Skip to content

Commit 9020e4d

Browse files
cgebbepre-commit-ci[bot]
authored andcommitted
Fix: Make WandbLogger upload models from all ModelCheckpoint callbacks, not just one (#20191)
* test: add failing test using two callbacks * fix: save all checkpoint callbacks to wandb * chore: fix mypy * chore: fix ruff --------- Co-authored-by: Christian Gebbe <> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> (cherry picked from commit bd9d114)
1 parent 80d1bce commit 9020e4d

File tree

2 files changed

+43
-4
lines changed

2 files changed

+43
-4
lines changed

src/lightning/pytorch/loggers/wandb.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ def __init__(
322322
self._prefix = prefix
323323
self._experiment = experiment
324324
self._logged_model_time: dict[str, float] = {}
325-
self._checkpoint_callback: Optional[ModelCheckpoint] = None
325+
self._checkpoint_callbacks: dict[int, ModelCheckpoint] = {}
326326

327327
# paths are processed as strings
328328
if save_dir is not None:
@@ -591,7 +591,7 @@ def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None:
591591
if self._log_model == "all" or self._log_model is True and checkpoint_callback.save_top_k == -1:
592592
self._scan_and_log_checkpoints(checkpoint_callback)
593593
elif self._log_model is True:
594-
self._checkpoint_callback = checkpoint_callback
594+
self._checkpoint_callbacks[id(checkpoint_callback)] = checkpoint_callback
595595

596596
@staticmethod
597597
@rank_zero_only
@@ -644,8 +644,9 @@ def finalize(self, status: str) -> None:
644644
# Currently, checkpoints only get logged on success
645645
return
646646
# log checkpoints as artifacts
647-
if self._checkpoint_callback and self._experiment is not None:
648-
self._scan_and_log_checkpoints(self._checkpoint_callback)
647+
if self._experiment is not None:
648+
for checkpoint_callback in self._checkpoint_callbacks.values():
649+
self._scan_and_log_checkpoints(checkpoint_callback)
649650

650651
def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None:
651652
import wandb

tests/tests_pytorch/loggers/test_wandb.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,44 @@ def test_wandb_log_model(wandb_mock, tmp_path):
426426
)
427427
wandb_mock.init().log_artifact.assert_called_with(wandb_mock.Artifact(), aliases=["latest", "best"])
428428

429+
# Test wandb artifact with two checkpoint_callbacks
430+
wandb_mock.init().log_artifact.reset_mock()
431+
wandb_mock.init.reset_mock()
432+
wandb_mock.Artifact.reset_mock()
433+
logger = WandbLogger(save_dir=tmp_path, log_model=True)
434+
logger.experiment.id = "1"
435+
logger.experiment.name = "run_name"
436+
trainer = Trainer(
437+
default_root_dir=tmp_path,
438+
logger=logger,
439+
max_epochs=3,
440+
limit_train_batches=3,
441+
limit_val_batches=3,
442+
callbacks=[
443+
ModelCheckpoint(monitor="epoch", save_top_k=2),
444+
ModelCheckpoint(monitor="step", save_top_k=2),
445+
],
446+
)
447+
trainer.fit(model)
448+
for name, val, version in [("epoch", 0, 2), ("step", 3, 3)]:
449+
wandb_mock.Artifact.assert_any_call(
450+
name="model-1",
451+
type="model",
452+
metadata={
453+
"score": val,
454+
"original_filename": f"epoch=0-step=3-v{version}.ckpt",
455+
"ModelCheckpoint": {
456+
"monitor": name,
457+
"mode": "min",
458+
"save_last": None,
459+
"save_top_k": 2,
460+
"save_weights_only": False,
461+
"_every_n_train_steps": 0,
462+
},
463+
},
464+
)
465+
wandb_mock.init().log_artifact.assert_any_call(wandb_mock.Artifact(), aliases=["latest"])
466+
429467

430468
def test_wandb_log_model_with_score(wandb_mock, tmp_path):
431469
"""Test to prevent regression on #15543, ensuring the score is logged as a Python number, not a scalar tensor."""

0 commit comments

Comments
 (0)