From 945fbf8c825aba3073e8073da26edbdbffe7666e Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Fri, 20 Jan 2023 13:36:22 +0200 Subject: [PATCH 1/7] Two fixes for handling edge cases in MLflow logging * Long parameter values are truncated, instead of not logged at all. * In case there are more than 100 parameters, they are logged in chunks. --- src/pytorch_lightning/loggers/mlflow.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/src/pytorch_lightning/loggers/mlflow.py b/src/pytorch_lightning/loggers/mlflow.py index 980d4e4bccb9e..2abe283c31e1d 100644 --- a/src/pytorch_lightning/loggers/mlflow.py +++ b/src/pytorch_lightning/loggers/mlflow.py @@ -238,18 +238,14 @@ def experiment_id(self) -> Optional[str]: def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: params = _convert_params(params) params = _flatten_dict(params) - params_list: List[Param] = [] - for k, v in params.items(): - # TODO: mlflow 1.28 allows up to 500 characters: https://github.com/mlflow/mlflow/releases/tag/v1.28.0 - if len(str(v)) > 250: - rank_zero_warn( - f"Mlflow only allows parameters with up to 250 characters. Discard {k}={v}", category=RuntimeWarning - ) - continue - params_list.append(Param(key=k, value=v)) + # Truncate parameter values to 250 characters. + # TODO: MLflow 1.28 allows up to 500 characters: https://github.com/mlflow/mlflow/releases/tag/v1.28.0 + params_list = [Param(key=k, value=str(v)[:250]) for k, v in params.items()] - self.experiment.log_batch(run_id=self.run_id, params=params_list) + # Log in chunks of 100 parameters (the maximum allowed by MLflow). + for idx in range(0, len(params_list), 100): + self.experiment.log_batch(run_id=self.run_id, params=params_list[idx:idx + 100]) @rank_zero_only def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None: From 1f16ec040db5cac52e121f90e4dddef80665b415 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 20 Jan 2023 11:49:39 +0000 Subject: [PATCH 2/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_lightning/loggers/mlflow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/loggers/mlflow.py b/src/pytorch_lightning/loggers/mlflow.py index 2abe283c31e1d..4b1088a6f4604 100644 --- a/src/pytorch_lightning/loggers/mlflow.py +++ b/src/pytorch_lightning/loggers/mlflow.py @@ -245,7 +245,7 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: # Log in chunks of 100 parameters (the maximum allowed by MLflow). for idx in range(0, len(params_list), 100): - self.experiment.log_batch(run_id=self.run_id, params=params_list[idx:idx + 100]) + self.experiment.log_batch(run_id=self.run_id, params=params_list[idx : idx + 100]) @rank_zero_only def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None: From 30839ad60560c5f099d123534a2ed9e738c1e6c8 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Fri, 20 Jan 2023 14:14:47 +0200 Subject: [PATCH 3/7] Added a unit test for logging more than 100 parameters to MLflow --- tests/tests_pytorch/loggers/test_mlflow.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/tests/tests_pytorch/loggers/test_mlflow.py b/tests/tests_pytorch/loggers/test_mlflow.py index 23de563270cfe..5fd06b7fa4f2a 100644 --- a/tests/tests_pytorch/loggers/test_mlflow.py +++ b/tests/tests_pytorch/loggers/test_mlflow.py @@ -227,14 +227,23 @@ def test_mlflow_logger_with_unexpected_characters(client, _, __, tmpdir): @mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True) @mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient") def test_mlflow_logger_with_long_param_value(client, _, tmpdir): - """Test that the logger raises warning with special characters not accepted by MLFlow.""" + """Test that the logger doesn't crash when logging a long parameter value.""" logger = MLFlowLogger("test", save_dir=tmpdir) - value = "test" * 100 + value = "test" * 200 key = "test_param" params = {key: value} - with pytest.warns(RuntimeWarning, match=f"Discard {key}={value}"): - logger.log_hyperparams(params) + logger.log_hyperparams(params) + + +@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True) +@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient") +def test_mlflow_logger_with_many_params(client, _, tmpdir): + """Test that the logger doesn't crash when logging more than 100 parameters.""" + logger = MLFlowLogger("test", save_dir=tmpdir) + params = {f"test_param_{idx}": f"test_value_{idx}" for idx in range(200)} + + logger.log_hyperparams(params) @mock.patch("pytorch_lightning.loggers.mlflow.Metric") From 298caa28aba9cafd8ff94ba2722983824204ca98 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Fri, 20 Jan 2023 15:13:24 +0200 Subject: [PATCH 4/7] Param mock is needed when MLflow is not available --- tests/tests_pytorch/loggers/test_mlflow.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/tests_pytorch/loggers/test_mlflow.py b/tests/tests_pytorch/loggers/test_mlflow.py index 5fd06b7fa4f2a..fa44665826c30 100644 --- a/tests/tests_pytorch/loggers/test_mlflow.py +++ b/tests/tests_pytorch/loggers/test_mlflow.py @@ -224,6 +224,7 @@ def test_mlflow_logger_with_unexpected_characters(client, _, __, tmpdir): logger.log_metrics(metrics) +@mock.patch("pytorch_lightning.loggers.mlflow.Param") @mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True) @mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient") def test_mlflow_logger_with_long_param_value(client, _, tmpdir): @@ -236,6 +237,7 @@ def test_mlflow_logger_with_long_param_value(client, _, tmpdir): logger.log_hyperparams(params) +@mock.patch("pytorch_lightning.loggers.mlflow.Param") @mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True) @mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient") def test_mlflow_logger_with_many_params(client, _, tmpdir): From 46ce1fe1a35e0d2a7e84be12b740c6fdf318b189 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Fri, 20 Jan 2023 15:24:03 +0200 Subject: [PATCH 5/7] Trying to get mocking to work --- tests/tests_pytorch/loggers/test_mlflow.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/tests_pytorch/loggers/test_mlflow.py b/tests/tests_pytorch/loggers/test_mlflow.py index fa44665826c30..2d43261b525cd 100644 --- a/tests/tests_pytorch/loggers/test_mlflow.py +++ b/tests/tests_pytorch/loggers/test_mlflow.py @@ -227,7 +227,7 @@ def test_mlflow_logger_with_unexpected_characters(client, _, __, tmpdir): @mock.patch("pytorch_lightning.loggers.mlflow.Param") @mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True) @mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient") -def test_mlflow_logger_with_long_param_value(client, _, tmpdir): +def test_mlflow_logger_with_long_param_value(client, _, param, tmpdir): """Test that the logger doesn't crash when logging a long parameter value.""" logger = MLFlowLogger("test", save_dir=tmpdir) value = "test" * 200 @@ -240,7 +240,7 @@ def test_mlflow_logger_with_long_param_value(client, _, tmpdir): @mock.patch("pytorch_lightning.loggers.mlflow.Param") @mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True) @mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient") -def test_mlflow_logger_with_many_params(client, _, tmpdir): +def test_mlflow_logger_with_many_params(client, _, param, tmpdir): """Test that the logger doesn't crash when logging more than 100 parameters.""" logger = MLFlowLogger("test", save_dir=tmpdir) params = {f"test_param_{idx}": f"test_value_{idx}" for idx in range(200)} From 57ea0be724be1d21533a835e2707a7dc52b027b6 Mon Sep 17 00:00:00 2001 From: Seppo Enarvi Date: Mon, 23 Jan 2023 11:32:22 +0200 Subject: [PATCH 6/7] Updated CHANGELOG and verified the mock calls. --- src/pytorch_lightning/CHANGELOG.md | 2 + tests/tests_pytorch/loggers/test_mlflow.py | 55 ++++++++++++---------- 2 files changed, 33 insertions(+), 24 deletions(-) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index a4d5a7d12fc38..e8e6431aa3a57 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -134,6 +134,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue with `MLFlowLogger` logging the wrong keys with `.log_hyperparams()` ([#16418](https://github.com/Lightning-AI/lightning/pull/16418)) +- Fixed logging more than 100 parameters with `MLFlowLogger` and long values are truncated ([#16451](https://github.com/Lightning-AI/lightning/pull/16451)) + ## [1.9.0] - 2023-01-17 diff --git a/tests/tests_pytorch/loggers/test_mlflow.py b/tests/tests_pytorch/loggers/test_mlflow.py index 2d43261b525cd..2c7fbaff64da2 100644 --- a/tests/tests_pytorch/loggers/test_mlflow.py +++ b/tests/tests_pytorch/loggers/test_mlflow.py @@ -224,30 +224,6 @@ def test_mlflow_logger_with_unexpected_characters(client, _, __, tmpdir): logger.log_metrics(metrics) -@mock.patch("pytorch_lightning.loggers.mlflow.Param") -@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True) -@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient") -def test_mlflow_logger_with_long_param_value(client, _, param, tmpdir): - """Test that the logger doesn't crash when logging a long parameter value.""" - logger = MLFlowLogger("test", save_dir=tmpdir) - value = "test" * 200 - key = "test_param" - params = {key: value} - - logger.log_hyperparams(params) - - -@mock.patch("pytorch_lightning.loggers.mlflow.Param") -@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True) -@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient") -def test_mlflow_logger_with_many_params(client, _, param, tmpdir): - """Test that the logger doesn't crash when logging more than 100 parameters.""" - logger = MLFlowLogger("test", save_dir=tmpdir) - params = {f"test_param_{idx}": f"test_value_{idx}" for idx in range(200)} - - logger.log_hyperparams(params) - - @mock.patch("pytorch_lightning.loggers.mlflow.Metric") @mock.patch("pytorch_lightning.loggers.mlflow.Param") @mock.patch("pytorch_lightning.loggers.mlflow.time") @@ -281,6 +257,37 @@ def test_mlflow_logger_experiment_calls(client, _, time, param, metric, tmpdir): ) +def _check_value_length(value, *args, **kwargs): + assert len(value) <= 250 + + +@mock.patch("pytorch_lightning.loggers.mlflow.Param", side_effect=_check_value_length) +@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True) +@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient") +def test_mlflow_logger_with_long_param_value(client, _, param, tmpdir): + """Test that long parameter values are truncated to 250 characters.""" + logger = MLFlowLogger("test", save_dir=tmpdir) + + params = {"test": "test_param" * 50} + logger.log_hyperparams(params) + + # assert_called_once_with() won't properly check the parameter value. + logger.experiment.log_batch.assert_called_once() + + +@mock.patch("pytorch_lightning.loggers.mlflow.Param") +@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True) +@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient") +def test_mlflow_logger_with_many_params(client, _, param, tmpdir): + """Test that the when logging more than 100 parameters, it will be split into batches of at most 100 parameters.""" + logger = MLFlowLogger("test", save_dir=tmpdir) + + params = {f"test_{idx}": f"test_param_{idx}" for idx in range(150)} + logger.log_hyperparams(params) + + assert logger.experiment.log_batch.call_count == 2 + + @pytest.mark.parametrize( "status,expected", [ From c3782878ddc42191801bedf4b51d056e362ad5a2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 23 Jan 2023 09:37:19 +0000 Subject: [PATCH 7/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/tests_pytorch/loggers/test_mlflow.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/tests_pytorch/loggers/test_mlflow.py b/tests/tests_pytorch/loggers/test_mlflow.py index 2c7fbaff64da2..14879ed10480e 100644 --- a/tests/tests_pytorch/loggers/test_mlflow.py +++ b/tests/tests_pytorch/loggers/test_mlflow.py @@ -279,7 +279,8 @@ def test_mlflow_logger_with_long_param_value(client, _, param, tmpdir): @mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True) @mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient") def test_mlflow_logger_with_many_params(client, _, param, tmpdir): - """Test that the when logging more than 100 parameters, it will be split into batches of at most 100 parameters.""" + """Test that the when logging more than 100 parameters, it will be split into batches of at most 100 + parameters.""" logger = MLFlowLogger("test", save_dir=tmpdir) params = {f"test_{idx}": f"test_param_{idx}" for idx in range(150)}