Skip to content

Commit 75b5042

Browse files
authored
Validate that state-key is unique when using multiple callbacks of the same type (#15634)
1 parent 61c1f69 commit 75b5042

File tree

5 files changed

+58
-20
lines changed

5 files changed

+58
-20
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3535

3636
- The `ModelCheckpoint.save_on_train_epoch_end` attribute is now computed dynamically every epoch, accounting for changes to the validation dataloaders ([#15300](https://github.com/Lightning-AI/lightning/pull/15300))
3737

38+
- The Trainer now raises an error if it is given multiple stateful callbacks of the same time with colliding state keys ([#15634](https://github.com/Lightning-AI/lightning/pull/15634))
39+
40+
3841
### Fixed
3942

4043
- Enhanced `reduce_boolean_decision` to accommodate `any`-analogous semantics expected by the `EarlyStopping` callback ([#15253](https://github.com/Lightning-AI/lightning/pull/15253))

src/pytorch_lightning/trainer/connectors/callback_connector.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from pytorch_lightning.callbacks.timer import Timer
3535
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3636
from pytorch_lightning.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0
37+
from pytorch_lightning.utilities.model_helpers import is_overridden
3738
from pytorch_lightning.utilities.rank_zero import rank_zero_info
3839

3940
_log = logging.getLogger(__name__)
@@ -82,6 +83,7 @@ def on_trainer_init(
8283
self._configure_fault_tolerance_callbacks()
8384

8485
self.trainer.callbacks.extend(_configure_external_callbacks())
86+
_validate_callbacks_list(self.trainer.callbacks)
8587

8688
# push all model checkpoint callbacks to the end
8789
# it is important that these are the last callbacks to run
@@ -290,3 +292,18 @@ def _configure_external_callbacks() -> List[Callback]:
290292
)
291293
external_callbacks.extend(callbacks_list)
292294
return external_callbacks
295+
296+
297+
def _validate_callbacks_list(callbacks: List[Callback]) -> None:
298+
stateful_callbacks = [cb for cb in callbacks if is_overridden("state_dict", instance=cb)]
299+
seen_callbacks = set()
300+
for callback in stateful_callbacks:
301+
if callback.state_key in seen_callbacks:
302+
raise RuntimeError(
303+
f"Found more than one stateful callback of type `{type(callback).__name__}`. In the current"
304+
" configuration, this callback does not support being saved alongside other instances of the same type."
305+
f" Please consult the documentation of `{type(callback).__name__}` regarding valid settings for"
306+
" the callback state to be checkpointable."
307+
" HINT: The `callback.state_key` must be unique among all callbacks in the Trainer."
308+
)
309+
seen_callbacks.add(callback.state_key)

tests/tests_pytorch/checkpointing/test_model_checkpoint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -981,8 +981,8 @@ def assert_checkpoint_log_dir(idx):
981981
def test_configure_model_checkpoint(tmpdir):
982982
"""Test all valid and invalid ways a checkpoint callback can be passed to the Trainer."""
983983
kwargs = dict(default_root_dir=tmpdir)
984-
callback1 = ModelCheckpoint()
985-
callback2 = ModelCheckpoint()
984+
callback1 = ModelCheckpoint(monitor="foo")
985+
callback2 = ModelCheckpoint(monitor="bar")
986986

987987
# no callbacks
988988
trainer = Trainer(enable_checkpointing=False, callbacks=[], **kwargs)

tests/tests_pytorch/trainer/connectors/test_callback_connector.py

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from unittest import mock
1717
from unittest.mock import Mock
1818

19+
import pytest
1920
import torch
2021

2122
from pytorch_lightning import Callback, LightningModule, Trainer
@@ -36,8 +37,8 @@
3637

3738
def test_checkpoint_callbacks_are_last(tmpdir):
3839
"""Test that checkpoint callbacks always get moved to the end of the list, with preserved order."""
39-
checkpoint1 = ModelCheckpoint(tmpdir)
40-
checkpoint2 = ModelCheckpoint(tmpdir)
40+
checkpoint1 = ModelCheckpoint(tmpdir, monitor="foo")
41+
checkpoint2 = ModelCheckpoint(tmpdir, monitor="bar")
4142
model_summary = ModelSummary()
4243
early_stopping = EarlyStopping(monitor="foo")
4344
lr_monitor = LearningRateMonitor()
@@ -179,7 +180,8 @@ def _attach_callbacks(trainer_callbacks, model_callbacks):
179180
cb_connector._attach_model_callbacks()
180181
return trainer
181182

182-
early_stopping = EarlyStopping(monitor="foo")
183+
early_stopping1 = EarlyStopping(monitor="red")
184+
early_stopping2 = EarlyStopping(monitor="blue")
183185
progress_bar = TQDMProgressBar()
184186
lr_monitor = LearningRateMonitor()
185187
grad_accumulation = GradientAccumulationScheduler({1: 1})
@@ -189,40 +191,40 @@ def _attach_callbacks(trainer_callbacks, model_callbacks):
189191
assert trainer.callbacks == [trainer.accumulation_scheduler]
190192

191193
# callbacks of different types
192-
trainer = _attach_callbacks(trainer_callbacks=[early_stopping], model_callbacks=[progress_bar])
193-
assert trainer.callbacks == [early_stopping, trainer.accumulation_scheduler, progress_bar]
194+
trainer = _attach_callbacks(trainer_callbacks=[early_stopping1], model_callbacks=[progress_bar])
195+
assert trainer.callbacks == [early_stopping1, trainer.accumulation_scheduler, progress_bar]
194196

195197
# same callback type twice, different instance
196198
trainer = _attach_callbacks(
197-
trainer_callbacks=[progress_bar, EarlyStopping(monitor="foo")],
198-
model_callbacks=[early_stopping],
199+
trainer_callbacks=[progress_bar, EarlyStopping(monitor="red")],
200+
model_callbacks=[early_stopping1],
199201
)
200-
assert trainer.callbacks == [progress_bar, trainer.accumulation_scheduler, early_stopping]
202+
assert trainer.callbacks == [progress_bar, trainer.accumulation_scheduler, early_stopping1]
201203

202204
# multiple callbacks of the same type in trainer
203205
trainer = _attach_callbacks(
204206
trainer_callbacks=[
205207
LearningRateMonitor(),
206-
EarlyStopping(monitor="foo"),
208+
EarlyStopping(monitor="yellow"),
207209
LearningRateMonitor(),
208-
EarlyStopping(monitor="foo"),
210+
EarlyStopping(monitor="black"),
209211
],
210-
model_callbacks=[early_stopping, lr_monitor],
212+
model_callbacks=[early_stopping1, lr_monitor],
211213
)
212-
assert trainer.callbacks == [trainer.accumulation_scheduler, early_stopping, lr_monitor]
214+
assert trainer.callbacks == [trainer.accumulation_scheduler, early_stopping1, lr_monitor]
213215

214216
# multiple callbacks of the same type, in both trainer and model
215217
trainer = _attach_callbacks(
216218
trainer_callbacks=[
217219
LearningRateMonitor(),
218220
progress_bar,
219-
EarlyStopping(monitor="foo"),
221+
EarlyStopping(monitor="yellow"),
220222
LearningRateMonitor(),
221-
EarlyStopping(monitor="foo"),
223+
EarlyStopping(monitor="black"),
222224
],
223-
model_callbacks=[early_stopping, lr_monitor, grad_accumulation, early_stopping],
225+
model_callbacks=[early_stopping1, lr_monitor, grad_accumulation, early_stopping2],
224226
)
225-
assert trainer.callbacks == [progress_bar, early_stopping, lr_monitor, grad_accumulation, early_stopping]
227+
assert trainer.callbacks == [progress_bar, early_stopping1, lr_monitor, grad_accumulation, early_stopping2]
226228

227229

228230
def test_attach_model_callbacks_override_info(caplog):
@@ -296,3 +298,19 @@ def _make_entry_point_query_mock(callback_factory):
296298
import_path = "pkg_resources.iter_entry_points"
297299
with mock.patch(import_path, query_mock):
298300
yield
301+
302+
303+
def test_validate_unique_callback_state_key():
304+
"""Test that we raise an error if the state keys collide, leading to missing state in the checkpoint."""
305+
306+
class MockCallback(Callback):
307+
@property
308+
def state_key(self):
309+
return "same_key"
310+
311+
def state_dict(self):
312+
# pretend these callbacks are stateful by overriding the `state_dict` hook
313+
return {"state": 1}
314+
315+
with pytest.raises(RuntimeError, match="Found more than one stateful callback of type `MockCallback`"):
316+
Trainer(callbacks=[MockCallback(), MockCallback()])

tests/tests_pytorch/trainer/test_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -924,9 +924,9 @@ def test_best_ckpt_evaluate_raises_warning_with_multiple_ckpt_callbacks():
924924
"""Test that a warning is raised if best ckpt callback is used for evaluation configured with multiple
925925
checkpoints."""
926926

927-
ckpt_callback1 = ModelCheckpoint()
927+
ckpt_callback1 = ModelCheckpoint(monitor="foo")
928928
ckpt_callback1.best_model_path = "foo_best_model.ckpt"
929-
ckpt_callback2 = ModelCheckpoint()
929+
ckpt_callback2 = ModelCheckpoint(monitor="bar")
930930
ckpt_callback2.best_model_path = "bar_best_model.ckpt"
931931
trainer = Trainer(callbacks=[ckpt_callback1, ckpt_callback2])
932932
trainer.state.fn = TrainerFn.TESTING

0 commit comments

Comments
 (0)