MLflow
A Python package which is an open-source platform for machine learning.
It is meant to help the overall machine learning cycle, which starts from raw data to data preprocessings to training (including hyperparameter tuning) to deployment and then starting over again.
MLflow currently consists of 4 components:
- Tracking
- Projects
- Models
- Model Registry
Tracking
The key concepts in MLflow tracking are:
- Parameters: The key-value inputs for the model and training.
- Metrics: Numerical values used for evaluation.
- Tags and Notes: Information about a run.
- Artifacts: Files, data (path to training, validation data etc), serialized version of the model used.
- Source: Shell scripts, Python code, and the Github version of it.
- Run: An instance of code that runs by MLflow, taking into consideration the information above.
- Experiment: A collection of different runs
Before everything, we need to set the tracking URI. Local file storage is the easiest option, although other database based approaches exist as well.
mlflow.set_tracking_uri(os.getcwd() + "/mlruns")
Some online sources also prepend the string file://
. There was a time back when using hydra meant one should use hutils.get_original_cwd()
instead of os.getcwd()
but this seems to not be needed now.
An illustrative example of using the Python context manager for MLflow tracking is shown below:
import mlflow data = load_data(data_path) model = init_model(param1=p1) model.fit(data.train, learning_rate=lr) score = model.score(data.test) with mlflow.start_run(): mlflow.log_param("data_path", data_path) mlflow.log_param("param1", p1) mlflow.log_param("learning_rate", lr) mlflow.log_metric("score", score) mlflow.sklearn.log_model(model)
See Yapi's answer in this StackOverflow post on how to log figures. Namely, given the Matplotlib Figure object fig
, use
mlflow.log_figure(fig, "figure_name.png")
To visualize the results, open the UI via the terminal.
mlflow ui
If the UI is running is a different ssh server, use ssh tunnelling. Assuming the service is running on localhost port 5000 on the ssh server and you want to use localhost port 8000 on your local machine, do:
ssh -L 8000:localhost:5000 user@sshserver.com
If you get the SSH tunneling error: "channel 1: open failed: administratively prohibited: open failed"
error, replace localhost
with 127.0.0.1
.
If you want to stop the process running on an ssh server e.g. on port 5000, close it via:
fuser -k 5000/tcp
Logging
To log a specific artifact (file), e.g. named image.png
, which we can assume is saved in the current working directory, use:
mlflow.log_artifact("image.png", "Images")
Here, Images
creates a new directory within the collection of artifacts which can be helpful to organize them if many of them are saved for each run.
To use sqlite
as the backend, run the database server connected to Mlflow via
mlflow ui --backend-store-uri=sqlite:///mlruns.db --default-artifact-root=file:mlruns
Or,
mlflow server --backend-store-uri=sqlite:///mlruns.db --default-artifact-root=file:mlruns --workers=8 --gunicorn-opts="--timeout=180"
In the Python script, set the tracking URI to be the database.
mlflow.set_tracking_uri("sqlite:///mlruns.db")
Or, assuming the use of PostgreSQL, which is running.
mlflow.set_tracking_uri(f"postgresql://{os.environ['POSTGRES_USER']}:{os.environ['POSTGRES_PASSWORD']}@localhost:{os.environ['POSTGRES_PORT']}/{os.environ['POSTGRES_DB']}")
In which case to run the server run the following with the appropriate environment variables set.
mlflow server --backend-store-uri postgresql://$POSTGRES_USER:$POSTGRES_PASSWORD@127.0.0.1:$POSTGRES_PORT/$POSTGRES_DB --default-artifact-root file:mlruns
Or more recently,
mlflow server --backend-store-uri $MLFLOW_DB_URL --default-artifact-root file:mlruns
Where $MLFLOWDBURL is sourced from .env.
See this documentation page on how to save custom models, e.g. Flax neural networks and their weights.
Analysing results
Sometimes its easier to analyse the experiments via a notebook rather than using the UI.