@@ -387,9 +387,6 @@ def __init__(
387
387
# default .predict() loop
388
388
self .predict_loop = PredictionLoop ()
389
389
390
- # set when a checkpoint is loaded via `Trainer.{fit,validate,test,predict}`.
391
- self ._ckpt_path : Optional [str ] = None
392
-
393
390
# init callbacks
394
391
# Declare attributes to be set in _callback_connector on_trainer_init
395
392
self ._callback_connector .on_trainer_init (
@@ -569,14 +566,13 @@ def _fit_impl(
569
566
model , train_dataloaders = train_dataloaders , val_dataloaders = val_dataloaders , datamodule = datamodule
570
567
)
571
568
572
- ckpt_path = ckpt_path
573
- self ._ckpt_path = self ._checkpoint_connector ._set_ckpt_path (
569
+ ckpt_path = self ._checkpoint_connector ._select_ckpt_path (
574
570
self .state .fn ,
575
571
ckpt_path ,
576
572
model_provided = True ,
577
573
model_connected = self .lightning_module is not None ,
578
574
)
579
- self ._run (model , ckpt_path = self . ckpt_path )
575
+ self ._run (model , ckpt_path = ckpt_path )
580
576
581
577
assert self .state .stopped
582
578
self .training = False
@@ -660,14 +656,10 @@ def _validate_impl(
660
656
# links data to the trainer
661
657
self ._data_connector .attach_data (model , val_dataloaders = dataloaders , datamodule = datamodule )
662
658
663
- self . _ckpt_path = self ._checkpoint_connector ._set_ckpt_path (
659
+ ckpt_path = self ._checkpoint_connector ._select_ckpt_path (
664
660
self .state .fn , ckpt_path , model_provided = model_provided , model_connected = self .lightning_module is not None
665
661
)
666
-
667
- self ._validated_ckpt_path = self .ckpt_path # TODO: remove in v1.8
668
-
669
- # run validate
670
- results = self ._run (model , ckpt_path = self .ckpt_path )
662
+ results = self ._run (model , ckpt_path = ckpt_path )
671
663
672
664
assert self .state .stopped
673
665
self .validating = False
@@ -753,14 +745,10 @@ def _test_impl(
753
745
# links data to the trainer
754
746
self ._data_connector .attach_data (model , test_dataloaders = dataloaders , datamodule = datamodule )
755
747
756
- self . _ckpt_path = self ._checkpoint_connector ._set_ckpt_path (
748
+ ckpt_path = self ._checkpoint_connector ._select_ckpt_path (
757
749
self .state .fn , ckpt_path , model_provided = model_provided , model_connected = self .lightning_module is not None
758
750
)
759
-
760
- self ._tested_ckpt_path = self .ckpt_path # TODO: remove in v1.8
761
-
762
- # run test
763
- results = self ._run (model , ckpt_path = self .ckpt_path )
751
+ results = self ._run (model , ckpt_path = ckpt_path )
764
752
765
753
assert self .state .stopped
766
754
self .testing = False
@@ -846,13 +834,10 @@ def _predict_impl(
846
834
# links data to the trainer
847
835
self ._data_connector .attach_data (model , predict_dataloaders = dataloaders , datamodule = datamodule )
848
836
849
- self . _ckpt_path = self ._checkpoint_connector ._set_ckpt_path (
837
+ ckpt_path = self ._checkpoint_connector ._select_ckpt_path (
850
838
self .state .fn , ckpt_path , model_provided = model_provided , model_connected = self .lightning_module is not None
851
839
)
852
-
853
- self ._predicted_ckpt_path = self .ckpt_path # TODO: remove in v1.8
854
-
855
- results = self ._run (model , ckpt_path = self .ckpt_path )
840
+ results = self ._run (model , ckpt_path = ckpt_path )
856
841
857
842
assert self .state .stopped
858
843
self .predicting = False
@@ -913,18 +898,8 @@ def tune(
913
898
914
899
return result
915
900
916
- def _restore_modules_and_callbacks (self , checkpoint_path : Optional [_PATH ] = None ) -> None :
917
- # restore modules after setup
918
- self ._checkpoint_connector .resume_start (checkpoint_path )
919
- self ._checkpoint_connector ._restore_quantization_callbacks ()
920
- self ._checkpoint_connector .restore_model ()
921
- self ._checkpoint_connector .restore_datamodule ()
922
- if self .state .fn == TrainerFn .FITTING :
923
- # restore callback states
924
- self ._checkpoint_connector .restore_callbacks ()
925
-
926
901
def _run (
927
- self , model : "pl.LightningModule" , ckpt_path : Optional [str ] = None
902
+ self , model : "pl.LightningModule" , ckpt_path : Optional [_PATH ] = None
928
903
) -> Optional [Union [_EVALUATE_OUTPUT , _PREDICT_OUTPUT ]]:
929
904
if model ._compiler_ctx is not None :
930
905
supported_strategies = [SingleDeviceStrategy , DDPStrategy , DDPFullyShardedNativeStrategy ]
@@ -973,7 +948,7 @@ def _run(
973
948
# check if we should delay restoring checkpoint till later
974
949
if not self .strategy .restore_checkpoint_after_setup :
975
950
log .detail (f"{ self .__class__ .__name__ } : restoring module and callbacks from checkpoint path: { ckpt_path } " )
976
- self ._restore_modules_and_callbacks (ckpt_path )
951
+ self ._checkpoint_connector . _restore_modules_and_callbacks (ckpt_path )
977
952
978
953
log .detail (f"{ self .__class__ .__name__ } : configuring sharded model" )
979
954
self ._call_configure_sharded_model () # allow user to setup in model sharded environment
@@ -1021,7 +996,7 @@ def _run(
1021
996
1022
997
if self .strategy .restore_checkpoint_after_setup :
1023
998
log .detail (f"{ self .__class__ .__name__ } : restoring module and callbacks from checkpoint path: { ckpt_path } " )
1024
- self ._restore_modules_and_callbacks (ckpt_path )
999
+ self ._checkpoint_connector . _restore_modules_and_callbacks (ckpt_path )
1025
1000
1026
1001
# restore optimizers, etc.
1027
1002
log .detail (f"{ self .__class__ .__name__ } : restoring training state" )
@@ -1806,12 +1781,30 @@ def progress_bar_callback(self) -> Optional[ProgressBarBase]:
1806
1781
return None
1807
1782
1808
1783
@property
1809
- def ckpt_path (self ) -> Optional [str ]:
1784
+ def ckpt_path (self ) -> Optional [_PATH ]:
1810
1785
"""Set to the path/URL of a checkpoint loaded via :meth:`~pytorch_lightning.trainer.trainer.Trainer.fit`,
1811
1786
:meth:`~pytorch_lightning.trainer.trainer.Trainer.validate`,
1812
1787
:meth:`~pytorch_lightning.trainer.trainer.Trainer.test`, or
1813
1788
:meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`. ``None`` otherwise."""
1814
- return self ._ckpt_path
1789
+ return self ._checkpoint_connector ._ckpt_path
1790
+
1791
+ @ckpt_path .setter
1792
+ def ckpt_path (self , ckpt_path : Optional [_PATH ]) -> None :
1793
+ """Allows you to manage which checkpoint is loaded statefully.
1794
+
1795
+ Examples::
1796
+
1797
+ trainer = Trainer()
1798
+ trainer.ckpt_path = "my/checkpoint/file.ckpt"
1799
+ trainer.fit(model)
1800
+ ...
1801
+
1802
+ # you will be in charge of resetting this
1803
+ trainer.ckpt_path = None
1804
+ trainer.test(model)
1805
+ """
1806
+ self ._checkpoint_connector ._ckpt_path = ckpt_path
1807
+ self ._checkpoint_connector ._user_managed = bool (ckpt_path )
1815
1808
1816
1809
def save_checkpoint (
1817
1810
self , filepath : _PATH , weights_only : bool = False , storage_options : Optional [Any ] = None
0 commit comments