Skip to content

Commit 9c3ab07

Browse files
committed
Revert "Fix: Make WandbLogger upload models from all ModelCheckpoint callbacks, not just one (#20191)"
This reverts commit bd9d114
1 parent 838e8b6 commit 9c3ab07

File tree

2 files changed

+5
-47
lines changed

2 files changed

+5
-47
lines changed

src/lightning/pytorch/loggers/wandb.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,6 @@ def any_lightning_module_function_or_hook(self):
278278
prefix: A string to put at the beginning of metric keys.
279279
experiment: WandB experiment object. Automatically set when creating a run.
280280
checkpoint_name: Name of the model checkpoint artifact being logged.
281-
add_file_policy: If "mutable", copies file to tempdirectory before upload.
282281
\**kwargs: Arguments passed to :func:`wandb.init` like `entity`, `group`, `tags`, etc.
283282
284283
Raises:
@@ -305,7 +304,6 @@ def __init__(
305304
experiment: Union["Run", "RunDisabled", None] = None,
306305
prefix: str = "",
307306
checkpoint_name: Optional[str] = None,
308-
add_file_policy: Literal["mutable", "immutable"] = "mutable",
309307
**kwargs: Any,
310308
) -> None:
311309
if not _WANDB_AVAILABLE:
@@ -324,8 +322,7 @@ def __init__(
324322
self._prefix = prefix
325323
self._experiment = experiment
326324
self._logged_model_time: dict[str, float] = {}
327-
self._checkpoint_callbacks: dict[int, ModelCheckpoint] = {}
328-
self.add_file_policy = add_file_policy
325+
self._checkpoint_callback: Optional[ModelCheckpoint] = None
329326

330327
# paths are processed as strings
331328
if save_dir is not None:
@@ -594,7 +591,7 @@ def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None:
594591
if self._log_model == "all" or self._log_model is True and checkpoint_callback.save_top_k == -1:
595592
self._scan_and_log_checkpoints(checkpoint_callback)
596593
elif self._log_model is True:
597-
self._checkpoint_callbacks[id(checkpoint_callback)] = checkpoint_callback
594+
self._checkpoint_callback = checkpoint_callback
598595

599596
@staticmethod
600597
@rank_zero_only
@@ -647,9 +644,8 @@ def finalize(self, status: str) -> None:
647644
# Currently, checkpoints only get logged on success
648645
return
649646
# log checkpoints as artifacts
650-
if self._experiment is not None:
651-
for checkpoint_callback in self._checkpoint_callbacks.values():
652-
self._scan_and_log_checkpoints(checkpoint_callback)
647+
if self._checkpoint_callback and self._experiment is not None:
648+
self._scan_and_log_checkpoints(self._checkpoint_callback)
653649

654650
def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None:
655651
import wandb
@@ -679,7 +675,7 @@ def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> Non
679675
if not self._checkpoint_name:
680676
self._checkpoint_name = f"model-{self.experiment.id}"
681677
artifact = wandb.Artifact(name=self._checkpoint_name, type="model", metadata=metadata)
682-
artifact.add_file(p, name="model.ckpt", policy=self.add_file_policy)
678+
artifact.add_file(p, name="model.ckpt")
683679
aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"]
684680
self.experiment.log_artifact(artifact, aliases=aliases)
685681
# remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name)

tests/tests_pytorch/loggers/test_wandb.py

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -426,44 +426,6 @@ def test_wandb_log_model(wandb_mock, tmp_path):
426426
)
427427
wandb_mock.init().log_artifact.assert_called_with(wandb_mock.Artifact(), aliases=["latest", "best"])
428428

429-
# Test wandb artifact with two checkpoint_callbacks
430-
wandb_mock.init().log_artifact.reset_mock()
431-
wandb_mock.init.reset_mock()
432-
wandb_mock.Artifact.reset_mock()
433-
logger = WandbLogger(save_dir=tmp_path, log_model=True)
434-
logger.experiment.id = "1"
435-
logger.experiment.name = "run_name"
436-
trainer = Trainer(
437-
default_root_dir=tmp_path,
438-
logger=logger,
439-
max_epochs=3,
440-
limit_train_batches=3,
441-
limit_val_batches=3,
442-
callbacks=[
443-
ModelCheckpoint(monitor="epoch", save_top_k=2),
444-
ModelCheckpoint(monitor="step", save_top_k=2),
445-
],
446-
)
447-
trainer.fit(model)
448-
for name, val, version in [("epoch", 0, 2), ("step", 3, 3)]:
449-
wandb_mock.Artifact.assert_any_call(
450-
name="model-1",
451-
type="model",
452-
metadata={
453-
"score": val,
454-
"original_filename": f"epoch=0-step=3-v{version}.ckpt",
455-
"ModelCheckpoint": {
456-
"monitor": name,
457-
"mode": "min",
458-
"save_last": None,
459-
"save_top_k": 2,
460-
"save_weights_only": False,
461-
"_every_n_train_steps": 0,
462-
},
463-
},
464-
)
465-
wandb_mock.init().log_artifact.assert_any_call(wandb_mock.Artifact(), aliases=["latest"])
466-
467429

468430
def test_wandb_log_model_with_score(wandb_mock, tmp_path):
469431
"""Test to prevent regression on #15543, ensuring the score is logged as a Python number, not a scalar tensor."""

0 commit comments

Comments
 (0)