Skip to content

Commit 67d3844

Browse files
Fix last checkpoint finding in filtered files with correct extension (#17072)
Co-authored-by: awaelchli <[email protected]>
1 parent d4614d0 commit 67d3844

File tree

3 files changed

+34
-5
lines changed

3 files changed

+34
-5
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4747
- Fixed checks for local file protocol due to fsspec changes in 2023.10.0 ([#19023](https://github.com/Lightning-AI/lightning/pull/19023))
4848

4949

50+
- Fixed automatic detection of 'last.ckpt' files to respect the extension when filtering ([#17072](https://github.com/Lightning-AI/lightning/pull/17072))
51+
52+
5053

5154
## [2.1.2] - 2023-11-15
5255

src/lightning/pytorch/callbacks/model_checkpoint.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -626,12 +626,13 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> _PATH:
626626
def _find_last_checkpoints(self, trainer: "pl.Trainer") -> Set[str]:
627627
# find all checkpoints in the folder
628628
ckpt_path = self.__resolve_ckpt_dir(trainer)
629+
last_pattern = rf"^{self.CHECKPOINT_NAME_LAST}(-(\d+))?"
630+
631+
def _is_last(path: Path) -> bool:
632+
return path.suffix == self.FILE_EXTENSION and bool(re.match(last_pattern, path.stem))
633+
629634
if self._fs.exists(ckpt_path):
630-
return {
631-
os.path.normpath(p)
632-
for p in self._fs.ls(ckpt_path, detail=False)
633-
if self.CHECKPOINT_NAME_LAST in os.path.split(p)[1]
634-
}
635+
return {os.path.normpath(p) for p in self._fs.ls(ckpt_path, detail=False) if _is_last(Path(p))}
635636
return set()
636637

637638
def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None:

tests/tests_pytorch/checkpointing/test_model_checkpoint.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1509,3 +1509,28 @@ def test_resume_and_old_checkpoint_files_remain(same_resume_folder, tmp_path):
15091509
else:
15101510
assert set(os.listdir(first)) == {"epoch=0-step=2.ckpt", "epoch=0-step=4.ckpt"} # no files deleted
15111511
assert set(os.listdir(second)) == {"epoch=0-step=6.ckpt", "epoch=0-step=8.ckpt"}
1512+
1513+
1514+
@pytest.mark.parametrize(
1515+
("name", "extension", "folder_contents", "expected"),
1516+
[
1517+
("last", ".ckpt", {}, {}),
1518+
("any", ".any", {}, {}),
1519+
("last", ".ckpt", {"last"}, {}),
1520+
("any", ".any", {"last"}, {}),
1521+
("last", ".ckpt", {"last", "last.ckpt"}, {"last.ckpt"}),
1522+
("other", ".pt", {"last", "last.pt", "other.pt"}, {"other.pt"}),
1523+
("last", ".ckpt", {"log.txt", "last-v0.ckpt", "last-v1.ckpt"}, {"last-v0.ckpt", "last-v1.ckpt"}),
1524+
("other", ".pt", {"log.txt", "last-v0.ckpt", "other-v0.pt", "other-v1.pt"}, {"other-v0.pt", "other-v1.pt"}),
1525+
],
1526+
)
1527+
def test_find_last_checkpoints(name, extension, folder_contents, expected, tmp_path):
1528+
for file in folder_contents:
1529+
(tmp_path / file).touch()
1530+
1531+
trainer = Trainer()
1532+
callback = ModelCheckpoint(dirpath=tmp_path)
1533+
callback.CHECKPOINT_NAME_LAST = name
1534+
callback.FILE_EXTENSION = extension
1535+
files = callback._find_last_checkpoints(trainer)
1536+
assert files == {str(tmp_path / p) for p in expected}

0 commit comments

Comments
 (0)