@@ -33,9 +33,9 @@ def mock_mlflow_run_creation(logger, experiment_name=None, experiment_id=None, r
33
33
return logger
34
34
35
35
36
- @mock .patch ("pytorch_lightning.loggers.mlflow.mlflow" )
36
+ @mock .patch ("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE" , return_value = True )
37
37
@mock .patch ("pytorch_lightning.loggers.mlflow.MlflowClient" )
38
- def test_mlflow_logger_exists (client , mlflow , tmpdir ):
38
+ def test_mlflow_logger_exists (client , _ , tmpdir ):
39
39
"""Test launching three independent loggers with either same or different experiment name."""
40
40
41
41
run1 = MagicMock ()
@@ -87,9 +87,9 @@ def test_mlflow_logger_exists(client, mlflow, tmpdir):
87
87
assert logger3 .run_id == "run-id-3"
88
88
89
89
90
- @mock .patch ("pytorch_lightning.loggers.mlflow.mlflow" )
90
+ @mock .patch ("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE" , return_value = True )
91
91
@mock .patch ("pytorch_lightning.loggers.mlflow.MlflowClient" )
92
- def test_mlflow_run_name_setting (client , mlflow , tmpdir ):
92
+ def test_mlflow_run_name_setting (client , _ , tmpdir ):
93
93
"""Test that the run_name argument makes the MLFLOW_RUN_NAME tag."""
94
94
95
95
tags = resolve_tags ({MLFLOW_RUN_NAME : "run-name-1" })
@@ -114,9 +114,9 @@ def test_mlflow_run_name_setting(client, mlflow, tmpdir):
114
114
client .return_value .create_run .assert_called_with (experiment_id = "exp-id" , tags = default_tags )
115
115
116
116
117
- @mock .patch ("pytorch_lightning.loggers.mlflow.mlflow" )
117
+ @mock .patch ("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE" , return_value = True )
118
118
@mock .patch ("pytorch_lightning.loggers.mlflow.MlflowClient" )
119
- def test_mlflow_run_id_setting (client , mlflow , tmpdir ):
119
+ def test_mlflow_run_id_setting (client , _ , tmpdir ):
120
120
"""Test that the run_id argument uses the provided run_id."""
121
121
122
122
run = MagicMock ()
@@ -135,9 +135,9 @@ def test_mlflow_run_id_setting(client, mlflow, tmpdir):
135
135
client .reset_mock (return_value = True )
136
136
137
137
138
- @mock .patch ("pytorch_lightning.loggers.mlflow.mlflow" )
138
+ @mock .patch ("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE" , return_value = True )
139
139
@mock .patch ("pytorch_lightning.loggers.mlflow.MlflowClient" )
140
- def test_mlflow_log_dir (client , mlflow , tmpdir ):
140
+ def test_mlflow_log_dir (client , _ , tmpdir ):
141
141
"""Test that the trainer saves checkpoints in the logger's save dir."""
142
142
143
143
# simulate experiment creation with mlflow client mock
@@ -165,7 +165,7 @@ def test_mlflow_log_dir(client, mlflow, tmpdir):
165
165
def test_mlflow_logger_dirs_creation (tmpdir ):
166
166
"""Test that the logger creates the folders and files in the right place."""
167
167
if not _MLFLOW_AVAILABLE :
168
- pytest .xfail ("test for explicit file creation requires mlflow dependency to be installed." )
168
+ pytest .skip ("test for explicit file creation requires mlflow dependency to be installed." )
169
169
170
170
assert not os .listdir (tmpdir )
171
171
logger = MLFlowLogger ("test" , save_dir = tmpdir )
@@ -201,9 +201,9 @@ def on_train_epoch_end(self, *args, **kwargs):
201
201
assert os .listdir (trainer .checkpoint_callback .dirpath ) == [f"epoch=0-step={ limit_batches } .ckpt" ]
202
202
203
203
204
- @mock .patch ("pytorch_lightning.loggers.mlflow.mlflow" )
204
+ @mock .patch ("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE" , return_value = True )
205
205
@mock .patch ("pytorch_lightning.loggers.mlflow.MlflowClient" )
206
- def test_mlflow_experiment_id_retrieved_once (client , mlflow , tmpdir ):
206
+ def test_mlflow_experiment_id_retrieved_once (client , tmpdir ):
207
207
"""Test that the logger experiment_id retrieved only once."""
208
208
logger = MLFlowLogger ("test" , save_dir = tmpdir )
209
209
_ = logger .experiment
@@ -212,9 +212,10 @@ def test_mlflow_experiment_id_retrieved_once(client, mlflow, tmpdir):
212
212
assert logger .experiment .get_experiment_by_name .call_count == 1
213
213
214
214
215
- @mock .patch ("pytorch_lightning.loggers.mlflow.mlflow" )
215
+ @mock .patch ("pytorch_lightning.loggers.mlflow.Metric" )
216
+ @mock .patch ("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE" , return_value = True )
216
217
@mock .patch ("pytorch_lightning.loggers.mlflow.MlflowClient" )
217
- def test_mlflow_logger_with_unexpected_characters (client , mlflow , tmpdir ):
218
+ def test_mlflow_logger_with_unexpected_characters (client , _ , __ , tmpdir ):
218
219
"""Test that the logger raises warning with special characters not accepted by MLFlow."""
219
220
logger = MLFlowLogger ("test" , save_dir = tmpdir )
220
221
metrics = {"[some_metric]" : 10 }
@@ -223,9 +224,9 @@ def test_mlflow_logger_with_unexpected_characters(client, mlflow, tmpdir):
223
224
logger .log_metrics (metrics )
224
225
225
226
226
- @mock .patch ("pytorch_lightning.loggers.mlflow.mlflow" )
227
+ @mock .patch ("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE" , return_value = True )
227
228
@mock .patch ("pytorch_lightning.loggers.mlflow.MlflowClient" )
228
- def test_mlflow_logger_with_long_param_value (client , mlflow , tmpdir ):
229
+ def test_mlflow_logger_with_long_param_value (client , _ , tmpdir ):
229
230
"""Test that the logger raises warning with special characters not accepted by MLFlow."""
230
231
logger = MLFlowLogger ("test" , save_dir = tmpdir )
231
232
value = "test" * 100
@@ -236,10 +237,12 @@ def test_mlflow_logger_with_long_param_value(client, mlflow, tmpdir):
236
237
logger .log_hyperparams (params )
237
238
238
239
240
+ @mock .patch ("pytorch_lightning.loggers.mlflow.Metric" )
241
+ @mock .patch ("pytorch_lightning.loggers.mlflow.Param" )
239
242
@mock .patch ("pytorch_lightning.loggers.mlflow.time" )
240
- @mock .patch ("pytorch_lightning.loggers.mlflow.mlflow" )
243
+ @mock .patch ("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE" , return_value = True )
241
244
@mock .patch ("pytorch_lightning.loggers.mlflow.MlflowClient" )
242
- def test_mlflow_logger_experiment_calls (client , mlflow , time , tmpdir ):
245
+ def test_mlflow_logger_experiment_calls (client , _ , time , param , metric , tmpdir ):
243
246
"""Test that the logger calls methods on the mlflow experiment correctly."""
244
247
time .return_value = 1
245
248
@@ -249,19 +252,23 @@ def test_mlflow_logger_experiment_calls(client, mlflow, time, tmpdir):
249
252
params = {"test" : "test_param" }
250
253
logger .log_hyperparams (params )
251
254
252
- logger .experiment .log_param .assert_called_once_with (logger .run_id , "test" , "test_param" )
255
+ logger .experiment .log_batch .assert_called_once_with (
256
+ run_id = logger .run_id , params = [param (key = "test_param" , value = "test_param" )]
257
+ )
253
258
254
259
metrics = {"some_metric" : 10 }
255
260
logger .log_metrics (metrics )
256
261
257
- logger .experiment .log_metric .assert_called_once_with (logger .run_id , "some_metric" , 10 , 1000 , None )
262
+ logger .experiment .log_batch .assert_called_with (
263
+ run_id = logger .run_id , metrics = [metric (key = "some_metric" , value = 10 , timestamp = 1000 , step = 0 )]
264
+ )
258
265
259
266
logger ._mlflow_client .create_experiment .assert_called_once_with (
260
267
name = "test" , artifact_location = "my_artifact_location"
261
268
)
262
269
263
270
264
- @mock .patch ("pytorch_lightning.loggers.mlflow.mlflow" )
271
+ @mock .patch ("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE" , return_value = True )
265
272
@mock .patch ("pytorch_lightning.loggers.mlflow.MlflowClient" )
266
273
def test_mlflow_logger_finalize_when_exception (* _ ):
267
274
logger = MLFlowLogger ("test" )
@@ -279,7 +286,7 @@ def test_mlflow_logger_finalize_when_exception(*_):
279
286
logger .experiment .set_terminated .assert_called_once_with (logger .run_id , "FAILED" )
280
287
281
288
282
- @mock .patch ("pytorch_lightning.loggers.mlflow.mlflow" )
289
+ @mock .patch ("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE" , return_value = True )
283
290
@mock .patch ("pytorch_lightning.loggers.mlflow.MlflowClient" )
284
291
@pytest .mark .parametrize ("log_model" , ["all" , True , False ])
285
292
def test_mlflow_log_model (client , _ , tmpdir , log_model ):
0 commit comments