@@ -392,9 +392,6 @@ def __init__(
392
392
# default .predict() loop
393
393
self .predict_loop = PredictionLoop ()
394
394
395
- # set when a checkpoint is loaded via `Trainer.{fit,validate,test,predict}`.
396
- self ._ckpt_path : Optional [str ] = None
397
-
398
395
# init callbacks
399
396
# Declare attributes to be set in _callback_connector on_trainer_init
400
397
self ._callback_connector .on_trainer_init (
@@ -574,14 +571,13 @@ def _fit_impl(
574
571
model , train_dataloaders = train_dataloaders , val_dataloaders = val_dataloaders , datamodule = datamodule
575
572
)
576
573
577
- ckpt_path = ckpt_path
578
- self ._ckpt_path = self ._checkpoint_connector ._set_ckpt_path (
574
+ ckpt_path = self ._checkpoint_connector ._select_ckpt_path (
579
575
self .state .fn ,
580
576
ckpt_path ,
581
577
model_provided = True ,
582
578
model_connected = self .lightning_module is not None ,
583
579
)
584
- self ._run (model , ckpt_path = self . ckpt_path )
580
+ self ._run (model , ckpt_path = ckpt_path )
585
581
586
582
assert self .state .stopped
587
583
self .training = False
@@ -665,14 +661,10 @@ def _validate_impl(
665
661
# links data to the trainer
666
662
self ._data_connector .attach_data (model , val_dataloaders = dataloaders , datamodule = datamodule )
667
663
668
- self . _ckpt_path = self ._checkpoint_connector ._set_ckpt_path (
664
+ ckpt_path = self ._checkpoint_connector ._select_ckpt_path (
669
665
self .state .fn , ckpt_path , model_provided = model_provided , model_connected = self .lightning_module is not None
670
666
)
671
-
672
- self ._validated_ckpt_path = self .ckpt_path # TODO: remove in v1.8
673
-
674
- # run validate
675
- results = self ._run (model , ckpt_path = self .ckpt_path )
667
+ results = self ._run (model , ckpt_path = ckpt_path )
676
668
677
669
assert self .state .stopped
678
670
self .validating = False
@@ -758,14 +750,10 @@ def _test_impl(
758
750
# links data to the trainer
759
751
self ._data_connector .attach_data (model , test_dataloaders = dataloaders , datamodule = datamodule )
760
752
761
- self . _ckpt_path = self ._checkpoint_connector ._set_ckpt_path (
753
+ ckpt_path = self ._checkpoint_connector ._select_ckpt_path (
762
754
self .state .fn , ckpt_path , model_provided = model_provided , model_connected = self .lightning_module is not None
763
755
)
764
-
765
- self ._tested_ckpt_path = self .ckpt_path # TODO: remove in v1.8
766
-
767
- # run test
768
- results = self ._run (model , ckpt_path = self .ckpt_path )
756
+ results = self ._run (model , ckpt_path = ckpt_path )
769
757
770
758
assert self .state .stopped
771
759
self .testing = False
@@ -851,13 +839,10 @@ def _predict_impl(
851
839
# links data to the trainer
852
840
self ._data_connector .attach_data (model , predict_dataloaders = dataloaders , datamodule = datamodule )
853
841
854
- self . _ckpt_path = self ._checkpoint_connector ._set_ckpt_path (
842
+ ckpt_path = self ._checkpoint_connector ._select_ckpt_path (
855
843
self .state .fn , ckpt_path , model_provided = model_provided , model_connected = self .lightning_module is not None
856
844
)
857
-
858
- self ._predicted_ckpt_path = self .ckpt_path # TODO: remove in v1.8
859
-
860
- results = self ._run (model , ckpt_path = self .ckpt_path )
845
+ results = self ._run (model , ckpt_path = ckpt_path )
861
846
862
847
assert self .state .stopped
863
848
self .predicting = False
@@ -918,18 +903,8 @@ def tune(
918
903
919
904
return result
920
905
921
- def _restore_modules_and_callbacks (self , checkpoint_path : Optional [_PATH ] = None ) -> None :
922
- # restore modules after setup
923
- self ._checkpoint_connector .resume_start (checkpoint_path )
924
- self ._checkpoint_connector ._restore_quantization_callbacks ()
925
- self ._checkpoint_connector .restore_model ()
926
- self ._checkpoint_connector .restore_datamodule ()
927
- if self .state .fn == TrainerFn .FITTING :
928
- # restore callback states
929
- self ._checkpoint_connector .restore_callbacks ()
930
-
931
906
def _run (
932
- self , model : "pl.LightningModule" , ckpt_path : Optional [str ] = None
907
+ self , model : "pl.LightningModule" , ckpt_path : Optional [_PATH ] = None
933
908
) -> Optional [Union [_EVALUATE_OUTPUT , _PREDICT_OUTPUT ]]:
934
909
if model ._compiler_ctx is not None :
935
910
supported_strategies = [SingleDeviceStrategy , DDPStrategy , DDPFullyShardedNativeStrategy ]
@@ -978,7 +953,7 @@ def _run(
978
953
# check if we should delay restoring checkpoint till later
979
954
if not self .strategy .restore_checkpoint_after_setup :
980
955
log .detail (f"{ self .__class__ .__name__ } : restoring module and callbacks from checkpoint path: { ckpt_path } " )
981
- self ._restore_modules_and_callbacks (ckpt_path )
956
+ self ._checkpoint_connector . _restore_modules_and_callbacks (ckpt_path )
982
957
983
958
log .detail (f"{ self .__class__ .__name__ } : configuring sharded model" )
984
959
self ._call_configure_sharded_model () # allow user to setup in model sharded environment
@@ -1026,7 +1001,7 @@ def _run(
1026
1001
1027
1002
if self .strategy .restore_checkpoint_after_setup :
1028
1003
log .detail (f"{ self .__class__ .__name__ } : restoring module and callbacks from checkpoint path: { ckpt_path } " )
1029
- self ._restore_modules_and_callbacks (ckpt_path )
1004
+ self ._checkpoint_connector . _restore_modules_and_callbacks (ckpt_path )
1030
1005
1031
1006
# restore optimizers, etc.
1032
1007
log .detail (f"{ self .__class__ .__name__ } : restoring training state" )
@@ -1811,12 +1786,30 @@ def progress_bar_callback(self) -> Optional[ProgressBarBase]:
1811
1786
return None
1812
1787
1813
1788
@property
1814
- def ckpt_path (self ) -> Optional [str ]:
1789
+ def ckpt_path (self ) -> Optional [_PATH ]:
1815
1790
"""Set to the path/URL of a checkpoint loaded via :meth:`~pytorch_lightning.trainer.trainer.Trainer.fit`,
1816
1791
:meth:`~pytorch_lightning.trainer.trainer.Trainer.validate`,
1817
1792
:meth:`~pytorch_lightning.trainer.trainer.Trainer.test`, or
1818
1793
:meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`. ``None`` otherwise."""
1819
- return self ._ckpt_path
1794
+ return self ._checkpoint_connector ._ckpt_path
1795
+
1796
+ @ckpt_path .setter
1797
+ def ckpt_path (self , ckpt_path : Optional [_PATH ]) -> None :
1798
+ """Allows you to manage which checkpoint is loaded statefully.
1799
+
1800
+ Examples::
1801
+
1802
+ trainer = Trainer()
1803
+ trainer.ckpt_path = "my/checkpoint/file.ckpt"
1804
+ trainer.fit(model)
1805
+ ...
1806
+
1807
+ # you will be in charge of resetting this
1808
+ trainer.ckpt_path = None
1809
+ trainer.test(model)
1810
+ """
1811
+ self ._checkpoint_connector ._ckpt_path = ckpt_path
1812
+ self ._checkpoint_connector ._user_managed = bool (ckpt_path )
1820
1813
1821
1814
def save_checkpoint (
1822
1815
self , filepath : _PATH , weights_only : bool = False , storage_options : Optional [Any ] = None
0 commit comments