Skip to content

Commit 4f0e6ad

Browse files
SeanNarenBorda
authored andcommitted
Add function to remove checkpoint to allow override for extended classes (#16067)
(cherry picked from commit 10cc677)
1 parent e5d5901 commit 4f0e6ad

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

src/pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -640,7 +640,7 @@ def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[
640640
previous, self.last_model_path = self.last_model_path, filepath
641641
self._save_checkpoint(trainer, filepath)
642642
if previous and previous != filepath:
643-
trainer.strategy.remove_checkpoint(previous)
643+
self._remove_checkpoint(trainer, previous)
644644

645645
def _save_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None:
646646
assert self.monitor
@@ -659,7 +659,7 @@ def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidate
659659
previous, self.best_model_path = self.best_model_path, filepath
660660
self._save_checkpoint(trainer, filepath)
661661
if self.save_top_k == 1 and previous and previous != filepath:
662-
trainer.strategy.remove_checkpoint(previous)
662+
self._remove_checkpoint(trainer, previous)
663663

664664
def _update_best_and_save(
665665
self, current: Tensor, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]
@@ -701,7 +701,7 @@ def _update_best_and_save(
701701
self._save_checkpoint(trainer, filepath)
702702

703703
if del_filepath is not None and filepath != del_filepath:
704-
trainer.strategy.remove_checkpoint(del_filepath)
704+
self._remove_checkpoint(trainer, del_filepath)
705705

706706
def to_yaml(self, filepath: Optional[_PATH] = None) -> None:
707707
"""Saves the `best_k_models` dict containing the checkpoint paths with the corresponding scores to a YAML
@@ -718,3 +718,7 @@ def file_exists(self, filepath: _PATH, trainer: "pl.Trainer") -> bool:
718718
state to diverge between ranks."""
719719
exists = self._fs.exists(filepath)
720720
return trainer.strategy.broadcast(exists)
721+
722+
def _remove_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
723+
"""Calls the strategy to remove the checkpoint file."""
724+
trainer.strategy.remove_checkpoint(filepath)

0 commit comments

Comments
 (0)