Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -397,9 +397,7 @@ def _restore_modules_and_callbacks(self, checkpoint_path: Optional[_PATH] = None
self.resume_start(checkpoint_path)
self.restore_model()
self.restore_datamodule()
if self.trainer.state.fn == TrainerFn.FITTING:
# restore callback states
self.restore_callbacks()
self.restore_callbacks()

def dump_checkpoint(self, weights_only: bool = False) -> dict:
"""Creating a model checkpoint dictionary object from various component states.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import pytest
import torch
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.trainer.states import TrainerFn
from lightning.pytorch.utilities.migration.utils import _set_version
Expand Down Expand Up @@ -234,3 +234,53 @@ def test_strict_loading(strict_loading, expected, tmp_path):
trainer = Trainer(default_root_dir=tmp_path, barebones=True, max_steps=2)
trainer.fit(model, ckpt_path=(tmp_path / "checkpoint.ckpt"))
model.load_state_dict.assert_called_once_with(ANY, strict=expected)


@pytest.mark.parametrize("trainer_fn", ["validate", "test", "predict"])
def test_restore_callbacks_in_non_fit_phases(tmp_path, trainer_fn):
"""Test that callbacks are properly restored in non-fit phases."""

class TestCallback(Callback):
def __init__(self):
self.restored = False

def on_load_checkpoint(self, trainer, pl_module, checkpoint):
if "callbacks" in checkpoint:
callback_state = checkpoint["callbacks"][self.__class__.__name__]
self.restored = callback_state["restored"]

def state_dict(self):
return {"restored": self.restored}

def on_save_checkpoint(self, trainer, pl_module, checkpoint):
checkpoint["callbacks"] = checkpoint.get("callbacks", {})
checkpoint["callbacks"][self.__class__.__name__] = self.state_dict()

# First create and train a model with the callback
callback = TestCallback()
model = BoringModel()
trainer = Trainer(default_root_dir=tmp_path, callbacks=[callback], max_steps=1)
trainer.fit(model)

# Set the callback state to True before saving
callback.restored = True
ckpt_path = tmp_path / "checkpoint.ckpt"
trainer.save_checkpoint(ckpt_path)

# Now create new instances and test restoration
new_callback = TestCallback()
new_model = BoringModel()
assert not new_callback.restored # Should start False

new_trainer = Trainer(default_root_dir=tmp_path, callbacks=[new_callback])

# Connect the model and restore callbacks before evaluation
new_trainer.strategy.connect(new_model)
new_trainer._checkpoint_connector.resume_start(ckpt_path)
new_trainer._checkpoint_connector.restore_callbacks()

# Run the evaluation phase (validate/test/predict)
fn = getattr(new_trainer, trainer_fn)
fn(new_model, ckpt_path=ckpt_path)

assert new_callback.restored # Should be True after loading the checkpoint
Loading