Skip to content
Merged
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed checks for local file protocol due to fsspec changes in 2023.10.0 ([#19023](https://github.com/Lightning-AI/lightning/pull/19023))


- Fixed automatic detection of 'last.ckpt' files to respect the extension when filtering ([#17072](https://github.com/Lightning-AI/lightning/pull/17072))



## [2.1.2] - 2023-11-15

Expand Down
1 change: 1 addition & 0 deletions src/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,7 @@ def _find_last_checkpoints(self, trainer: "pl.Trainer") -> Set[str]:
os.path.normpath(p)
for p in self._fs.ls(ckpt_path, detail=False)
if self.CHECKPOINT_NAME_LAST in os.path.split(p)[1]
and os.path.split(p)[1].endswith(self.FILE_EXTENSION)
}
return set()

Expand Down
19 changes: 19 additions & 0 deletions tests/tests_pytorch/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1509,3 +1509,22 @@ def test_resume_and_old_checkpoint_files_remain(same_resume_folder, tmp_path):
else:
assert set(os.listdir(first)) == {"epoch=0-step=2.ckpt", "epoch=0-step=4.ckpt"} # no files deleted
assert set(os.listdir(second)) == {"epoch=0-step=6.ckpt", "epoch=0-step=8.ckpt"}


@pytest.mark.parametrize(
"folder_contents,expected",
[
([], []),
(["last"], []),
(["last", "last.ckpt"], ["last.ckpt"]),
(["log.txt", "last-v0.ckpt", "last-v1.ckpt"], ["last-v0.ckpt", "last-v1.ckpt"]),
],
)
def test_find_last_checkpoints(folder_contents, expected, tmp_path):
for file in folder_contents:
(tmp_path / file).touch()

trainer = Trainer()
callback = ModelCheckpoint(dirpath=tmp_path)
files = callback._find_last_checkpoints(trainer)
assert files == {str(tmp_path / p) for p in expected}