Skip to content

Commit df09370

Browse files
authored
Run on_train_epoch_end after the LM for callbacks that monitor (#16567)
1 parent 645416e commit df09370

File tree

5 files changed

+28
-13
lines changed

5 files changed

+28
-13
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4141

4242
- Renamed the `pl.utilities.exceptions.GracefulExitException` to `SIGTERMException` ([#16501](https://github.com/Lightning-AI/lightning/pull/16501))
4343

44+
- The `Callback.on_train_epoch_end` hook now runs after the `LightningModule.on_train_epoch_end` hook for instances of `EarlyStopping` and `Checkpoint` callbacks ([#16567](https://github.com/Lightning-AI/lightning/pull/16567))
4445

4546
- The `LightningModule.{un}toggle_optimizer` methods no longer accept a `optimizer_idx` argument to select the relevant optimizer. Instead, the optimizer object can be passed in directly ([#16560](https://github.com/Lightning-AI/lightning/pull/16560))
4647

src/pytorch_lightning/loops/fit_loop.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,8 +302,12 @@ def on_advance_end(self) -> None:
302302
self.epoch_progress.increment_processed()
303303

304304
# call train epoch end hooks
305-
self.trainer._call_callback_hooks("on_train_epoch_end")
305+
# we always call callback hooks first, but here we need to make an exception for the callbacks that
306+
# monitor a metric, otherwise they wouldn't be able to monitor a key logged in
307+
# `LightningModule.on_train_epoch_end`
308+
self.trainer._call_callback_hooks("on_train_epoch_end", monitoring_callbacks=False)
306309
self.trainer._call_lightning_module_hook("on_train_epoch_end")
310+
self.trainer._call_callback_hooks("on_train_epoch_end", monitoring_callbacks=True)
307311

308312
self.trainer._logger_connector.on_epoch_end()
309313

src/pytorch_lightning/trainer/trainer.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1173,6 +1173,7 @@ def _call_callback_hooks(
11731173
self,
11741174
hook_name: str,
11751175
*args: Any,
1176+
monitoring_callbacks: Optional[bool] = None,
11761177
**kwargs: Any,
11771178
) -> None:
11781179
log.debug(f"{self.__class__.__name__}: calling callback hook: {hook_name}")
@@ -1182,7 +1183,14 @@ def _call_callback_hooks(
11821183
prev_fx_name = pl_module._current_fx_name
11831184
pl_module._current_fx_name = hook_name
11841185

1185-
for callback in self.callbacks:
1186+
callbacks = self.callbacks
1187+
if monitoring_callbacks is True:
1188+
# the list of "monitoring callbacks" is hard-coded to these two. we could add an API to define this
1189+
callbacks = [cb for cb in callbacks if isinstance(cb, (EarlyStopping, Checkpoint))]
1190+
elif monitoring_callbacks is False:
1191+
callbacks = [cb for cb in callbacks if not isinstance(cb, (EarlyStopping, Checkpoint))]
1192+
1193+
for callback in callbacks:
11861194
fn = getattr(callback, hook_name)
11871195
if callable(fn):
11881196
with self.profiler.profile(f"[Callback]{callback.state_key}.{hook_name}"):

tests/tests_pytorch/callbacks/test_early_stopping.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def test_early_stopping_patience(tmpdir, loss_values: list, patience: int, expec
134134
class ModelOverrideValidationReturn(BoringModel):
135135
validation_return_values = torch.tensor(loss_values)
136136

137-
def validation_epoch_end(self, outputs):
137+
def on_validation_epoch_end(self):
138138
loss = self.validation_return_values[self.current_epoch]
139139
self.log("test_val_loss", loss)
140140

@@ -164,7 +164,7 @@ def test_early_stopping_patience_train(
164164
class ModelOverrideTrainReturn(BoringModel):
165165
train_return_values = torch.tensor(loss_values)
166166

167-
def training_epoch_end(self, outputs):
167+
def on_train_epoch_end(self):
168168
loss = self.train_return_values[self.current_epoch]
169169
self.log("train_loss", loss)
170170

@@ -187,7 +187,7 @@ def training_epoch_end(self, outputs):
187187
assert trainer.current_epoch - 1 == expected_stop_epoch
188188

189189

190-
def test_pickling(tmpdir):
190+
def test_pickling():
191191
early_stopping = EarlyStopping(monitor="foo")
192192

193193
early_stopping_pickled = pickle.dumps(early_stopping)
@@ -226,7 +226,7 @@ def test_early_stopping_no_val_step(tmpdir):
226226
)
227227
def test_early_stopping_thresholds(tmpdir, stopping_threshold, divergence_threshold, losses, expected_epoch):
228228
class CurrentModel(BoringModel):
229-
def validation_epoch_end(self, outputs):
229+
def on_validation_epoch_end(self):
230230
val_loss = losses[self.current_epoch]
231231
self.log("abc", val_loss)
232232

@@ -252,7 +252,7 @@ def test_early_stopping_on_non_finite_monitor(tmpdir, stop_value):
252252
expected_stop_epoch = 2
253253

254254
class CurrentModel(BoringModel):
255-
def validation_epoch_end(self, outputs):
255+
def on_validation_epoch_end(self):
256256
val_loss = losses[self.current_epoch]
257257
self.log("val_loss", val_loss)
258258

@@ -352,12 +352,12 @@ def _epoch_end(self) -> None:
352352
self.log("abc", torch.tensor(loss))
353353
self.log("cba", torch.tensor(0))
354354

355-
def training_epoch_end(self, outputs):
355+
def on_train_epoch_end(self):
356356
if not self.early_stop_on_train:
357357
return
358358
self._epoch_end()
359359

360-
def validation_epoch_end(self, outputs):
360+
def on_validation_epoch_end(self):
361361
if self.early_stop_on_train:
362362
return
363363
self._epoch_end()

tests/tests_pytorch/models/test_hooks.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -548,11 +548,11 @@ def training_step(self, batch, batch_idx):
548548
dict(name="on_validation_model_train"),
549549
dict(name="training_epoch_end", args=([dict(loss=ANY)] * train_batches,)),
550550
dict(name="Callback.on_train_epoch_end", args=(trainer, model)),
551-
# `ModelCheckpoint.save_checkpoint` is called here from `Callback.on_train_epoch_end`
551+
dict(name="on_train_epoch_end"), # before ModelCheckpoint because it's a "monitoring callback"
552+
# `ModelCheckpoint.save_checkpoint` is called here
552553
dict(name="Callback.state_dict"),
553554
dict(name="Callback.on_save_checkpoint", args=(trainer, model, saved_ckpt)),
554555
dict(name="on_save_checkpoint", args=(saved_ckpt,)),
555-
dict(name="on_train_epoch_end"),
556556
dict(name="Callback.on_train_end", args=(trainer, model)),
557557
dict(name="on_train_end"),
558558
dict(name="Callback.on_fit_end", args=(trainer, model)),
@@ -627,10 +627,11 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_epochs(tmpdir):
627627
*model._train_batch(trainer, model, 2, current_epoch=1, current_batch=0),
628628
dict(name="training_epoch_end", args=([dict(loss=ANY)] * 2,)),
629629
dict(name="Callback.on_train_epoch_end", args=(trainer, model)),
630+
dict(name="on_train_epoch_end"), # before ModelCheckpoint because it's a "monitoring callback"
631+
# `ModelCheckpoint.save_checkpoint` is called here
630632
dict(name="Callback.state_dict"),
631633
dict(name="Callback.on_save_checkpoint", args=(trainer, model, saved_ckpt)),
632634
dict(name="on_save_checkpoint", args=(saved_ckpt,)),
633-
dict(name="on_train_epoch_end"),
634635
dict(name="Callback.on_train_end", args=(trainer, model)),
635636
dict(name="on_train_end"),
636637
dict(name="Callback.on_fit_end", args=(trainer, model)),
@@ -706,10 +707,11 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_steps(tmpdir):
706707
*model._train_batch(trainer, model, steps_after_reload, current_batch=1),
707708
dict(name="training_epoch_end", args=([dict(loss=ANY)] * train_batches,)),
708709
dict(name="Callback.on_train_epoch_end", args=(trainer, model)),
710+
dict(name="on_train_epoch_end"), # before ModelCheckpoint because it's a "monitoring callback"
711+
# `ModelCheckpoint.save_checkpoint` is called here
709712
dict(name="Callback.state_dict"),
710713
dict(name="Callback.on_save_checkpoint", args=(trainer, model, saved_ckpt)),
711714
dict(name="on_save_checkpoint", args=(saved_ckpt,)),
712-
dict(name="on_train_epoch_end"),
713715
dict(name="Callback.on_train_end", args=(trainer, model)),
714716
dict(name="on_train_end"),
715717
dict(name="Callback.on_fit_end", args=(trainer, model)),

0 commit comments

Comments
 (0)