@@ -33,9 +33,9 @@ def mock_mlflow_run_creation(logger, experiment_name=None, experiment_id=None, r
33
33
return logger
34
34
35
35
36
- @mock .patch ("pytorch_lightning.loggers.mlflow.mlflow" )
36
+ @mock .patch ("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE" , return_value = True )
37
37
@mock .patch ("pytorch_lightning.loggers.mlflow.MlflowClient" )
38
- def test_mlflow_logger_exists (client , mlflow , tmpdir ):
38
+ def test_mlflow_logger_exists (client , _ , tmpdir ):
39
39
"""Test launching three independent loggers with either same or different experiment name."""
40
40
41
41
run1 = MagicMock ()
@@ -87,9 +87,9 @@ def test_mlflow_logger_exists(client, mlflow, tmpdir):
87
87
assert logger3 .run_id == "run-id-3"
88
88
89
89
90
- @mock .patch ("pytorch_lightning.loggers.mlflow.mlflow" )
90
+ @mock .patch ("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE" , return_value = True )
91
91
@mock .patch ("pytorch_lightning.loggers.mlflow.MlflowClient" )
92
- def test_mlflow_run_name_setting (client , mlflow , tmpdir ):
92
+ def test_mlflow_run_name_setting (client , _ , tmpdir ):
93
93
"""Test that the run_name argument makes the MLFLOW_RUN_NAME tag."""
94
94
95
95
tags = resolve_tags ({MLFLOW_RUN_NAME : "run-name-1" })
@@ -114,9 +114,9 @@ def test_mlflow_run_name_setting(client, mlflow, tmpdir):
114
114
client .return_value .create_run .assert_called_with (experiment_id = "exp-id" , tags = default_tags )
115
115
116
116
117
- @mock .patch ("pytorch_lightning.loggers.mlflow.mlflow" )
117
+ @mock .patch ("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE" , return_value = True )
118
118
@mock .patch ("pytorch_lightning.loggers.mlflow.MlflowClient" )
119
- def test_mlflow_run_id_setting (client , mlflow , tmpdir ):
119
+ def test_mlflow_run_id_setting (client , _ , tmpdir ):
120
120
"""Test that the run_id argument uses the provided run_id."""
121
121
122
122
run = MagicMock ()
@@ -135,9 +135,9 @@ def test_mlflow_run_id_setting(client, mlflow, tmpdir):
135
135
client .reset_mock (return_value = True )
136
136
137
137
138
- @mock .patch ("pytorch_lightning.loggers.mlflow.mlflow" )
138
+ @mock .patch ("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE" , return_value = True )
139
139
@mock .patch ("pytorch_lightning.loggers.mlflow.MlflowClient" )
140
- def test_mlflow_log_dir (client , mlflow , tmpdir ):
140
+ def test_mlflow_log_dir (client , _ , tmpdir ):
141
141
"""Test that the trainer saves checkpoints in the logger's save dir."""
142
142
143
143
# simulate experiment creation with mlflow client mock
@@ -165,7 +165,7 @@ def test_mlflow_log_dir(client, mlflow, tmpdir):
165
165
def test_mlflow_logger_dirs_creation (tmpdir ):
166
166
"""Test that the logger creates the folders and files in the right place."""
167
167
if not _MLFLOW_AVAILABLE :
168
- pytest .xfail ("test for explicit file creation requires mlflow dependency to be installed." )
168
+ pytest .skip ("test for explicit file creation requires mlflow dependency to be installed." )
169
169
170
170
assert not os .listdir (tmpdir )
171
171
logger = MLFlowLogger ("test" , save_dir = tmpdir )
@@ -201,9 +201,9 @@ def on_train_epoch_end(self, *args, **kwargs):
201
201
assert os .listdir (trainer .checkpoint_callback .dirpath ) == [f"epoch=0-step={ limit_batches } .ckpt" ]
202
202
203
203
204
- @mock .patch ("pytorch_lightning.loggers.mlflow.mlflow" )
204
+ @mock .patch ("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE" , return_value = True )
205
205
@mock .patch ("pytorch_lightning.loggers.mlflow.MlflowClient" )
206
- def test_mlflow_experiment_id_retrieved_once (client , mlflow , tmpdir ):
206
+ def test_mlflow_experiment_id_retrieved_once (client , tmpdir ):
207
207
"""Test that the logger experiment_id retrieved only once."""
208
208
logger = MLFlowLogger ("test" , save_dir = tmpdir )
209
209
_ = logger .experiment
@@ -213,9 +213,9 @@ def test_mlflow_experiment_id_retrieved_once(client, mlflow, tmpdir):
213
213
214
214
215
215
@mock .patch ("pytorch_lightning.loggers.mlflow.Metric" )
216
- @mock .patch ("pytorch_lightning.loggers.mlflow.mlflow" )
216
+ @mock .patch ("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE" , return_value = True )
217
217
@mock .patch ("pytorch_lightning.loggers.mlflow.MlflowClient" )
218
- def test_mlflow_logger_with_unexpected_characters (client , mlflow , _ , tmpdir ):
218
+ def test_mlflow_logger_with_unexpected_characters (client , _ , __ , tmpdir ):
219
219
"""Test that the logger raises warning with special characters not accepted by MLFlow."""
220
220
logger = MLFlowLogger ("test" , save_dir = tmpdir )
221
221
metrics = {"[some_metric]" : 10 }
@@ -224,9 +224,9 @@ def test_mlflow_logger_with_unexpected_characters(client, mlflow, _, tmpdir):
224
224
logger .log_metrics (metrics )
225
225
226
226
227
- @mock .patch ("pytorch_lightning.loggers.mlflow.mlflow" )
227
+ @mock .patch ("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE" , return_value = True )
228
228
@mock .patch ("pytorch_lightning.loggers.mlflow.MlflowClient" )
229
- def test_mlflow_logger_with_long_param_value (client , mlflow , tmpdir ):
229
+ def test_mlflow_logger_with_long_param_value (client , _ , tmpdir ):
230
230
"""Test that the logger raises warning with special characters not accepted by MLFlow."""
231
231
logger = MLFlowLogger ("test" , save_dir = tmpdir )
232
232
value = "test" * 100
@@ -240,9 +240,9 @@ def test_mlflow_logger_with_long_param_value(client, mlflow, tmpdir):
240
240
@mock .patch ("pytorch_lightning.loggers.mlflow.Metric" )
241
241
@mock .patch ("pytorch_lightning.loggers.mlflow.Param" )
242
242
@mock .patch ("pytorch_lightning.loggers.mlflow.time" )
243
- @mock .patch ("pytorch_lightning.loggers.mlflow.mlflow" )
243
+ @mock .patch ("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE" , return_value = True )
244
244
@mock .patch ("pytorch_lightning.loggers.mlflow.MlflowClient" )
245
- def test_mlflow_logger_experiment_calls (client , mlflow , time , param , metric , tmpdir ):
245
+ def test_mlflow_logger_experiment_calls (client , _ , time , param , metric , tmpdir ):
246
246
"""Test that the logger calls methods on the mlflow experiment correctly."""
247
247
time .return_value = 1
248
248
@@ -268,7 +268,7 @@ def test_mlflow_logger_experiment_calls(client, mlflow, time, param, metric, tmp
268
268
)
269
269
270
270
271
- @mock .patch ("pytorch_lightning.loggers.mlflow.mlflow" )
271
+ @mock .patch ("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE" , return_value = True )
272
272
@mock .patch ("pytorch_lightning.loggers.mlflow.MlflowClient" )
273
273
def test_mlflow_logger_finalize_when_exception (* _ ):
274
274
logger = MLFlowLogger ("test" )
@@ -286,7 +286,7 @@ def test_mlflow_logger_finalize_when_exception(*_):
286
286
logger .experiment .set_terminated .assert_called_once_with (logger .run_id , "FAILED" )
287
287
288
288
289
- @mock .patch ("pytorch_lightning.loggers.mlflow.mlflow" )
289
+ @mock .patch ("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE" , return_value = True )
290
290
@mock .patch ("pytorch_lightning.loggers.mlflow.MlflowClient" )
291
291
@pytest .mark .parametrize ("log_model" , ["all" , True , False ])
292
292
def test_mlflow_log_model (client , _ , tmpdir , log_model ):
0 commit comments