Skip to content

Commit dcfaa06

Browse files
authored
Improve the checkpoint upgrade utility script (#15333)
1 parent fe8488d commit dcfaa06

File tree

9 files changed

+219
-36
lines changed

9 files changed

+219
-36
lines changed

docs/source-pytorch/common/checkpointing.rst

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,33 +12,41 @@ Checkpointing
1212
.. Add callout items below this line
1313
1414
.. displayitem::
15-
:header: Basic
15+
:header: Saving and loading checkpoints
1616
:description: Learn to save and load checkpoints
17-
:col_css: col-md-3
17+
:col_css: col-md-4
1818
:button_link: checkpointing_basic.html
1919
:height: 150
2020
:tag: basic
2121

2222
.. displayitem::
23-
:header: Intermediate
24-
:description: Customize checkpointing behavior
25-
:col_css: col-md-3
23+
:header: Customize checkpointing behavior
24+
:description: Learn how to change the behavior of checkpointing
25+
:col_css: col-md-4
2626
:button_link: checkpointing_intermediate.html
2727
:height: 150
2828
:tag: intermediate
2929

3030
.. displayitem::
31-
:header: Advanced
31+
:header: Upgrading checkpoints
32+
:description: Learn how to upgrade old checkpoints to the newest Lightning version
33+
:col_css: col-md-4
34+
:button_link: checkpointing_migration.html
35+
:height: 150
36+
:tag: intermediate
37+
38+
.. displayitem::
39+
:header: Cloud-based checkpoints
3240
:description: Enable cloud-based checkpointing and composable checkpoints.
33-
:col_css: col-md-3
41+
:col_css: col-md-4
3442
:button_link: checkpointing_advanced.html
3543
:height: 150
3644
:tag: advanced
3745

3846
.. displayitem::
39-
:header: Expert
47+
:header: Distributed checkpoints
4048
:description: Customize checkpointing for custom distributed strategies and accelerators.
41-
:col_css: col-md-3
49+
:col_css: col-md-4
4250
:button_link: checkpointing_expert.html
4351
:height: 150
4452
:tag: expert

docs/source-pytorch/common/checkpointing_advanced.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
.. _checkpointing_advanced:
22

3-
########################
4-
Checkpointing (advanced)
5-
########################
3+
##################################
4+
Cloud-based checkpoints (advanced)
5+
##################################
66

77

88
*****************

docs/source-pytorch/common/checkpointing_basic.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
.. _checkpointing_basic:
44

5-
#####################
6-
Checkpointing (basic)
7-
#####################
5+
######################################
6+
Saving and loading checkpoints (basic)
7+
######################################
88
**Audience:** All users
99

1010
----

docs/source-pytorch/common/checkpointing_expert.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
.. _checkpointing_expert:
44

5-
######################
6-
Checkpointing (expert)
7-
######################
5+
################################
6+
Distributed checkpoints (expert)
7+
################################
88

99
*********************************
1010
Writing your own Checkpoint class

docs/source-pytorch/common/checkpointing_intermediate.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
:orphan:
22

3-
.. _checkpointing_intermediate:
3+
.. _checkpointing_intermediate_1:
44

5-
############################
6-
Checkpointing (intermediate)
7-
############################
5+
###############################################
6+
Customize checkpointing behavior (intermediate)
7+
###############################################
88
**Audience:** Users looking to customize the checkpointing behavior
99

1010
----
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
:orphan:
2+
3+
.. _checkpointing_intermediate_2:
4+
5+
####################################
6+
Upgrading checkpoints (intermediate)
7+
####################################
8+
**Audience:** Users who are upgrading Lightning and their code and want to reuse their old checkpoints.
9+
10+
----
11+
12+
**************************************
13+
Resume training from an old checkpoint
14+
**************************************
15+
16+
Next to the model weights and trainer state, a Lightning checkpoint contains the version number of Lightning with which the checkpoint was saved.
17+
When you load a checkpoint file, either by resuming training
18+
19+
.. code-block:: python
20+
21+
trainer = Trainer(...)
22+
trainer.fit(model, ckpt_path="path/to/checkpoint.ckpt")
23+
24+
or by loading the state directly into your model,
25+
26+
.. code-block:: python
27+
28+
model = LitModel.load_from_checkpoint("path/to/checkpoint.ckpt")
29+
30+
Lightning will automatically recognize that it is from an older version and migrates the internal structure so it can be loaded properly.
31+
This is done without any action required by the user.
32+
33+
----
34+
35+
************************************
36+
Upgrade checkpoint files permanently
37+
************************************
38+
39+
When Lightning loads a checkpoint, it applies the version migration on-the-fly as explained above, but it does not modify your checkpoint files.
40+
You can upgrade checkpoint files permanently with the following command
41+
42+
.. code-block::
43+
44+
python -m pytorch_lightning.utilities.upgrade_checkpoint path/to/model.ckpt
45+
46+
47+
or a folder with multiple files:
48+
49+
.. code-block::
50+
51+
python -m pytorch_lightning.utilities.upgrade_checkpoint /path/to/checkpoints/folder

src/pytorch_lightning/CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1212

1313
- Added utilities to migrate checkpoints from one Lightning version to another ([#15237](https://github.com/Lightning-AI/lightning/pull/15237))
1414

15-
-
15+
- Added support to upgrade all checkpoints in a folder using the `pl.utilities.upgrade_checkpoint` script ([#15333](https://github.com/Lightning-AI/lightning/pull/15333))
1616

1717
-
1818

src/pytorch_lightning/utilities/upgrade_checkpoint.py

Lines changed: 63 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,30 +11,80 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import argparse
14+
import glob
1515
import logging
16+
from argparse import ArgumentParser, Namespace
17+
from pathlib import Path
1618
from shutil import copyfile
19+
from typing import List
1720

1821
import torch
22+
from tqdm import tqdm
1923

2024
from pytorch_lightning.utilities.migration import migrate_checkpoint, pl_legacy_patch
2125

22-
log = logging.getLogger(__name__)
26+
_log = logging.getLogger(__name__)
2327

24-
if __name__ == "__main__":
2528

26-
parser = argparse.ArgumentParser(
29+
def _upgrade(args: Namespace) -> None:
30+
path = Path(args.path).absolute()
31+
extension: str = args.extension if args.extension.startswith(".") else f".{args.extension}"
32+
files: List[Path] = []
33+
34+
if not path.exists():
35+
_log.error(
36+
f"The path {path} does not exist. Please provide a valid path to a checkpoint file or a directory"
37+
f" containing checkpoints ending in {extension}."
38+
)
39+
exit(1)
40+
41+
if path.is_file():
42+
files = [path]
43+
if path.is_dir():
44+
files = [Path(p) for p in glob.glob(str(path / "**" / f"*{extension}"), recursive=True)]
45+
if not files:
46+
_log.error(
47+
f"No checkpoint files with extension {extension} were found in {path}."
48+
f" HINT: Try setting the `--extension` option to specify the right file extension to look for."
49+
)
50+
exit(1)
51+
52+
_log.info("Creating a backup of the existing checkpoint files before overwriting in the upgrade process.")
53+
for file in files:
54+
backup_file = file.with_suffix(".bak")
55+
if backup_file.exists():
56+
# never overwrite backup files - they are the original, untouched checkpoints
57+
continue
58+
copyfile(file, backup_file)
59+
60+
_log.info("Upgrading checkpoints ...")
61+
for file in tqdm(files):
62+
with pl_legacy_patch():
63+
checkpoint = torch.load(file)
64+
migrate_checkpoint(checkpoint)
65+
torch.save(checkpoint, file)
66+
67+
_log.info("Done.")
68+
69+
70+
def main() -> None:
71+
parser = ArgumentParser(
2772
description=(
28-
"Upgrade an old checkpoint to the current schema. This will also save a backup of the original file."
73+
"A utility to upgrade old checkpoints to the format of the current Lightning version."
74+
" This will also save a backup of the original files."
2975
)
3076
)
31-
parser.add_argument("--file", help="filepath for a checkpoint to upgrade")
32-
77+
parser.add_argument("path", type=str, help="Path to a checkpoint file or a directory with checkpoints to upgrade")
78+
parser.add_argument(
79+
"--extension",
80+
"-e",
81+
type=str,
82+
default=".ckpt",
83+
help="The file extension to look for when searching for checkpoint files in a directory.",
84+
)
3385
args = parser.parse_args()
86+
_upgrade(args)
87+
3488

35-
log.info("Creating a backup of the existing checkpoint file before overwriting in the upgrade process.")
36-
copyfile(args.file, args.file + ".bak")
37-
with pl_legacy_patch():
38-
checkpoint = torch.load(args.file)
39-
migrate_checkpoint(checkpoint)
40-
torch.save(checkpoint, args.file)
89+
if __name__ == "__main__":
90+
main()

tests/tests_pytorch/utilities/test_upgrade_checkpoint.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,18 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import logging
15+
from pathlib import Path
16+
from unittest import mock
17+
from unittest.mock import ANY
18+
1419
import pytest
1520

1621
import pytorch_lightning as pl
1722
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
1823
from pytorch_lightning.utilities.migration import migrate_checkpoint
1924
from pytorch_lightning.utilities.migration.utils import _get_version, _set_legacy_version, _set_version
25+
from pytorch_lightning.utilities.upgrade_checkpoint import main as upgrade_main
2026

2127

2228
@pytest.mark.parametrize(
@@ -47,3 +53,71 @@ def test_upgrade_checkpoint(tmpdir, old_checkpoint, new_checkpoint):
4753
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint)
4854
assert updated_checkpoint == old_checkpoint == new_checkpoint
4955
assert _get_version(updated_checkpoint) == pl.__version__
56+
57+
58+
def test_upgrade_checkpoint_file_missing(tmp_path, caplog):
59+
# path to single file (missing)
60+
file = tmp_path / "checkpoint.ckpt"
61+
with mock.patch("sys.argv", ["upgrade_checkpoint.py", str(file)]):
62+
with caplog.at_level(logging.ERROR):
63+
with pytest.raises(SystemExit):
64+
upgrade_main()
65+
assert f"The path {file} does not exist" in caplog.text
66+
67+
caplog.clear()
68+
69+
# path to non-empty directory, but no checkpoints with matching extension
70+
file.touch()
71+
with mock.patch("sys.argv", ["upgrade_checkpoint.py", str(tmp_path), "--extension", ".other"]):
72+
with caplog.at_level(logging.ERROR):
73+
with pytest.raises(SystemExit):
74+
upgrade_main()
75+
assert "No checkpoint files with extension .other were found" in caplog.text
76+
77+
78+
@mock.patch("pytorch_lightning.utilities.upgrade_checkpoint.torch.save")
79+
@mock.patch("pytorch_lightning.utilities.upgrade_checkpoint.torch.load")
80+
@mock.patch("pytorch_lightning.utilities.upgrade_checkpoint.migrate_checkpoint")
81+
def test_upgrade_checkpoint_single_file(migrate_mock, load_mock, save_mock, tmp_path):
82+
file = tmp_path / "checkpoint.ckpt"
83+
file.touch()
84+
with mock.patch("sys.argv", ["upgrade_checkpoint.py", str(file)]):
85+
upgrade_main()
86+
87+
load_mock.assert_called_once_with(Path(file))
88+
migrate_mock.assert_called_once()
89+
save_mock.assert_called_once_with(ANY, Path(file))
90+
91+
92+
@mock.patch("pytorch_lightning.utilities.upgrade_checkpoint.torch.save")
93+
@mock.patch("pytorch_lightning.utilities.upgrade_checkpoint.torch.load")
94+
@mock.patch("pytorch_lightning.utilities.upgrade_checkpoint.migrate_checkpoint")
95+
def test_upgrade_checkpoint_directory(migrate_mock, load_mock, save_mock, tmp_path):
96+
top_files = [tmp_path / "top0.ckpt", tmp_path / "top1.ckpt"]
97+
nested_files = [
98+
tmp_path / "subdir0" / "nested0.ckpt",
99+
tmp_path / "subdir0" / "nested1.other",
100+
tmp_path / "subdir1" / "nested2.ckpt",
101+
]
102+
103+
for file in top_files + nested_files:
104+
file.parent.mkdir(exist_ok=True, parents=True)
105+
file.touch()
106+
107+
# directory with recursion
108+
with mock.patch("sys.argv", ["upgrade_checkpoint.py", str(tmp_path)]):
109+
upgrade_main()
110+
111+
assert {c[0][0] for c in load_mock.call_args_list} == {
112+
tmp_path / "top0.ckpt",
113+
tmp_path / "top1.ckpt",
114+
tmp_path / "subdir0" / "nested0.ckpt",
115+
tmp_path / "subdir1" / "nested2.ckpt",
116+
}
117+
assert migrate_mock.call_count == 4
118+
assert {c[0][1] for c in save_mock.call_args_list} == {
119+
tmp_path / "top0.ckpt",
120+
tmp_path / "top1.ckpt",
121+
tmp_path / "subdir0" / "nested0.ckpt",
122+
tmp_path / "subdir1" / "nested2.ckpt",
123+
}

0 commit comments

Comments
 (0)