From b220ba88f8efaf0b997397880ab4aef267464bed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 5 Jan 2023 11:50:56 +0100 Subject: [PATCH] Add a `trainer.ckpt_path` setter for stateful loading --- src/pytorch_lightning/CHANGELOG.md | 2 + .../connectors/checkpoint_connector.py | 69 ++++++++++++++----- src/pytorch_lightning/trainer/trainer.py | 69 +++++++++---------- tests/tests_pytorch/core/test_datamodules.py | 2 +- .../connectors/test_checkpoint_connector.py | 64 +++++++++++++++-- tests/tests_pytorch/trainer/test_trainer.py | 6 +- 6 files changed, 149 insertions(+), 63 deletions(-) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 888cdc20e31c1..44e24e7f5a8b0 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -11,6 +11,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added migration logic to warn about checkpoints with apex AMP state ([#16161](https://github.com/Lightning-AI/lightning/pull/16161)) +- Added the `Trainer.ckpt_path = ...` setter to statefully set the checkpoint path to load. This can act as a replacement for the removed `Trainer(resume_from_checkpoint=...)` flag ([#16187](https://github.com/Lightning-AI/lightning/pull/16187)) + ### Removed - Removed the `pytorch_lightning.lite` module in favor of `lightning_fabric` ([#15953](https://github.com/Lightning-AI/lightning/pull/15953)) diff --git a/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py index fbec09f002837..08b7bdad05df2 100644 --- a/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -48,7 +48,9 @@ class CheckpointConnector: def __init__(self, trainer: "pl.Trainer") -> None: self.trainer = trainer - self.resume_checkpoint_path: Optional[_PATH] = None + self._ckpt_path: Optional[_PATH] = None + # flag to know if the user is changing the checkpoint path statefully. See `trainer.ckpt_path.setter` + self._user_managed: bool = False self._loaded_checkpoint: Dict[str, Any] = {} @property @@ -73,7 +75,7 @@ def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None: 3. from `checkpoint_path` file if provided 4. don't restore """ - self.resume_checkpoint_path = checkpoint_path + self._ckpt_path = checkpoint_path if not checkpoint_path: log.detail("`checkpoint_path` not specified. Skipping checkpoint loading.") return @@ -83,9 +85,41 @@ def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None: loaded_checkpoint = self.trainer.strategy.load_checkpoint(checkpoint_path) self._loaded_checkpoint = _pl_migrate_checkpoint(loaded_checkpoint, checkpoint_path) - def _set_ckpt_path( - self, state_fn: TrainerFn, ckpt_path: Optional[str], model_provided: bool, model_connected: bool - ) -> Optional[str]: + def _select_ckpt_path( + self, state_fn: TrainerFn, ckpt_path: Optional[_PATH], model_provided: bool, model_connected: bool + ) -> Optional[_PATH]: + """Called by the ``Trainer`` to select the checkpoint path source.""" + if self._user_managed: + if ckpt_path: + rank_zero_warn( + f"`trainer.ckpt_path = {self._ckpt_path!r}` was called but then you" + f" passed `trainer.fit(ckpt_path={ckpt_path!r})`. The latter will be loaded." + ) + # reset the previous path + self._ckpt_path = None + self._user_managed = False + ckpt_path = self._parse_ckpt_path( + state_fn, + ckpt_path, + model_provided=model_provided, + model_connected=model_connected, + ) + else: + ckpt_path = self._ckpt_path + else: + ckpt_path = self._parse_ckpt_path( + state_fn, + ckpt_path, + model_provided=model_provided, + model_connected=model_connected, + ) + return ckpt_path + + def _parse_ckpt_path( + self, state_fn: TrainerFn, ckpt_path: Optional[_PATH], model_provided: bool, model_connected: bool + ) -> Optional[_PATH]: + """Converts the ``ckpt_path`` special values into an actual filepath, depending on the trainer + configuration.""" if ckpt_path is None and SLURMEnvironment.detect() and self._hpc_resume_path is not None: ckpt_path = "hpc" @@ -181,15 +215,12 @@ def resume_end(self) -> None: """Signal the connector that all states have resumed and memory for the checkpoint object can be released.""" assert self.trainer.state.fn is not None - if self.resume_checkpoint_path: - if self.trainer.state.fn == TrainerFn.FITTING: - rank_zero_info(f"Restored all states from the checkpoint file at {self.resume_checkpoint_path}") - elif self.trainer.state.fn in (TrainerFn.VALIDATING, TrainerFn.TESTING, TrainerFn.PREDICTING): - rank_zero_info(f"Loaded model weights from checkpoint at {self.resume_checkpoint_path}") - self.resume_checkpoint_path = None - self._loaded_checkpoint = {} + if self._ckpt_path: + message = "Restored all states" if self.trainer.state.fn == TrainerFn.FITTING else "Loaded model weights" + rank_zero_info(f"{message} from the checkpoint at {self._ckpt_path}") - # clear cache after restore + # free memory + self._loaded_checkpoint = {} torch.cuda.empty_cache() # wait for all to catch up @@ -391,9 +422,15 @@ def restore_lr_schedulers(self) -> None: for config, lrs_state in zip(self.trainer.lr_scheduler_configs, lr_schedulers): config.scheduler.load_state_dict(lrs_state) - # ---------------------------------- - # PRIVATE OPS - # ---------------------------------- + def _restore_modules_and_callbacks(self, checkpoint_path: Optional[_PATH] = None) -> None: + # restore modules after setup + self.resume_start(checkpoint_path) + self._restore_quantization_callbacks() + self.restore_model() + self.restore_datamodule() + if self.trainer.state.fn == TrainerFn.FITTING: + # restore callback states + self.restore_callbacks() def dump_checkpoint(self, weights_only: bool = False) -> dict: """Creating a model checkpoint dictionary object from various component states. diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index ed642f6189a89..f94bbb81155d9 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -387,9 +387,6 @@ def __init__( # default .predict() loop self.predict_loop = PredictionLoop() - # set when a checkpoint is loaded via `Trainer.{fit,validate,test,predict}`. - self._ckpt_path: Optional[str] = None - # init callbacks # Declare attributes to be set in _callback_connector on_trainer_init self._callback_connector.on_trainer_init( @@ -569,14 +566,13 @@ def _fit_impl( model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule ) - ckpt_path = ckpt_path - self._ckpt_path = self._checkpoint_connector._set_ckpt_path( + ckpt_path = self._checkpoint_connector._select_ckpt_path( self.state.fn, ckpt_path, model_provided=True, model_connected=self.lightning_module is not None, ) - self._run(model, ckpt_path=self.ckpt_path) + self._run(model, ckpt_path=ckpt_path) assert self.state.stopped self.training = False @@ -660,14 +656,10 @@ def _validate_impl( # links data to the trainer self._data_connector.attach_data(model, val_dataloaders=dataloaders, datamodule=datamodule) - self._ckpt_path = self._checkpoint_connector._set_ckpt_path( + ckpt_path = self._checkpoint_connector._select_ckpt_path( self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None ) - - self._validated_ckpt_path = self.ckpt_path # TODO: remove in v1.8 - - # run validate - results = self._run(model, ckpt_path=self.ckpt_path) + results = self._run(model, ckpt_path=ckpt_path) assert self.state.stopped self.validating = False @@ -753,14 +745,10 @@ def _test_impl( # links data to the trainer self._data_connector.attach_data(model, test_dataloaders=dataloaders, datamodule=datamodule) - self._ckpt_path = self._checkpoint_connector._set_ckpt_path( + ckpt_path = self._checkpoint_connector._select_ckpt_path( self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None ) - - self._tested_ckpt_path = self.ckpt_path # TODO: remove in v1.8 - - # run test - results = self._run(model, ckpt_path=self.ckpt_path) + results = self._run(model, ckpt_path=ckpt_path) assert self.state.stopped self.testing = False @@ -846,13 +834,10 @@ def _predict_impl( # links data to the trainer self._data_connector.attach_data(model, predict_dataloaders=dataloaders, datamodule=datamodule) - self._ckpt_path = self._checkpoint_connector._set_ckpt_path( + ckpt_path = self._checkpoint_connector._select_ckpt_path( self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None ) - - self._predicted_ckpt_path = self.ckpt_path # TODO: remove in v1.8 - - results = self._run(model, ckpt_path=self.ckpt_path) + results = self._run(model, ckpt_path=ckpt_path) assert self.state.stopped self.predicting = False @@ -913,18 +898,8 @@ def tune( return result - def _restore_modules_and_callbacks(self, checkpoint_path: Optional[_PATH] = None) -> None: - # restore modules after setup - self._checkpoint_connector.resume_start(checkpoint_path) - self._checkpoint_connector._restore_quantization_callbacks() - self._checkpoint_connector.restore_model() - self._checkpoint_connector.restore_datamodule() - if self.state.fn == TrainerFn.FITTING: - # restore callback states - self._checkpoint_connector.restore_callbacks() - def _run( - self, model: "pl.LightningModule", ckpt_path: Optional[str] = None + self, model: "pl.LightningModule", ckpt_path: Optional[_PATH] = None ) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: if model._compiler_ctx is not None: supported_strategies = [SingleDeviceStrategy, DDPStrategy, DDPFullyShardedNativeStrategy] @@ -973,7 +948,7 @@ def _run( # check if we should delay restoring checkpoint till later if not self.strategy.restore_checkpoint_after_setup: log.detail(f"{self.__class__.__name__}: restoring module and callbacks from checkpoint path: {ckpt_path}") - self._restore_modules_and_callbacks(ckpt_path) + self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path) log.detail(f"{self.__class__.__name__}: configuring sharded model") self._call_configure_sharded_model() # allow user to setup in model sharded environment @@ -1021,7 +996,7 @@ def _run( if self.strategy.restore_checkpoint_after_setup: log.detail(f"{self.__class__.__name__}: restoring module and callbacks from checkpoint path: {ckpt_path}") - self._restore_modules_and_callbacks(ckpt_path) + self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path) # restore optimizers, etc. log.detail(f"{self.__class__.__name__}: restoring training state") @@ -1806,12 +1781,30 @@ def progress_bar_callback(self) -> Optional[ProgressBarBase]: return None @property - def ckpt_path(self) -> Optional[str]: + def ckpt_path(self) -> Optional[_PATH]: """Set to the path/URL of a checkpoint loaded via :meth:`~pytorch_lightning.trainer.trainer.Trainer.fit`, :meth:`~pytorch_lightning.trainer.trainer.Trainer.validate`, :meth:`~pytorch_lightning.trainer.trainer.Trainer.test`, or :meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`. ``None`` otherwise.""" - return self._ckpt_path + return self._checkpoint_connector._ckpt_path + + @ckpt_path.setter + def ckpt_path(self, ckpt_path: Optional[_PATH]) -> None: + """Allows you to manage which checkpoint is loaded statefully. + + Examples:: + + trainer = Trainer() + trainer.ckpt_path = "my/checkpoint/file.ckpt" + trainer.fit(model) + ... + + # you will be in charge of resetting this + trainer.ckpt_path = None + trainer.test(model) + """ + self._checkpoint_connector._ckpt_path = ckpt_path + self._checkpoint_connector._user_managed = bool(ckpt_path) def save_checkpoint( self, filepath: _PATH, weights_only: bool = False, storage_options: Optional[Any] = None diff --git a/tests/tests_pytorch/core/test_datamodules.py b/tests/tests_pytorch/core/test_datamodules.py index 53493ed8cc103..5ce1a452d6e3c 100644 --- a/tests/tests_pytorch/core/test_datamodules.py +++ b/tests/tests_pytorch/core/test_datamodules.py @@ -230,7 +230,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: for trainer_fn in TrainerFn: trainer.state.fn = trainer_fn - trainer._restore_modules_and_callbacks(checkpoint_path) + trainer._checkpoint_connector._restore_modules_and_callbacks(checkpoint_path) assert dm.my_state_dict == {"my": "state_dict"} diff --git a/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py b/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py index 511b42e60f509..0846b7a8a7d82 100644 --- a/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py @@ -13,6 +13,7 @@ # limitations under the License. import os from unittest import mock +from unittest.mock import Mock import pytest import torch @@ -21,6 +22,7 @@ from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.demos.boring_classes import BoringModel from pytorch_lightning.trainer.states import TrainerFn +from pytorch_lightning.utilities.migration.utils import _set_version def test_preloaded_checkpoint_lifecycle(tmpdir): @@ -31,26 +33,27 @@ def test_preloaded_checkpoint_lifecycle(tmpdir): connector = trainer._checkpoint_connector - assert not connector.resume_checkpoint_path + assert not connector._ckpt_path assert not connector._loaded_checkpoint connector.resume_start() - assert not connector.resume_checkpoint_path + assert not connector._ckpt_path assert not connector._loaded_checkpoint connector.resume_end() - assert not connector.resume_checkpoint_path + assert not connector._ckpt_path assert not connector._loaded_checkpoint ckpt_path = trainer.checkpoint_callback.best_model_path trainer = Trainer(default_root_dir=tmpdir, max_steps=2) connector = trainer._checkpoint_connector connector.resume_start(ckpt_path) - assert connector.resume_checkpoint_path == ckpt_path + assert connector._ckpt_path == ckpt_path assert connector._loaded_checkpoint assert isinstance(connector._loaded_checkpoint, dict) trainer.state.fn = TrainerFn.FITTING connector.resume_end() - assert not connector.resume_checkpoint_path + # not cleared until next restoration, as the user might access it through `trainer.ckpt_path` + assert connector._ckpt_path == ckpt_path assert not connector._loaded_checkpoint @@ -166,3 +169,54 @@ def test_loops_restore(tmpdir): if fn2 != fn: trainer_loop2 = getattr(trainer, f"{fn2}_loop") trainer_loop2.load_state_dict.assert_not_called() + + +def test_stateful_trainer_ckpt_path_support(tmp_path): + """Tests support for the pattern used by NeMo's experiment manager.""" + model = BoringModel() + + # dummy ckpt data + ckpt_data = {"state_dict": model.state_dict(), "optimizer_states": {}, "lr_schedulers": {}} + _set_version(ckpt_data, "2.0.0") + + # save a "checkpoint" + ckpt_path = tmp_path / "foo.ckpt" + torch.save(ckpt_data, ckpt_path) + + # mock model checkpoint instance that has saved a last checkpoint + model_checkpoint = Mock(spec=ModelCheckpoint) + last_path = tmp_path / "last.ckpt" + torch.save(ckpt_data, last_path) + model_checkpoint._find_last_checkpoints.return_value = {last_path} + + trainer = Trainer(default_root_dir=tmp_path, fast_dev_run=True, callbacks=model_checkpoint) + + # set the ckpt path statefully + trainer.ckpt_path = ckpt_path + trainer.fit(model) + assert trainer.ckpt_path == ckpt_path # not automatically cleaned + assert trainer._checkpoint_connector._user_managed + + # now conflict with ckpt_path functionally + with pytest.warns(UserWarning, match="trainer.ckpt_path =.*but then you passed"): + trainer.fit(model, ckpt_path="last") + assert trainer.ckpt_path == last_path + assert not trainer._checkpoint_connector._user_managed + + # mock model checkpoint instance that has saved a last checkpoint + best_path = tmp_path / "best.ckpt" + torch.save(ckpt_data, best_path) + model_checkpoint.best_model_path = best_path + + # `trainer.test` will use this over "best" if statefully set + trainer.ckpt_path = ckpt_path + trainer.test() + assert trainer.ckpt_path == ckpt_path + + # ckpt_path = "best" still works if it's reset + trainer.ckpt_path = None + # the state is cleared + assert trainer._checkpoint_connector._ckpt_path is None + assert not trainer._checkpoint_connector._user_managed + trainer.test() + assert trainer.ckpt_path == best_path diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py index edc161853638c..dba7b35155f5c 100644 --- a/tests/tests_pytorch/trainer/test_trainer.py +++ b/tests/tests_pytorch/trainer/test_trainer.py @@ -720,7 +720,7 @@ def test_checkpoint_path_input_last_fault_tolerant(tmpdir, ckpt_path, fn): final_path = "foobar" with ctxt: - ckpt_path = trainer._checkpoint_connector._set_ckpt_path( + ckpt_path = trainer._checkpoint_connector._parse_ckpt_path( fn, ckpt_path, model_provided=fn == "fit", model_connected=True ) assert ckpt_path == final_path @@ -933,7 +933,7 @@ def test_best_ckpt_evaluate_raises_warning_with_multiple_ckpt_callbacks(): trainer.state.fn = TrainerFn.TESTING with pytest.warns(UserWarning, match="best checkpoint path from first checkpoint callback"): - trainer._checkpoint_connector._set_ckpt_path( + trainer._checkpoint_connector._parse_ckpt_path( trainer.state.fn, ckpt_path="best", model_provided=False, model_connected=True ) @@ -1701,7 +1701,7 @@ def test_exception_when_testing_or_validating_with_fast_dev_run(): trainer = Trainer(fast_dev_run=True) trainer.state.fn = TrainerFn.TESTING with pytest.raises(ValueError, match=r"with `fast_dev_run=True`. .* pass an exact checkpoint path"): - trainer._checkpoint_connector._set_ckpt_path( + trainer._checkpoint_connector._parse_ckpt_path( trainer.state.fn, ckpt_path="best", model_provided=False, model_connected=True )