Skip to content

Commit c0de998

Browse files
authored
Merge branch 'master' into set-global_step-default-wandb-x-axis-sync_tensorboard=True
2 parents 6bb8e56 + 71793c6 commit c0de998

File tree

10 files changed

+134
-10
lines changed

10 files changed

+134
-10
lines changed

.github/workflows/call-clear-cache.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,18 @@ on:
2323
jobs:
2424
cron-clear:
2525
if: github.event_name == 'schedule' || github.event_name == 'pull_request'
26-
uses: Lightning-AI/utilities/.github/workflows/cleanup-caches.yml@v0.12.0
26+
uses: Lightning-AI/utilities/.github/workflows/cleanup-caches.yml@v0.14.0
2727
with:
28-
scripts-ref: v0.11.8
28+
scripts-ref: v0.14.0
2929
dry-run: ${{ github.event_name == 'pull_request' }}
3030
pattern: "latest|docs"
3131
age-days: 7
3232

3333
direct-clear:
3434
if: github.event_name == 'workflow_dispatch' || github.event_name == 'pull_request'
35-
uses: Lightning-AI/utilities/.github/workflows/cleanup-caches.yml@v0.12.0
35+
uses: Lightning-AI/utilities/.github/workflows/cleanup-caches.yml@v0.14.0
3636
with:
37-
scripts-ref: v0.11.8
37+
scripts-ref: v0.14.0
3838
dry-run: ${{ github.event_name == 'pull_request' }}
3939
pattern: ${{ inputs.pattern || 'pypi_wheels' }} # setting str in case of PR / debugging
4040
age-days: ${{ fromJSON(inputs.age-days) || 0 }} # setting 0 in case of PR / debugging

.github/workflows/ci-check-md-links.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ on:
1414

1515
jobs:
1616
check-md-links:
17-
uses: Lightning-AI/utilities/.github/workflows/check-md-links.yml@v0.12.0
17+
uses: Lightning-AI/utilities/.github/workflows/check-md-links.yml@v0.14.0
1818
with:
1919
config-file: ".github/markdown-links-config.json"
2020
base-branch: "master"

.github/workflows/ci-schema.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ on:
88

99
jobs:
1010
check:
11-
uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@v0.12.0
11+
uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@v0.14.0
1212
with:
1313
# skip azure due to the wrong schema file by MSFT
1414
# https://github.com/Lightning-AI/lightning-flash/pull/1455#issuecomment-1244793607

docs/source-pytorch/visualize/loggers.rst

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,37 @@ Track and Visualize Experiments
5454

5555
</div>
5656
</div>
57+
58+
.. _mlflow_logger:
59+
60+
MLflow Logger
61+
-------------
62+
63+
The MLflow logger in PyTorch Lightning now includes a `checkpoint_path_prefix` parameter. This parameter allows you to prefix the checkpoint artifact's path when logging checkpoints as artifacts.
64+
65+
Example usage:
66+
67+
.. code-block:: python
68+
69+
import lightning as L
70+
from lightning.pytorch.loggers import MLFlowLogger
71+
72+
mlf_logger = MLFlowLogger(
73+
experiment_name="lightning_logs",
74+
tracking_uri="file:./ml-runs",
75+
checkpoint_path_prefix="my_prefix"
76+
)
77+
trainer = L.Trainer(logger=mlf_logger)
78+
79+
# Your LightningModule definition
80+
class LitModel(L.LightningModule):
81+
def training_step(self, batch, batch_idx):
82+
# example
83+
self.logger.experiment.whatever_ml_flow_supports(...)
84+
85+
def any_lightning_module_function_or_hook(self):
86+
self.logger.experiment.whatever_ml_flow_supports(...)
87+
88+
# Train your model
89+
model = LitModel()
90+
trainer.fit(model)

src/lightning/pytorch/CHANGELOG.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,30 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
88

99
### Added
1010

11+
-
12+
13+
1114
### Changed
1215

1316
- Change `wandb` default x-axis to `tensorboard`'s `global_step` when `sync_tensorboard=True` ([#20611](https://github.com/Lightning-AI/pytorch-lightning/pull/20611))
1417

