Skip to content

Commit 38acba0

Browse files
Jake SchmidtJake Schmidtawaelchlicarmocca
authored
Batch MLFlowLogger requests (#15915)
Co-authored-by: Jake Schmidt <[email protected]> Co-authored-by: awaelchli <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 2577285 commit 38acba0

File tree

5 files changed

+59
-43
lines changed

5 files changed

+59
-43
lines changed

requirements/pytorch/loggers.info

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# all supported loggers. this list is here as a reference, but they are not installed in CI
22
neptune-client
33
comet-ml
4-
mlflow
4+
mlflow>=1.0.0
55
wandb

src/pytorch_lightning/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
5656
- The Trainer now raises an error if it is given multiple stateful callbacks of the same time with colliding state keys ([#15634](https://github.com/Lightning-AI/lightning/pull/15634))
5757

5858

59+
- `MLFlowLogger` now logs hyperparameters and metrics in batched API calls ([#15915](https://github.com/Lightning-AI/lightning/pull/15915))
60+
61+
5962
### Deprecated
6063

6164
- Deprecated `description`, `env_prefix` and `env_parse` parameters in `LightningCLI.__init__` in favour of giving them through `parser_kwargs` ([#15651](https://github.com/Lightning-AI/lightning/pull/15651))

src/pytorch_lightning/loggers/mlflow.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@
2222
from argparse import Namespace
2323
from pathlib import Path
2424
from time import time
25-
from typing import Any, Dict, Mapping, Optional, Union
25+
from typing import Any, Dict, List, Mapping, Optional, Union
2626

2727
import yaml
28-
from lightning_utilities.core.imports import module_available
28+
from lightning_utilities.core.imports import RequirementCache
2929
from torch import Tensor
3030
from typing_extensions import Literal
3131

@@ -36,15 +36,14 @@
3636

3737
log = logging.getLogger(__name__)
3838
LOCAL_FILE_URI_PREFIX = "file:"
39-
_MLFLOW_AVAILABLE = module_available("mlflow")
40-
try:
41-
import mlflow
39+
_MLFLOW_AVAILABLE = RequirementCache("mlflow>=1.0.0")
40+
if _MLFLOW_AVAILABLE:
41+
from mlflow.entities import Metric, Param
4242
from mlflow.tracking import context, MlflowClient
4343
from mlflow.utils.mlflow_tags import MLFLOW_RUN_NAME
44-
# todo: there seems to be still some remaining import error with Conda env
45-
except ModuleNotFoundError:
46-
_MLFLOW_AVAILABLE = False
47-
mlflow, MlflowClient, context = None, None, None
44+
else:
45+
MlflowClient, context = None, None
46+
Metric, Param = None, None
4847
MLFLOW_RUN_NAME = "mlflow.runName"
4948

5049
# before v1.1.0
@@ -147,10 +146,8 @@ def __init__(
147146
artifact_location: Optional[str] = None,
148147
run_id: Optional[str] = None,
149148
):
150-
if mlflow is None:
151-
raise ModuleNotFoundError(
152-
"You want to use `mlflow` logger which is not installed yet, install it with `pip install mlflow`."
153-
)
149+
if not _MLFLOW_AVAILABLE:
150+
raise ModuleNotFoundError(str(_MLFLOW_AVAILABLE))
154151
super().__init__()
155152
if not tracking_uri:
156153
tracking_uri = f"{LOCAL_FILE_URI_PREFIX}{save_dir}"
@@ -240,20 +237,25 @@ def experiment_id(self) -> Optional[str]:
240237
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
241238
params = _convert_params(params)
242239
params = _flatten_dict(params)
240+
params_list: List[Param] = []
241+
243242
for k, v in params.items():
243+
# TODO: mlflow 1.28 allows up to 500 characters: https://github.com/mlflow/mlflow/releases/tag/v1.28.0
244244
if len(str(v)) > 250:
245245
rank_zero_warn(
246246
f"Mlflow only allows parameters with up to 250 characters. Discard {k}={v}", category=RuntimeWarning
247247
)
248248
continue
249+
params_list.append(Param(key=v, value=v))
249250

250-
self.experiment.log_param(self.run_id, k, v)
251+
self.experiment.log_batch(run_id=self.run_id, params=params_list)
251252

252253
@rank_zero_only
253254
def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None:
254255
assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0"
255256

256257
metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR)
258+
metrics_list: List[Metric] = []
257259

258260
timestamp_ms = int(time() * 1000)
259261
for k, v in metrics.items():
@@ -269,8 +271,9 @@ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None)
269271
category=RuntimeWarning,
270272
)
271273
k = new_k
274+
metrics_list.append(Metric(key=k, value=v, timestamp=timestamp_ms, step=step or 0))
272275

273-
self.experiment.log_metric(self.run_id, k, v, timestamp_ms, step)
276+
self.experiment.log_batch(run_id=self.run_id, metrics=metrics_list)
274277

275278
@rank_zero_only
276279
def finalize(self, status: str = "success") -> None:

tests/tests_pytorch/loggers/test_all.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,9 @@
4040
LOGGER_CTX_MANAGERS = (
4141
mock.patch("pytorch_lightning.loggers.comet.comet_ml"),
4242
mock.patch("pytorch_lightning.loggers.comet.CometOfflineExperiment"),
43-
mock.patch("pytorch_lightning.loggers.mlflow.mlflow"),
43+
mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True),
4444
mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient"),
45+
mock.patch("pytorch_lightning.loggers.mlflow.Metric"),
4546
mock.patch("pytorch_lightning.loggers.neptune.neptune", new_callable=create_neptune_mock),
4647
mock.patch("pytorch_lightning.loggers.neptune._NEPTUNE_AVAILABLE", return_value=True),
4748
mock.patch("pytorch_lightning.loggers.wandb.wandb"),
@@ -282,12 +283,14 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch):
282283
logger.experiment.log_metrics.assert_called_once_with({"tmp-test": 1.0}, epoch=None, step=0)
283284

284285
# MLflow
285-
with mock.patch("pytorch_lightning.loggers.mlflow.mlflow"), mock.patch(
286-
"pytorch_lightning.loggers.mlflow.MlflowClient"
287-
):
286+
with mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True), mock.patch(
287+
"pytorch_lightning.loggers.mlflow.Metric"
288+
) as Metric, mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient"):
288289
logger = _instantiate_logger(MLFlowLogger, save_dir=tmpdir, prefix=prefix)
289290
logger.log_metrics({"test": 1.0}, step=0)
290-
logger.experiment.log_metric.assert_called_once_with(ANY, "tmp-test", 1.0, ANY, 0)
291+
logger.experiment.log_batch.assert_called_once_with(
292+
run_id=ANY, metrics=[Metric(key="tmp-test", value=1.0, timestamp=ANY, step=0)]
293+
)
291294

292295
# Neptune
293296
with mock.patch("pytorch_lightning.loggers.neptune.neptune"), mock.patch(
@@ -340,7 +343,7 @@ def test_logger_default_name(tmpdir, monkeypatch):
340343
assert logger.name == "lightning_logs"
341344

342345
# MLflow
343-
with mock.patch("pytorch_lightning.loggers.mlflow.mlflow"), mock.patch(
346+
with mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True), mock.patch(
344347
"pytorch_lightning.loggers.mlflow.MlflowClient"
345348
) as mlflow_client:
346349
mlflow_client().get_experiment_by_name.return_value = None

tests/tests_pytorch/loggers/test_mlflow.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@ def mock_mlflow_run_creation(logger, experiment_name=None, experiment_id=None, r
3333
return logger
3434

3535

36-
@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
36+
@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
3737
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
38-
def test_mlflow_logger_exists(client, mlflow, tmpdir):
38+
def test_mlflow_logger_exists(client, _, tmpdir):
3939
"""Test launching three independent loggers with either same or different experiment name."""
4040

4141
run1 = MagicMock()
@@ -87,9 +87,9 @@ def test_mlflow_logger_exists(client, mlflow, tmpdir):
8787
assert logger3.run_id == "run-id-3"
8888

8989

90-
@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
90+
@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
9191
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
92-
def test_mlflow_run_name_setting(client, mlflow, tmpdir):
92+
def test_mlflow_run_name_setting(client, _, tmpdir):
9393
"""Test that the run_name argument makes the MLFLOW_RUN_NAME tag."""
9494

9595
tags = resolve_tags({MLFLOW_RUN_NAME: "run-name-1"})
@@ -114,9 +114,9 @@ def test_mlflow_run_name_setting(client, mlflow, tmpdir):
114114
client.return_value.create_run.assert_called_with(experiment_id="exp-id", tags=default_tags)
115115

116116

117-
@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
117+
@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
118118
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
119-
def test_mlflow_run_id_setting(client, mlflow, tmpdir):
119+
def test_mlflow_run_id_setting(client, _, tmpdir):
120120
"""Test that the run_id argument uses the provided run_id."""
121121

122122
run = MagicMock()
@@ -135,9 +135,9 @@ def test_mlflow_run_id_setting(client, mlflow, tmpdir):
135135
client.reset_mock(return_value=True)
136136

137137

138-
@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
138+
@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
139139
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
140-
def test_mlflow_log_dir(client, mlflow, tmpdir):
140+
def test_mlflow_log_dir(client, _, tmpdir):
141141
"""Test that the trainer saves checkpoints in the logger's save dir."""
142142

143143
# simulate experiment creation with mlflow client mock
@@ -165,7 +165,7 @@ def test_mlflow_log_dir(client, mlflow, tmpdir):
165165
def test_mlflow_logger_dirs_creation(tmpdir):
166166
"""Test that the logger creates the folders and files in the right place."""
167167
if not _MLFLOW_AVAILABLE:
168-
pytest.xfail("test for explicit file creation requires mlflow dependency to be installed.")
168+
pytest.skip("test for explicit file creation requires mlflow dependency to be installed.")
169169

170170
assert not os.listdir(tmpdir)
171171
logger = MLFlowLogger("test", save_dir=tmpdir)
@@ -201,9 +201,9 @@ def on_train_epoch_end(self, *args, **kwargs):
201201
assert os.listdir(trainer.checkpoint_callback.dirpath) == [f"epoch=0-step={limit_batches}.ckpt"]
202202

203203

204-
@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
204+
@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
205205
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
206-
def test_mlflow_experiment_id_retrieved_once(client, mlflow, tmpdir):
206+
def test_mlflow_experiment_id_retrieved_once(client, tmpdir):
207207
"""Test that the logger experiment_id retrieved only once."""
208208
logger = MLFlowLogger("test", save_dir=tmpdir)
209209
_ = logger.experiment
@@ -212,9 +212,10 @@ def test_mlflow_experiment_id_retrieved_once(client, mlflow, tmpdir):
212212
assert logger.experiment.get_experiment_by_name.call_count == 1
213213

214214

215-
@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
215+
@mock.patch("pytorch_lightning.loggers.mlflow.Metric")
216+
@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
216217
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
217-
def test_mlflow_logger_with_unexpected_characters(client, mlflow, tmpdir):
218+
def test_mlflow_logger_with_unexpected_characters(client, _, __, tmpdir):
218219
"""Test that the logger raises warning with special characters not accepted by MLFlow."""
219220
logger = MLFlowLogger("test", save_dir=tmpdir)
220221
metrics = {"[some_metric]": 10}
@@ -223,9 +224,9 @@ def test_mlflow_logger_with_unexpected_characters(client, mlflow, tmpdir):
223224
logger.log_metrics(metrics)
224225

225226

226-
@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
227+
@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
227228
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
228-
def test_mlflow_logger_with_long_param_value(client, mlflow, tmpdir):
229+
def test_mlflow_logger_with_long_param_value(client, _, tmpdir):
229230
"""Test that the logger raises warning with special characters not accepted by MLFlow."""
230231
logger = MLFlowLogger("test", save_dir=tmpdir)
231232
value = "test" * 100
@@ -236,10 +237,12 @@ def test_mlflow_logger_with_long_param_value(client, mlflow, tmpdir):
236237
logger.log_hyperparams(params)
237238

238239

240+
@mock.patch("pytorch_lightning.loggers.mlflow.Metric")
241+
@mock.patch("pytorch_lightning.loggers.mlflow.Param")
239242
@mock.patch("pytorch_lightning.loggers.mlflow.time")
240-
@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
243+
@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
241244
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
242-
def test_mlflow_logger_experiment_calls(client, mlflow, time, tmpdir):
245+
def test_mlflow_logger_experiment_calls(client, _, time, param, metric, tmpdir):
243246
"""Test that the logger calls methods on the mlflow experiment correctly."""
244247
time.return_value = 1
245248

@@ -249,19 +252,23 @@ def test_mlflow_logger_experiment_calls(client, mlflow, time, tmpdir):
249252
params = {"test": "test_param"}
250253
logger.log_hyperparams(params)
251254

252-
logger.experiment.log_param.assert_called_once_with(logger.run_id, "test", "test_param")
255+
logger.experiment.log_batch.assert_called_once_with(
256+
run_id=logger.run_id, params=[param(key="test_param", value="test_param")]
257+
)
253258

254259
metrics = {"some_metric": 10}
255260
logger.log_metrics(metrics)
256261

257-
logger.experiment.log_metric.assert_called_once_with(logger.run_id, "some_metric", 10, 1000, None)
262+
logger.experiment.log_batch.assert_called_with(
263+
run_id=logger.run_id, metrics=[metric(key="some_metric", value=10, timestamp=1000, step=0)]
264+
)
258265

259266
logger._mlflow_client.create_experiment.assert_called_once_with(
260267
name="test", artifact_location="my_artifact_location"
261268
)
262269

263270

264-
@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
271+
@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
265272
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
266273
def test_mlflow_logger_finalize_when_exception(*_):
267274
logger = MLFlowLogger("test")
@@ -279,7 +286,7 @@ def test_mlflow_logger_finalize_when_exception(*_):
279286
logger.experiment.set_terminated.assert_called_once_with(logger.run_id, "FAILED")
280287

281288

282-
@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
289+
@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
283290
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
284291
@pytest.mark.parametrize("log_model", ["all", True, False])
285292
def test_mlflow_log_model(client, _, tmpdir, log_model):

0 commit comments

Comments
 (0)