Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed an issue with `MLFlowLogger` logging the wrong keys with `.log_hyperparams()` ([#16418](https://github.com/Lightning-AI/lightning/pull/16418))

- Fixed logging more than 100 parameters with `MLFlowLogger` and long values are truncated ([#16451](https://github.com/Lightning-AI/lightning/pull/16451))



## [1.9.0] - 2023-01-17
Expand Down
16 changes: 6 additions & 10 deletions src/pytorch_lightning/loggers/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,18 +238,14 @@ def experiment_id(self) -> Optional[str]:
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
params = _convert_params(params)
params = _flatten_dict(params)
params_list: List[Param] = []

for k, v in params.items():
# TODO: mlflow 1.28 allows up to 500 characters: https://github.com/mlflow/mlflow/releases/tag/v1.28.0
if len(str(v)) > 250:
rank_zero_warn(
f"Mlflow only allows parameters with up to 250 characters. Discard {k}={v}", category=RuntimeWarning
)
continue
params_list.append(Param(key=k, value=v))
# Truncate parameter values to 250 characters.
# TODO: MLflow 1.28 allows up to 500 characters: https://github.com/mlflow/mlflow/releases/tag/v1.28.0
params_list = [Param(key=k, value=str(v)[:250]) for k, v in params.items()]

self.experiment.log_batch(run_id=self.run_id, params=params_list)
# Log in chunks of 100 parameters (the maximum allowed by MLflow).
for idx in range(0, len(params_list), 100):
self.experiment.log_batch(run_id=self.run_id, params=params_list[idx : idx + 100])

@rank_zero_only
def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None:
Expand Down
45 changes: 32 additions & 13 deletions tests/tests_pytorch/loggers/test_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,19 +224,6 @@ def test_mlflow_logger_with_unexpected_characters(client, _, __, tmpdir):
logger.log_metrics(metrics)


@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
def test_mlflow_logger_with_long_param_value(client, _, tmpdir):
"""Test that the logger raises warning with special characters not accepted by MLFlow."""
logger = MLFlowLogger("test", save_dir=tmpdir)
value = "test" * 100
key = "test_param"
params = {key: value}

with pytest.warns(RuntimeWarning, match=f"Discard {key}={value}"):
logger.log_hyperparams(params)


@mock.patch("pytorch_lightning.loggers.mlflow.Metric")
@mock.patch("pytorch_lightning.loggers.mlflow.Param")
@mock.patch("pytorch_lightning.loggers.mlflow.time")
Expand Down Expand Up @@ -270,6 +257,38 @@ def test_mlflow_logger_experiment_calls(client, _, time, param, metric, tmpdir):
)


def _check_value_length(value, *args, **kwargs):
assert len(value) <= 250


@mock.patch("pytorch_lightning.loggers.mlflow.Param", side_effect=_check_value_length)
@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
def test_mlflow_logger_with_long_param_value(client, _, param, tmpdir):
"""Test that long parameter values are truncated to 250 characters."""
logger = MLFlowLogger("test", save_dir=tmpdir)

params = {"test": "test_param" * 50}
logger.log_hyperparams(params)

# assert_called_once_with() won't properly check the parameter value.
logger.experiment.log_batch.assert_called_once()


@mock.patch("pytorch_lightning.loggers.mlflow.Param")
@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
def test_mlflow_logger_with_many_params(client, _, param, tmpdir):
"""Test that the when logging more than 100 parameters, it will be split into batches of at most 100
parameters."""
logger = MLFlowLogger("test", save_dir=tmpdir)

params = {f"test_{idx}": f"test_param_{idx}" for idx in range(150)}
logger.log_hyperparams(params)

assert logger.experiment.log_batch.call_count == 2


@pytest.mark.parametrize(
"status,expected",
[
Expand Down