Skip to content

Commit b220ba8

Browse files
committed
Add a trainer.ckpt_path setter for stateful loading
1 parent 2b3bba6 commit b220ba8

File tree

6 files changed

+149
-63
lines changed

6 files changed

+149
-63
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1111

1212
- Added migration logic to warn about checkpoints with apex AMP state ([#16161](https://github.com/Lightning-AI/lightning/pull/16161))
1313

14+
- Added the `Trainer.ckpt_path = ...` setter to statefully set the checkpoint path to load. This can act as a replacement for the removed `Trainer(resume_from_checkpoint=...)` flag ([#16187](https://github.com/Lightning-AI/lightning/pull/16187))
15+
1416
### Removed
1517

1618
- Removed the `pytorch_lightning.lite` module in favor of `lightning_fabric` ([#15953](https://github.com/Lightning-AI/lightning/pull/15953))

src/pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 53 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@
4848
class CheckpointConnector:
4949
def __init__(self, trainer: "pl.Trainer") -> None:
5050
self.trainer = trainer
51-
self.resume_checkpoint_path: Optional[_PATH] = None
51+
self._ckpt_path: Optional[_PATH] = None
52+
# flag to know if the user is changing the checkpoint path statefully. See `trainer.ckpt_path.setter`
53+
self._user_managed: bool = False
5254
self._loaded_checkpoint: Dict[str, Any] = {}
5355

5456
@property
@@ -73,7 +75,7 @@ def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None:
7375
3. from `checkpoint_path` file if provided
7476
4. don't restore
7577
"""
76-
self.resume_checkpoint_path = checkpoint_path
78+
self._ckpt_path = checkpoint_path
7779
if not checkpoint_path:
7880
log.detail("`checkpoint_path` not specified. Skipping checkpoint loading.")
7981
return
@@ -83,9 +85,41 @@ def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None:
8385
loaded_checkpoint = self.trainer.strategy.load_checkpoint(checkpoint_path)
8486
self._loaded_checkpoint = _pl_migrate_checkpoint(loaded_checkpoint, checkpoint_path)
8587

86-
def _set_ckpt_path(
87-
self, state_fn: TrainerFn, ckpt_path: Optional[str], model_provided: bool, model_connected: bool
88-
) -> Optional[str]:
88+
def _select_ckpt_path(
89+
self, state_fn: TrainerFn, ckpt_path: Optional[_PATH], model_provided: bool, model_connected: bool
90+
) -> Optional[_PATH]:
91+
"""Called by the ``Trainer`` to select the checkpoint path source."""
92+
if self._user_managed:
93+
if ckpt_path:
94+
rank_zero_warn(
95+
f"`trainer.ckpt_path = {self._ckpt_path!r}` was called but then you"
96+
f" passed `trainer.fit(ckpt_path={ckpt_path!r})`. The latter will be loaded."
97+
)
98+
# reset the previous path
99+
self._ckpt_path = None
100+
self._user_managed = False
101+
ckpt_path = self._parse_ckpt_path(
102+
state_fn,
103+
ckpt_path,
104+
model_provided=model_provided,
105+
model_connected=model_connected,
106+
)
107+
else:
108+
ckpt_path = self._ckpt_path
109+
else:
110+
ckpt_path = self._parse_ckpt_path(
111+
state_fn,
112+
ckpt_path,
113+
model_provided=model_provided,
114+
model_connected=model_connected,
115+
)
116+
return ckpt_path
117+
118+
def _parse_ckpt_path(
119+
self, state_fn: TrainerFn, ckpt_path: Optional[_PATH], model_provided: bool, model_connected: bool
120+
) -> Optional[_PATH]:
121+
"""Converts the ``ckpt_path`` special values into an actual filepath, depending on the trainer
122+
configuration."""
89123
if ckpt_path is None and SLURMEnvironment.detect() and self._hpc_resume_path is not None:
90124
ckpt_path = "hpc"
91125

@@ -181,15 +215,12 @@ def resume_end(self) -> None:
181215
"""Signal the connector that all states have resumed and memory for the checkpoint object can be
182216
released."""
183217
assert self.trainer.state.fn is not None
184-
if self.resume_checkpoint_path:
185-
if self.trainer.state.fn == TrainerFn.FITTING:
186-
rank_zero_info(f"Restored all states from the checkpoint file at {self.resume_checkpoint_path}")
187-
elif self.trainer.state.fn in (TrainerFn.VALIDATING, TrainerFn.TESTING, TrainerFn.PREDICTING):
188-
rank_zero_info(f"Loaded model weights from checkpoint at {self.resume_checkpoint_path}")
189-
self.resume_checkpoint_path = None
190-
self._loaded_checkpoint = {}
218+
if self._ckpt_path:
219+
message = "Restored all states" if self.trainer.state.fn == TrainerFn.FITTING else "Loaded model weights"
220+
rank_zero_info(f"{message} from the checkpoint at {self._ckpt_path}")
191221

192-
# clear cache after restore
222+
# free memory
223+
self._loaded_checkpoint = {}
193224
torch.cuda.empty_cache()
194225

195226
# wait for all to catch up
@@ -391,9 +422,15 @@ def restore_lr_schedulers(self) -> None:
391422
for config, lrs_state in zip(self.trainer.lr_scheduler_configs, lr_schedulers):
392423
config.scheduler.load_state_dict(lrs_state)
393424

394-
# ----------------------------------
395-
# PRIVATE OPS
396-
# ----------------------------------
425+
def _restore_modules_and_callbacks(self, checkpoint_path: Optional[_PATH] = None) -> None:
426+
# restore modules after setup
427+
self.resume_start(checkpoint_path)
428+
self._restore_quantization_callbacks()
429+
self.restore_model()
430+
self.restore_datamodule()
431+
if self.trainer.state.fn == TrainerFn.FITTING:
432+
# restore callback states
433+
self.restore_callbacks()
397434

398435
def dump_checkpoint(self, weights_only: bool = False) -> dict:
399436
"""Creating a model checkpoint dictionary object from various component states.

src/pytorch_lightning/trainer/trainer.py

Lines changed: 31 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -387,9 +387,6 @@ def __init__(
387387
# default .predict() loop
388388
self.predict_loop = PredictionLoop()
389389

390-
# set when a checkpoint is loaded via `Trainer.{fit,validate,test,predict}`.
391-
self._ckpt_path: Optional[str] = None
392-
393390
# init callbacks
394391
# Declare attributes to be set in _callback_connector on_trainer_init
395392
self._callback_connector.on_trainer_init(
@@ -569,14 +566,13 @@ def _fit_impl(
569566
model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule
570567
)
571568

572-
ckpt_path = ckpt_path
573-
self._ckpt_path = self._checkpoint_connector._set_ckpt_path(
569+
ckpt_path = self._checkpoint_connector._select_ckpt_path(
574570
self.state.fn,
575571
ckpt_path,
576572
model_provided=True,
577573
model_connected=self.lightning_module is not None,
578574
)
579-
self._run(model, ckpt_path=self.ckpt_path)
575+
self._run(model, ckpt_path=ckpt_path)
580576

581577
assert self.state.stopped
582578
self.training = False
@@ -660,14 +656,10 @@ def _validate_impl(
660656
# links data to the trainer
661657
self._data_connector.attach_data(model, val_dataloaders=dataloaders, datamodule=datamodule)
662658

663-
self._ckpt_path = self._checkpoint_connector._set_ckpt_path(
659+
ckpt_path = self._checkpoint_connector._select_ckpt_path(
664660
self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
665661
)
666-
667-
self._validated_ckpt_path = self.ckpt_path # TODO: remove in v1.8
668-
669-
# run validate
670-
results = self._run(model, ckpt_path=self.ckpt_path)
662+
results = self._run(model, ckpt_path=ckpt_path)
671663

672664
assert self.state.stopped
673665
self.validating = False
@@ -753,14 +745,10 @@ def _test_impl(
753745
# links data to the trainer
754746
self._data_connector.attach_data(model, test_dataloaders=dataloaders, datamodule=datamodule)
755747

756-
self._ckpt_path = self._checkpoint_connector._set_ckpt_path(
748+
ckpt_path = self._checkpoint_connector._select_ckpt_path(
757749
self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
758750
)
759-
760-
self._tested_ckpt_path = self.ckpt_path # TODO: remove in v1.8
761-
762-
# run test
763-
results = self._run(model, ckpt_path=self.ckpt_path)
751+
results = self._run(model, ckpt_path=ckpt_path)
764752

765753
assert self.state.stopped
766754
self.testing = False
@@ -846,13 +834,10 @@ def _predict_impl(
846834
# links data to the trainer
847835
self._data_connector.attach_data(model, predict_dataloaders=dataloaders, datamodule=datamodule)
848836

849-
self._ckpt_path = self._checkpoint_connector._set_ckpt_path(
837+
ckpt_path = self._checkpoint_connector._select_ckpt_path(
850838
self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
851839
)
852-
853-
self._predicted_ckpt_path = self.ckpt_path # TODO: remove in v1.8
854-
855-
results = self._run(model, ckpt_path=self.ckpt_path)
840+
results = self._run(model, ckpt_path=ckpt_path)
856841

857842
assert self.state.stopped
858843
self.predicting = False
@@ -913,18 +898,8 @@ def tune(
913898

914899
return result
915900

916-
def _restore_modules_and_callbacks(self, checkpoint_path: Optional[_PATH] = None) -> None:
917-
# restore modules after setup
918-
self._checkpoint_connector.resume_start(checkpoint_path)
919-
self._checkpoint_connector._restore_quantization_callbacks()
920-
self._checkpoint_connector.restore_model()
921-
self._checkpoint_connector.restore_datamodule()
922-
if self.state.fn == TrainerFn.FITTING:
923-
# restore callback states
924-
self._checkpoint_connector.restore_callbacks()
925-
926901
def _run(
927-
self, model: "pl.LightningModule", ckpt_path: Optional[str] = None
902+
self, model: "pl.LightningModule", ckpt_path: Optional[_PATH] = None
928903
) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
929904
if model._compiler_ctx is not None:
930905
supported_strategies = [SingleDeviceStrategy, DDPStrategy, DDPFullyShardedNativeStrategy]
@@ -973,7 +948,7 @@ def _run(
973948
# check if we should delay restoring checkpoint till later
974949
if not self.strategy.restore_checkpoint_after_setup:
975950
log.detail(f"{self.__class__.__name__}: restoring module and callbacks from checkpoint path: {ckpt_path}")
976-
self._restore_modules_and_callbacks(ckpt_path)
951+
self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path)
977952

978953
log.detail(f"{self.__class__.__name__}: configuring sharded model")
979954
self._call_configure_sharded_model() # allow user to setup in model sharded environment
@@ -1021,7 +996,7 @@ def _run(
1021996

1022997
if self.strategy.restore_checkpoint_after_setup:
1023998
log.detail(f"{self.__class__.__name__}: restoring module and callbacks from checkpoint path: {ckpt_path}")
1024-
self._restore_modules_and_callbacks(ckpt_path)
999+
self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path)
10251000

10261001
# restore optimizers, etc.
10271002
log.detail(f"{self.__class__.__name__}: restoring training state")
@@ -1806,12 +1781,30 @@ def progress_bar_callback(self) -> Optional[ProgressBarBase]:
18061781
return None
18071782

18081783
@property
1809-
def ckpt_path(self) -> Optional[str]:
1784+
def ckpt_path(self) -> Optional[_PATH]:
18101785
"""Set to the path/URL of a checkpoint loaded via :meth:`~pytorch_lightning.trainer.trainer.Trainer.fit`,
18111786
:meth:`~pytorch_lightning.trainer.trainer.Trainer.validate`,
18121787
:meth:`~pytorch_lightning.trainer.trainer.Trainer.test`, or
18131788
:meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`. ``None`` otherwise."""
1814-
return self._ckpt_path
1789+
return self._checkpoint_connector._ckpt_path
1790+
1791+
@ckpt_path.setter
1792+
def ckpt_path(self, ckpt_path: Optional[_PATH]) -> None:
1793+
"""Allows you to manage which checkpoint is loaded statefully.
1794+
1795+
Examples::
1796+
1797+
trainer = Trainer()
1798+
trainer.ckpt_path = "my/checkpoint/file.ckpt"
1799+
trainer.fit(model)
1800+
...
1801+
1802+
# you will be in charge of resetting this
1803+
trainer.ckpt_path = None
1804+
trainer.test(model)
1805+
"""
1806+
self._checkpoint_connector._ckpt_path = ckpt_path
1807+
self._checkpoint_connector._user_managed = bool(ckpt_path)
18151808

18161809
def save_checkpoint(
18171810
self, filepath: _PATH, weights_only: bool = False, storage_options: Optional[Any] = None

tests/tests_pytorch/core/test_datamodules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
230230

231231
for trainer_fn in TrainerFn:
232232
trainer.state.fn = trainer_fn
233-
trainer._restore_modules_and_callbacks(checkpoint_path)
233+
trainer._checkpoint_connector._restore_modules_and_callbacks(checkpoint_path)
234234
assert dm.my_state_dict == {"my": "state_dict"}
235235

236236

tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
import os
1515
from unittest import mock
16+
from unittest.mock import Mock
1617

1718
import pytest
1819
import torch
@@ -21,6 +22,7 @@
2122
from pytorch_lightning.callbacks import ModelCheckpoint
2223
from pytorch_lightning.demos.boring_classes import BoringModel
2324
from pytorch_lightning.trainer.states import TrainerFn
25+
from pytorch_lightning.utilities.migration.utils import _set_version
2426

2527

2628
def test_preloaded_checkpoint_lifecycle(tmpdir):
@@ -31,26 +33,27 @@ def test_preloaded_checkpoint_lifecycle(tmpdir):
3133

3234
connector = trainer._checkpoint_connector
3335

34-
assert not connector.resume_checkpoint_path
36+
assert not connector._ckpt_path
3537
assert not connector._loaded_checkpoint
3638

3739
connector.resume_start()
38-
assert not connector.resume_checkpoint_path
40+
assert not connector._ckpt_path
3941
assert not connector._loaded_checkpoint
4042
connector.resume_end()
41-
assert not connector.resume_checkpoint_path
43+
assert not connector._ckpt_path
4244
assert not connector._loaded_checkpoint
4345

4446
ckpt_path = trainer.checkpoint_callback.best_model_path
4547
trainer = Trainer(default_root_dir=tmpdir, max_steps=2)
4648
connector = trainer._checkpoint_connector
4749
connector.resume_start(ckpt_path)
48-
assert connector.resume_checkpoint_path == ckpt_path
50+
assert connector._ckpt_path == ckpt_path
4951
assert connector._loaded_checkpoint
5052
assert isinstance(connector._loaded_checkpoint, dict)
5153
trainer.state.fn = TrainerFn.FITTING
5254
connector.resume_end()
53-
assert not connector.resume_checkpoint_path
55+
# not cleared until next restoration, as the user might access it through `trainer.ckpt_path`
56+
assert connector._ckpt_path == ckpt_path
5457
assert not connector._loaded_checkpoint
5558

5659

@@ -166,3 +169,54 @@ def test_loops_restore(tmpdir):
166169
if fn2 != fn:
167170
trainer_loop2 = getattr(trainer, f"{fn2}_loop")
168171
trainer_loop2.load_state_dict.assert_not_called()
172+
173+
174+
def test_stateful_trainer_ckpt_path_support(tmp_path):
175+
"""Tests support for the pattern used by NeMo's experiment manager."""
176+
model = BoringModel()
177+
178+
# dummy ckpt data
179+
ckpt_data = {"state_dict": model.state_dict(), "optimizer_states": {}, "lr_schedulers": {}}
180+
_set_version(ckpt_data, "2.0.0")
181+
182+
# save a "checkpoint"
183+
ckpt_path = tmp_path / "foo.ckpt"
184+
torch.save(ckpt_data, ckpt_path)
185+
186+
# mock model checkpoint instance that has saved a last checkpoint
187+
model_checkpoint = Mock(spec=ModelCheckpoint)
188+
last_path = tmp_path / "last.ckpt"
189+
torch.save(ckpt_data, last_path)
190+
model_checkpoint._find_last_checkpoints.return_value = {last_path}
191+
192+
trainer = Trainer(default_root_dir=tmp_path, fast_dev_run=True, callbacks=model_checkpoint)
193+
194+
# set the ckpt path statefully
195+
trainer.ckpt_path = ckpt_path
196+
trainer.fit(model)
197+
assert trainer.ckpt_path == ckpt_path # not automatically cleaned
198+
assert trainer._checkpoint_connector._user_managed
199+
200+
# now conflict with ckpt_path functionally
201+
with pytest.warns(UserWarning, match="trainer.ckpt_path =.*but then you passed"):
202+
trainer.fit(model, ckpt_path="last")
203+
assert trainer.ckpt_path == last_path
204+
assert not trainer._checkpoint_connector._user_managed
205+
206+
# mock model checkpoint instance that has saved a last checkpoint
207+
best_path = tmp_path / "best.ckpt"
208+
torch.save(ckpt_data, best_path)
209+
model_checkpoint.best_model_path = best_path
210+
211+
# `trainer.test` will use this over "best" if statefully set
212+
trainer.ckpt_path = ckpt_path
213+
trainer.test()
214+
assert trainer.ckpt_path == ckpt_path
215+
216+
# ckpt_path = "best" still works if it's reset
217+
trainer.ckpt_path = None
218+
# the state is cleared
219+
assert trainer._checkpoint_connector._ckpt_path is None
220+
assert not trainer._checkpoint_connector._user_managed
221+
trainer.test()
222+
assert trainer.ckpt_path == best_path

0 commit comments

Comments
 (0)