Skip to content

Commit 0f79699

Browse files
committed
Fix mocks
1 parent 8131209 commit 0f79699

File tree

3 files changed

+22
-22
lines changed

3 files changed

+22
-22
lines changed

src/pytorch_lightning/loggers/mlflow.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,11 @@
3838
LOCAL_FILE_URI_PREFIX = "file:"
3939
_MLFLOW_AVAILABLE = RequirementCache("mlflow>=1.0.0")
4040
if _MLFLOW_AVAILABLE:
41-
import mlflow
4241
from mlflow.entities import Metric, Param
4342
from mlflow.tracking import context, MlflowClient
4443
from mlflow.utils.mlflow_tags import MLFLOW_RUN_NAME
4544
else:
46-
mlflow, MlflowClient, context = None, None, None
45+
MlflowClient, context = None, None
4746
Metric, Param = None, None
4847
MLFLOW_RUN_NAME = "mlflow.runName"
4948

tests/tests_pytorch/loggers/test_all.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,9 @@
4040
LOGGER_CTX_MANAGERS = (
4141
mock.patch("pytorch_lightning.loggers.comet.comet_ml"),
4242
mock.patch("pytorch_lightning.loggers.comet.CometOfflineExperiment"),
43-
mock.patch("pytorch_lightning.loggers.mlflow.mlflow"),
43+
mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True),
4444
mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient"),
45+
mock.patch("pytorch_lightning.loggers.mlflow.Metric"),
4546
mock.patch("pytorch_lightning.loggers.neptune.neptune", new_callable=create_neptune_mock),
4647
mock.patch("pytorch_lightning.loggers.neptune._NEPTUNE_AVAILABLE", return_value=True),
4748
mock.patch("pytorch_lightning.loggers.wandb.wandb"),

