Skip to content
Merged
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed checks for local file protocol due to fsspec changes in 2023.10.0 ([#19023](https://github.com/Lightning-AI/lightning/pull/19023))


- Fixed automatic detection of 'last.ckpt' files to respect the extension when filtering ([#17072](https://github.com/Lightning-AI/lightning/pull/17072))



## [2.1.2] - 2023-11-15

Expand Down
11 changes: 6 additions & 5 deletions src/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,12 +626,13 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> _PATH:
def _find_last_checkpoints(self, trainer: "pl.Trainer") -> Set[str]:
# find all checkpoints in the folder
ckpt_path = self.__resolve_ckpt_dir(trainer)
last_pattern = rf"^{self.CHECKPOINT_NAME_LAST}(-(\d+))?"

def _is_last(path: Path) -> bool:
return path.suffix == self.FILE_EXTENSION and bool(re.match(last_pattern, path.stem))

if self._fs.exists(ckpt_path):
return {
os.path.normpath(p)
for p in self._fs.ls(ckpt_path, detail=False)
if self.CHECKPOINT_NAME_LAST in os.path.split(p)[1]
}
return {os.path.normpath(p) for p in self._fs.ls(ckpt_path, detail=False) if _is_last(Path(p))}
return set()

def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None:
Expand Down
25 changes: 25 additions & 0 deletions tests/tests_pytorch/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1509,3 +1509,28 @@ def test_resume_and_old_checkpoint_files_remain(same_resume_folder, tmp_path):
else:
assert set(os.listdir(first)) == {"epoch=0-step=2.ckpt", "epoch=0-step=4.ckpt"} # no files deleted
assert set(os.listdir(second)) == {"epoch=0-step=6.ckpt", "epoch=0-step=8.ckpt"}


@pytest.mark.parametrize(
("name", "extension", "folder_contents", "expected"),
[
("last", ".ckpt", {}, {}),
("any", ".any", {}, {}),
("last", ".ckpt", {"last"}, {}),
("any", ".any", {"last"}, {}),
("last", ".ckpt", {"last", "last.ckpt"}, {"last.ckpt"}),
("other", ".pt", {"last", "last.pt", "other.pt"}, {"other.pt"}),
("last", ".ckpt", {"log.txt", "last-v0.ckpt", "last-v1.ckpt"}, {"last-v0.ckpt", "last-v1.ckpt"}),
("other", ".pt", {"log.txt", "last-v0.ckpt", "other-v0.pt", "other-v1.pt"}, {"other-v0.pt", "other-v1.pt"}),
],
)
def test_find_last_checkpoints(name, extension, folder_contents, expected, tmp_path):
for file in folder_contents:
(tmp_path / file).touch()

trainer = Trainer()
callback = ModelCheckpoint(dirpath=tmp_path)
callback.CHECKPOINT_NAME_LAST = name
callback.FILE_EXTENSION = extension
files = callback._find_last_checkpoints(trainer)
assert files == {str(tmp_path / p) for p in expected}