Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added migration logic to warn about checkpoints with apex AMP state ([#16161](https://github.com/Lightning-AI/lightning/pull/16161))

- Added the `Trainer.ckpt_path = ...` setter to statefully set the checkpoint path to load. This can act as a replacement for the removed `Trainer(resume_from_checkpoint=...)` flag ([#16187](https://github.com/Lightning-AI/lightning/pull/16187))

### Removed

- Removed the `pytorch_lightning.lite` module in favor of `lightning_fabric` ([#15953](https://github.com/Lightning-AI/lightning/pull/15953))
Expand Down
69 changes: 53 additions & 16 deletions src/pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@
class CheckpointConnector:
def __init__(self, trainer: "pl.Trainer") -> None:
self.trainer = trainer
self.resume_checkpoint_path: Optional[_PATH] = None
self._ckpt_path: Optional[_PATH] = None
# flag to know if the user is changing the checkpoint path statefully. See `trainer.ckpt_path.setter`
self._user_managed: bool = False
self._loaded_checkpoint: Dict[str, Any] = {}

@property
Expand All @@ -73,7 +75,7 @@ def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None:
3. from `checkpoint_path` file if provided
4. don't restore
"""
self.resume_checkpoint_path = checkpoint_path
self._ckpt_path = checkpoint_path
if not checkpoint_path:
log.detail("`checkpoint_path` not specified. Skipping checkpoint loading.")
return
Expand All @@ -83,9 +85,41 @@ def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None:
loaded_checkpoint = self.trainer.strategy.load_checkpoint(checkpoint_path)
self._loaded_checkpoint = _pl_migrate_checkpoint(loaded_checkpoint, checkpoint_path)

def _set_ckpt_path(
self, state_fn: TrainerFn, ckpt_path: Optional[str], model_provided: bool, model_connected: bool
) -> Optional[str]:
def _select_ckpt_path(
self, state_fn: TrainerFn, ckpt_path: Optional[_PATH], model_provided: bool, model_connected: bool
) -> Optional[_PATH]:
"""Called by the ``Trainer`` to select the checkpoint path source."""
if self._user_managed:
if ckpt_path:
rank_zero_warn(
f"`trainer.ckpt_path = {self._ckpt_path!r}` was called but then you"
f" passed `trainer.fit(ckpt_path={ckpt_path!r})`. The latter will be loaded."
)
# reset the previous path
self._ckpt_path = None
self._user_managed = False
ckpt_path = self._parse_ckpt_path(
state_fn,
ckpt_path,
model_provided=model_provided,
model_connected=model_connected,
)
else:
ckpt_path = self._ckpt_path
else:
ckpt_path = self._parse_ckpt_path(
state_fn,
ckpt_path,
model_provided=model_provided,
model_connected=model_connected,
)
return ckpt_path

def _parse_ckpt_path(
self, state_fn: TrainerFn, ckpt_path: Optional[_PATH], model_provided: bool, model_connected: bool
) -> Optional[_PATH]:
"""Converts the ``ckpt_path`` special values into an actual filepath, depending on the trainer
configuration."""
if ckpt_path is None and SLURMEnvironment.detect() and self._hpc_resume_path is not None:
ckpt_path = "hpc"

Expand Down Expand Up @@ -181,15 +215,12 @@ def resume_end(self) -> None:
"""Signal the connector that all states have resumed and memory for the checkpoint object can be
released."""
assert self.trainer.state.fn is not None
if self.resume_checkpoint_path:
if self.trainer.state.fn == TrainerFn.FITTING:
rank_zero_info(f"Restored all states from the checkpoint file at {self.resume_checkpoint_path}")
elif self.trainer.state.fn in (TrainerFn.VALIDATING, TrainerFn.TESTING, TrainerFn.PREDICTING):
rank_zero_info(f"Loaded model weights from checkpoint at {self.resume_checkpoint_path}")
self.resume_checkpoint_path = None
self._loaded_checkpoint = {}
if self._ckpt_path:
message = "Restored all states" if self.trainer.state.fn == TrainerFn.FITTING else "Loaded model weights"
rank_zero_info(f"{message} from the checkpoint at {self._ckpt_path}")

# clear cache after restore
# free memory
self._loaded_checkpoint = {}
torch.cuda.empty_cache()

# wait for all to catch up
Expand Down Expand Up @@ -391,9 +422,15 @@ def restore_lr_schedulers(self) -> None:
for config, lrs_state in zip(self.trainer.lr_scheduler_configs, lr_schedulers):
config.scheduler.load_state_dict(lrs_state)

# ----------------------------------
# PRIVATE OPS
# ----------------------------------
def _restore_modules_and_callbacks(self, checkpoint_path: Optional[_PATH] = None) -> None:
# restore modules after setup
self.resume_start(checkpoint_path)
self._restore_quantization_callbacks()
self.restore_model()
self.restore_datamodule()
if self.trainer.state.fn == TrainerFn.FITTING:
# restore callback states
self.restore_callbacks()

def dump_checkpoint(self, weights_only: bool = False) -> dict:
"""Creating a model checkpoint dictionary object from various component states.
Expand Down
69 changes: 31 additions & 38 deletions src/pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,9 +387,6 @@ def __init__(
# default .predict() loop
self.predict_loop = PredictionLoop()

# set when a checkpoint is loaded via `Trainer.{fit,validate,test,predict}`.
self._ckpt_path: Optional[str] = None

# init callbacks
# Declare attributes to be set in _callback_connector on_trainer_init
self._callback_connector.on_trainer_init(
Expand Down Expand Up @@ -569,14 +566,13 @@ def _fit_impl(
model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule
)

ckpt_path = ckpt_path
self._ckpt_path = self._checkpoint_connector._set_ckpt_path(
ckpt_path = self._checkpoint_connector._select_ckpt_path(
self.state.fn,
ckpt_path,
model_provided=True,
model_connected=self.lightning_module is not None,
)
self._run(model, ckpt_path=self.ckpt_path)
self._run(model, ckpt_path=ckpt_path)

assert self.state.stopped
self.training = False
Expand Down Expand Up @@ -660,14 +656,10 @@ def _validate_impl(
# links data to the trainer
self._data_connector.attach_data(model, val_dataloaders=dataloaders, datamodule=datamodule)

self._ckpt_path = self._checkpoint_connector._set_ckpt_path(
ckpt_path = self._checkpoint_connector._select_ckpt_path(
self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
)

self._validated_ckpt_path = self.ckpt_path # TODO: remove in v1.8

# run validate
results = self._run(model, ckpt_path=self.ckpt_path)
results = self._run(model, ckpt_path=ckpt_path)

assert self.state.stopped
self.validating = False
Expand Down Expand Up @@ -753,14 +745,10 @@ def _test_impl(
# links data to the trainer
self._data_connector.attach_data(model, test_dataloaders=dataloaders, datamodule=datamodule)

self._ckpt_path = self._checkpoint_connector._set_ckpt_path(
ckpt_path = self._checkpoint_connector._select_ckpt_path(
self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
)

self._tested_ckpt_path = self.ckpt_path # TODO: remove in v1.8

# run test
results = self._run(model, ckpt_path=self.ckpt_path)
results = self._run(model, ckpt_path=ckpt_path)

assert self.state.stopped
self.testing = False
Expand Down Expand Up @@ -846,13 +834,10 @@ def _predict_impl(
# links data to the trainer
self._data_connector.attach_data(model, predict_dataloaders=dataloaders, datamodule=datamodule)

self._ckpt_path = self._checkpoint_connector._set_ckpt_path(
ckpt_path = self._checkpoint_connector._select_ckpt_path(
self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
)

self._predicted_ckpt_path = self.ckpt_path # TODO: remove in v1.8

results = self._run(model, ckpt_path=self.ckpt_path)
results = self._run(model, ckpt_path=ckpt_path)

assert self.state.stopped
self.predicting = False
Expand Down Expand Up @@ -913,18 +898,8 @@ def tune(

return result

def _restore_modules_and_callbacks(self, checkpoint_path: Optional[_PATH] = None) -> None:
# restore modules after setup
self._checkpoint_connector.resume_start(checkpoint_path)
self._checkpoint_connector._restore_quantization_callbacks()
self._checkpoint_connector.restore_model()
self._checkpoint_connector.restore_datamodule()
if self.state.fn == TrainerFn.FITTING:
# restore callback states
self._checkpoint_connector.restore_callbacks()

def _run(
self, model: "pl.LightningModule", ckpt_path: Optional[str] = None
self, model: "pl.LightningModule", ckpt_path: Optional[_PATH] = None
) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
if model._compiler_ctx is not None:
supported_strategies = [SingleDeviceStrategy, DDPStrategy, DDPFullyShardedNativeStrategy]
Expand Down Expand Up @@ -973,7 +948,7 @@ def _run(
# check if we should delay restoring checkpoint till later
if not self.strategy.restore_checkpoint_after_setup:
log.detail(f"{self.__class__.__name__}: restoring module and callbacks from checkpoint path: {ckpt_path}")
self._restore_modules_and_callbacks(ckpt_path)
self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path)

log.detail(f"{self.__class__.__name__}: configuring sharded model")
self._call_configure_sharded_model() # allow user to setup in model sharded environment
Expand Down Expand Up @@ -1021,7 +996,7 @@ def _run(

if self.strategy.restore_checkpoint_after_setup:
log.detail(f"{self.__class__.__name__}: restoring module and callbacks from checkpoint path: {ckpt_path}")
self._restore_modules_and_callbacks(ckpt_path)
self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path)

# restore optimizers, etc.
log.detail(f"{self.__class__.__name__}: restoring training state")
Expand Down Expand Up @@ -1806,12 +1781,30 @@ def progress_bar_callback(self) -> Optional[ProgressBarBase]:
return None

@property
def ckpt_path(self) -> Optional[str]:
def ckpt_path(self) -> Optional[_PATH]:
"""Set to the path/URL of a checkpoint loaded via :meth:`~pytorch_lightning.trainer.trainer.Trainer.fit`,
:meth:`~pytorch_lightning.trainer.trainer.Trainer.validate`,
:meth:`~pytorch_lightning.trainer.trainer.Trainer.test`, or
:meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`. ``None`` otherwise."""
return self._ckpt_path
return self._checkpoint_connector._ckpt_path

@ckpt_path.setter
def ckpt_path(self, ckpt_path: Optional[_PATH]) -> None:
"""Allows you to manage which checkpoint is loaded statefully.

Examples::

trainer = Trainer()
trainer.ckpt_path = "my/checkpoint/file.ckpt"
trainer.fit(model)
...

# you will be in charge of resetting this
trainer.ckpt_path = None
trainer.test(model)
"""
self._checkpoint_connector._ckpt_path = ckpt_path
self._checkpoint_connector._user_managed = bool(ckpt_path)

def save_checkpoint(
self, filepath: _PATH, weights_only: bool = False, storage_options: Optional[Any] = None
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/core/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:

for trainer_fn in TrainerFn:
trainer.state.fn = trainer_fn
trainer._restore_modules_and_callbacks(checkpoint_path)
trainer._checkpoint_connector._restore_modules_and_callbacks(checkpoint_path)
assert dm.my_state_dict == {"my": "state_dict"}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import os
from unittest import mock
from unittest.mock import Mock

import pytest
import torch
Expand All @@ -21,6 +22,7 @@
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.migration.utils import _set_version


def test_preloaded_checkpoint_lifecycle(tmpdir):
Expand All @@ -31,26 +33,27 @@ def test_preloaded_checkpoint_lifecycle(tmpdir):

connector = trainer._checkpoint_connector

assert not connector.resume_checkpoint_path
assert not connector._ckpt_path
assert not connector._loaded_checkpoint

connector.resume_start()
assert not connector.resume_checkpoint_path
assert not connector._ckpt_path
assert not connector._loaded_checkpoint
connector.resume_end()
assert not connector.resume_checkpoint_path
assert not connector._ckpt_path
assert not connector._loaded_checkpoint

ckpt_path = trainer.checkpoint_callback.best_model_path
trainer = Trainer(default_root_dir=tmpdir, max_steps=2)
connector = trainer._checkpoint_connector
connector.resume_start(ckpt_path)
assert connector.resume_checkpoint_path == ckpt_path
assert connector._ckpt_path == ckpt_path
assert connector._loaded_checkpoint
assert isinstance(connector._loaded_checkpoint, dict)
trainer.state.fn = TrainerFn.FITTING
connector.resume_end()
assert not connector.resume_checkpoint_path
# not cleared until next restoration, as the user might access it through `trainer.ckpt_path`
assert connector._ckpt_path == ckpt_path
assert not connector._loaded_checkpoint


Expand Down Expand Up @@ -166,3 +169,54 @@ def test_loops_restore(tmpdir):
if fn2 != fn:
trainer_loop2 = getattr(trainer, f"{fn2}_loop")
trainer_loop2.load_state_dict.assert_not_called()


def test_stateful_trainer_ckpt_path_support(tmp_path):
"""Tests support for the pattern used by NeMo's experiment manager."""
model = BoringModel()

# dummy ckpt data
ckpt_data = {"state_dict": model.state_dict(), "optimizer_states": {}, "lr_schedulers": {}}
_set_version(ckpt_data, "2.0.0")

# save a "checkpoint"
ckpt_path = tmp_path / "foo.ckpt"
torch.save(ckpt_data, ckpt_path)

# mock model checkpoint instance that has saved a last checkpoint
model_checkpoint = Mock(spec=ModelCheckpoint)
last_path = tmp_path / "last.ckpt"
torch.save(ckpt_data, last_path)
model_checkpoint._find_last_checkpoints.return_value = {last_path}

trainer = Trainer(default_root_dir=tmp_path, fast_dev_run=True, callbacks=model_checkpoint)

# set the ckpt path statefully
trainer.ckpt_path = ckpt_path
trainer.fit(model)
assert trainer.ckpt_path == ckpt_path # not automatically cleaned
assert trainer._checkpoint_connector._user_managed

# now conflict with ckpt_path functionally
with pytest.warns(UserWarning, match="trainer.ckpt_path =.*but then you passed"):
trainer.fit(model, ckpt_path="last")
assert trainer.ckpt_path == last_path
assert not trainer._checkpoint_connector._user_managed

# mock model checkpoint instance that has saved a last checkpoint
best_path = tmp_path / "best.ckpt"
torch.save(ckpt_data, best_path)
model_checkpoint.best_model_path = best_path

# `trainer.test` will use this over "best" if statefully set
trainer.ckpt_path = ckpt_path
trainer.test()
assert trainer.ckpt_path == ckpt_path

# ckpt_path = "best" still works if it's reset
trainer.ckpt_path = None
# the state is cleared
assert trainer._checkpoint_connector._ckpt_path is None
assert not trainer._checkpoint_connector._user_managed
trainer.test()
assert trainer.ckpt_path == best_path
Loading