tests/tests_pytorch/loggers/test_mlflow.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@ def mock_mlflow_run_creation(logger, experiment_name=None, experiment_id=None, r
3333
return logger
3434

3535

36-
@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
36+
@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
3737
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
38-
def test_mlflow_logger_exists(client, mlflow, tmpdir):
38+
def test_mlflow_logger_exists(client, _, tmpdir):
3939
"""Test launching three independent loggers with either same or different experiment name."""
4040

4141
run1 = MagicMock()
@@ -87,9 +87,9 @@ def test_mlflow_logger_exists(client, mlflow, tmpdir):
8787
assert logger3.run_id == "run-id-3"
8888

8989

90-
@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
90+
@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
9191
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
92-
def test_mlflow_run_name_setting(client, mlflow, tmpdir):
92+
def test_mlflow_run_name_setting(client, _, tmpdir):
9393
"""Test that the run_name argument makes the MLFLOW_RUN_NAME tag."""
9494

9595
tags = resolve_tags({MLFLOW_RUN_NAME: "run-name-1"})
@@ -114,9 +114,9 @@ def test_mlflow_run_name_setting(client, mlflow, tmpdir):
114114
client.return_value.create_run.assert_called_with(experiment_id="exp-id", tags=default_tags)
115115

116116

117-
@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
117+
@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
118118
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
119-
def test_mlflow_run_id_setting(client, mlflow, tmpdir):
119+
def test_mlflow_run_id_setting(client, _, tmpdir):
120120
"""Test that the run_id argument uses the provided run_id."""
121121

122122
run = MagicMock()
@@ -135,9 +135,9 @@ def test_mlflow_run_id_setting(client, mlflow, tmpdir):
135135
client.reset_mock(return_value=True)
136136

137137

138-
@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
138+
@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
139139
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
140-
def test_mlflow_log_dir(client, mlflow, tmpdir):
140+
def test_mlflow_log_dir(client, _, tmpdir):
141141
"""Test that the trainer saves checkpoints in the logger's save dir."""
142142

143143
# simulate experiment creation with mlflow client mock
@@ -165,7 +165,7 @@ def test_mlflow_log_dir(client, mlflow, tmpdir):
165165
def test_mlflow_logger_dirs_creation(tmpdir):
166166
"""Test that the logger creates the folders and files in the right place."""
167167
if not _MLFLOW_AVAILABLE:
168-
pytest.xfail("test for explicit file creation requires mlflow dependency to be installed.")
168+
pytest.skip("test for explicit file creation requires mlflow dependency to be installed.")
169169

170170
assert not os.listdir(tmpdir)
171171
logger = MLFlowLogger("test", save_dir=tmpdir)
@@ -201,9 +201,9 @@ def on_train_epoch_end(self, *args, **kwargs):
201201
assert os.listdir(trainer.checkpoint_callback.dirpath) == [f"epoch=0-step={limit_batches}.ckpt"]
202202

203203

204-
@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
204+
@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
205205
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
206-
def test_mlflow_experiment_id_retrieved_once(client, mlflow, tmpdir):
206+
def test_mlflow_experiment_id_retrieved_once(client, tmpdir):
207207
"""Test that the logger experiment_id retrieved only once."""
208208
logger = MLFlowLogger("test", save_dir=tmpdir)
209209
_ = logger.experiment
@@ -213,9 +213,9 @@ def test_mlflow_experiment_id_retrieved_once(client, mlflow, tmpdir):
213213

214214

215215
@mock.patch("pytorch_lightning.loggers.mlflow.Metric")
216-
@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
216+
@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
217217
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
218-
def test_mlflow_logger_with_unexpected_characters(client, mlflow, _, tmpdir):
218+
def test_mlflow_logger_with_unexpected_characters(client, _, __, tmpdir):
219219
"""Test that the logger raises warning with special characters not accepted by MLFlow."""
220220
logger = MLFlowLogger("test", save_dir=tmpdir)
221221
metrics = {"[some_metric]": 10}
@@ -224,9 +224,9 @@ def test_mlflow_logger_with_unexpected_characters(client, mlflow, _, tmpdir):
224224
logger.log_metrics(metrics)
225225

226226

227-
@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
227+
@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
228228
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
229-
def test_mlflow_logger_with_long_param_value(client, mlflow, tmpdir):
229+
def test_mlflow_logger_with_long_param_value(client, _, tmpdir):
230230
"""Test that the logger raises warning with special characters not accepted by MLFlow."""
231231
logger = MLFlowLogger("test", save_dir=tmpdir)
232232
value = "test" * 100
@@ -240,9 +240,9 @@ def test_mlflow_logger_with_long_param_value(client, mlflow, tmpdir):
240240
@mock.patch("pytorch_lightning.loggers.mlflow.Metric")
241241
@mock.patch("pytorch_lightning.loggers.mlflow.Param")
242242
@mock.patch("pytorch_lightning.loggers.mlflow.time")
243-
@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
243+
@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
244244
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
245-
def test_mlflow_logger_experiment_calls(client, mlflow, time, param, metric, tmpdir):
245+
def test_mlflow_logger_experiment_calls(client, _, time, param, metric, tmpdir):
246246
"""Test that the logger calls methods on the mlflow experiment correctly."""
247247
time.return_value = 1
248248

@@ -268,7 +268,7 @@ def test_mlflow_logger_experiment_calls(client, mlflow, time, param, metric, tmp
268268
)
269269

270270

271-
@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
271+
@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
272272
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
273273
def test_mlflow_logger_finalize_when_exception(*_):
274274
logger = MLFlowLogger("test")
@@ -286,7 +286,7 @@ def test_mlflow_logger_finalize_when_exception(*_):
286286
logger.experiment.set_terminated.assert_called_once_with(logger.run_id, "FAILED")
287287

288288

289-
@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
289+
@mock.patch("pytorch_lightning.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
290290
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
291291
@pytest.mark.parametrize("log_model", ["all", True, False])
292292
def test_mlflow_log_model(client, _, tmpdir, log_model):

0 commit comments

Comments
 (0)