18+
19+
- Added a new `checkpoint_path_prefix` parameter to the MLflow logger which can control the path to where the MLflow artifacts for the model checkpoints are stored ([#20538](https://github.com/Lightning-AI/pytorch-lightning/pull/20538))
20+
21+
1522
### Removed
1623

24+
-
25+
26+
1727
### Fixed
1828

1929
- Fix CSVLogger logging hyperparameter at every write which increase latency ([#20594](https://github.com/Lightning-AI/pytorch-lightning/pull/20594))
2030

2131

32+
- Always call `WandbLogger.experiment` first in `_call_setup_hook` to ensure `tensorboard` logs can sync to `wandb` ([#20610](https://github.com/Lightning-AI/pytorch-lightning/pull/20610))
33+
34+
2235
## [2.5.0] - 2024-12-19
2336

2437
### Added

src/lightning/pytorch/loggers/mlflow.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def any_lightning_module_function_or_hook(self):
9797
:paramref:`~lightning.pytorch.callbacks.Checkpoint.save_top_k` ``== -1``
9898
which also logs every checkpoint during training.
9999
* if ``log_model == False`` (default), no checkpoint is logged.
100-
100+
checkpoint_path_prefix: A string to prefix the checkpoint artifact's path.
101101
prefix: A string to put at the beginning of metric keys.
102102
artifact_location: The location to store run artifacts. If not provided, the server picks an appropriate
103103
default.
@@ -121,6 +121,7 @@ def __init__(
121121
tags: Optional[dict[str, Any]] = None,
122122
save_dir: Optional[str] = "./mlruns",
123123
log_model: Literal[True, False, "all"] = False,
124+
checkpoint_path_prefix: str = "",
124125
prefix: str = "",
125126
artifact_location: Optional[str] = None,
126127
run_id: Optional[str] = None,
@@ -147,6 +148,7 @@ def __init__(
147148
self._artifact_location = artifact_location
148149
self._log_batch_kwargs = {} if synchronous is None else {"synchronous": synchronous}
149150
self._initialized = False
151+
self._checkpoint_path_prefix = checkpoint_path_prefix
150152

151153
from mlflow.tracking import MlflowClient
152154

@@ -361,7 +363,7 @@ def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> Non
361363
aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"]
362364

363365
# Artifact path on mlflow
364-
artifact_path = Path(p).stem
366+
artifact_path = Path(self._checkpoint_path_prefix) / Path(p).stem
365367

366368
# Log the checkpoint
367369
self.experiment.log_artifact(self._run_id, p, artifact_path)

src/lightning/pytorch/trainer/call.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import lightning.pytorch as pl
2222
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
2323
from lightning.pytorch.callbacks import Checkpoint, EarlyStopping
24+
from lightning.pytorch.loggers import WandbLogger
2425
from lightning.pytorch.strategies.launchers import _SubprocessScriptLauncher
2526
from lightning.pytorch.trainer.connectors.signal_connector import _get_sigkill_signal
2627
from lightning.pytorch.trainer.states import TrainerStatus
@@ -91,8 +92,12 @@ def _call_setup_hook(trainer: "pl.Trainer") -> None:
9192
if isinstance(module, _DeviceDtypeModuleMixin):
9293
module._device = trainer.strategy.root_device
9394

95+
# wandb.init must be called before any tensorboard writers are created in order to sync tensorboard logs to wandb:
96+
# https://github.com/wandb/wandb/issues/1782#issuecomment-779161203
97+
loggers = sorted(trainer.loggers, key=lambda logger: not isinstance(logger, WandbLogger))
98+
9499
# Trigger lazy creation of experiment in loggers so loggers have their metadata available
95-
for logger in trainer.loggers:
100+
for logger in loggers:
96101
if hasattr(logger, "experiment"):
97102
_ = logger.experiment
98103

tests/tests_pytorch/core/test_results.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
from functools import partial
1515

16+
import pytest
1617
import torch
1718
import torch.distributed as dist
1819

@@ -48,6 +49,8 @@ def result_reduce_ddp_fn(strategy):
4849
assert actual.item() == dist.get_world_size()
4950

5051

52+
# flaky with "process 0 terminated with signal SIGABRT"
53+
@pytest.mark.flaky(reruns=3, only_rerun="torch.multiprocessing.spawn.ProcessExitedException")
5154
@RunIf(skip_windows=True)
5255
def test_result_reduce_ddp():
5356
spawn_launch(result_reduce_ddp_fn, [torch.device("cpu")] * 2)

tests/tests_pytorch/loggers/test_mlflow.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,3 +427,33 @@ def test_set_tracking_uri(mlflow_mock):
427427
mlflow_mock.set_tracking_uri.assert_not_called()
428428
_ = logger.experiment
429429
mlflow_mock.set_tracking_uri.assert_called_with("the_tracking_uri")
430+
431+
432+
@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
433+
def test_mlflow_log_model_with_checkpoint_path_prefix(mlflow_mock, tmp_path):
434+
"""Test that the logger creates the folders and files in the right place with a prefix."""
435+
client = mlflow_mock.tracking.MlflowClient
436+
437+
# Get model, logger, trainer and train
438+
model = BoringModel()
439+
logger = MLFlowLogger("test", save_dir=str(tmp_path), log_model="all", checkpoint_path_prefix="my_prefix")
440+
logger = mock_mlflow_run_creation(logger, experiment_id="test-id")
441+
442+
trainer = Trainer(
443+
default_root_dir=tmp_path,
444+
logger=logger,
445+
max_epochs=2,
446+
limit_train_batches=3,
447+
limit_val_batches=3,
448+
)
449+
trainer.fit(model)
450+
451+
# Checkpoint log
452+
assert client.return_value.log_artifact.call_count == 2
453+
# Metadata and aliases log
454+
assert client.return_value.log_artifacts.call_count == 2
455+
456+
# Check that the prefix is used in the artifact path
457+
for call in client.return_value.log_artifact.call_args_list:
458+
args, _ = call
459+
assert str(args[2]).startswith("my_prefix")

tests/tests_pytorch/loggers/test_wandb.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from lightning.pytorch.callbacks import ModelCheckpoint
2525
from lightning.pytorch.cli import LightningCLI
2626
from lightning.pytorch.demos.boring_classes import BoringModel
27-
from lightning.pytorch.loggers import WandbLogger
27+
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
2828
from lightning.pytorch.utilities.exceptions import MisconfigurationException
2929
from tests_pytorch.test_cli import _xfail_python_ge_3_11_9
3030

@@ -151,6 +151,43 @@ def test_wandb_logger_init_before_spawn(wandb_mock):
151151
assert logger._experiment is not None
152152

153153

154+
def test_wandb_logger_experiment_called_first(wandb_mock, tmp_path):
155+
wandb_experiment_called = False
156+
157+
def tensorboard_experiment_side_effect() -> mock.MagicMock:
158+
nonlocal wandb_experiment_called
159+
assert wandb_experiment_called
160+
return mock.MagicMock()
161+
162+
def wandb_experiment_side_effect() -> mock.MagicMock:
163+
nonlocal wandb_experiment_called
164+
wandb_experiment_called = True
165+
return mock.MagicMock()
166+
167+
with (
168+
mock.patch.object(
169+
TensorBoardLogger,
170+
"experiment",
171+
new_callable=lambda: mock.PropertyMock(side_effect=tensorboard_experiment_side_effect),
172+
),
173+
mock.patch.object(
174+
WandbLogger,
175+
"experiment",
176+
new_callable=lambda: mock.PropertyMock(side_effect=wandb_experiment_side_effect),
177+
),
178+
):
179+
model = BoringModel()
180+
trainer = Trainer(
181+
default_root_dir=tmp_path,
182+
log_every_n_steps=1,
183+
limit_train_batches=0,
184+
limit_val_batches=0,
185+
max_steps=1,
186+
logger=[TensorBoardLogger(tmp_path), WandbLogger(save_dir=tmp_path)],
187+
)
188+
trainer.fit(model)
189+
190+
154191
def test_wandb_pickle(wandb_mock, tmp_path):
155192
"""Verify that pickling trainer with wandb logger works.
156193

0 commit comments

Comments
 (0)