Skip to content

Commit 87b11fb

Browse files
awaelchlicarmoccapre-commit-ci[bot]Borda
authored
add legacy load utility (#9166)
Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <[email protected]>
1 parent 491e4a2 commit 87b11fb

File tree

9 files changed

+102
-47
lines changed

9 files changed

+102
-47
lines changed

CHANGELOG.md

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
148148
- Added `RichModelSummary` callback ([#9546](https://github.com/PyTorchLightning/pytorch-lightning/pull/9546))
149149

150150

151+
- Added `pl_legacy_patch` load utility for loading old checkpoints that have pickled legacy Lightning attributes ([#9166](https://github.com/PyTorchLightning/pytorch-lightning/pull/9166))
152+
153+
151154
### Changed
152155

153156
- `pytorch_lightning.loggers.neptune.NeptuneLogger` is now consistent with new [neptune-client](https://github.com/neptune-ai/neptune-client) API ([#6867](https://github.com/PyTorchLightning/pytorch-lightning/pull/6867)).
@@ -242,9 +245,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
242245
- Deprecated `on_{train/val/test/predict}_dataloader()` from `LightningModule` and `LightningDataModule` [#9098](https://github.com/PyTorchLightning/pytorch-lightning/pull/9098)
243246

244247

245-
- Updated deprecation of `argparse_utils.py` from removal in 1.4 to 2.0 ([#9162](https://github.com/PyTorchLightning/pytorch-lightning/pull/9162))
246-
247-
248248
- Deprecated `on_keyboard_interrupt` callback hook in favor of new `on_exception` hook ([#9260](https://github.com/PyTorchLightning/pytorch-lightning/pull/9260))
249249

250250

@@ -325,6 +325,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
325325
- Removed deprecated `profiled_functions` argument from `PyTorchProfiler` ([#9178](https://github.com/PyTorchLightning/pytorch-lightning/pull/9178))
326326

327327

328+
- Removed deprecated `pytorch_lighting.utilities.argparse_utils` module ([#9166](https://github.com/PyTorchLightning/pytorch-lightning/pull/9166))
329+
330+
328331
- Removed deprecated property `Trainer.running_sanity_check` in favor of `Trainer.sanity_checking` ([#9209](https://github.com/PyTorchLightning/pytorch-lightning/pull/9209))
329332

330333

pytorch_lightning/core/saving.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from pytorch_lightning.utilities.apply_func import apply_to_collection
3030
from pytorch_lightning.utilities.cloud_io import get_filesystem
3131
from pytorch_lightning.utilities.cloud_io import load as pl_load
32+
from pytorch_lightning.utilities.migration import pl_legacy_patch
3233
from pytorch_lightning.utilities.parsing import parse_class_init_keys
3334

3435
log = logging.getLogger(__name__)
@@ -125,10 +126,11 @@ def load_from_checkpoint(
125126
pretrained_model.freeze()
126127
y_hat = pretrained_model(x)
127128
"""
128-
if map_location is not None:
129-
checkpoint = pl_load(checkpoint_path, map_location=map_location)
130-
else:
131-
checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)
129+
with pl_legacy_patch():
130+
if map_location is not None:
131+
checkpoint = pl_load(checkpoint_path, map_location=map_location)
132+
else:
133+
checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)
132134

133135
if hparams_file is not None:
134136
extension = hparams_file.split(".")[-1]

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
2727
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2828
from pytorch_lightning.utilities.imports import _fault_tolerant_training
29+
from pytorch_lightning.utilities.migration import pl_legacy_patch
2930
from pytorch_lightning.utilities.types import _PATH
3031
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS
3132

@@ -65,7 +66,8 @@ def resume_start(self) -> None:
6566
self._loaded_checkpoint = self._load_and_validate_checkpoint(checkpoint_path)
6667

6768
def _load_and_validate_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
68-
loaded_checkpoint = self.trainer.training_type_plugin.load_checkpoint(checkpoint_path)
69+
with pl_legacy_patch():
70+
loaded_checkpoint = self.trainer.training_type_plugin.load_checkpoint(checkpoint_path)
6971
if any(key in loaded_checkpoint for key in DEPRECATED_CHECKPOINT_KEYS):
7072
raise ValueError(
7173
"The checkpoint you're attempting to load follows an"

pytorch_lightning/utilities/argparse.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -295,13 +295,6 @@ def _gpus_allowed_type(x: str) -> Union[int, str]:
295295
return int(x)
296296

297297

298-
def _gpus_arg_default(x: str) -> Union[int, str]: # pragma: no-cover
299-
# unused, but here for backward compatibility with old checkpoints that need to be able to
300-
# unpickle the function from the checkpoint, as it was not filtered out in versions < 1.2.8
301-
# see: https://github.com/PyTorchLightning/pytorch-lightning/pull/6898
302-
pass
303-
304-
305298
def _int_or_float_type(x: Union[int, float, str]) -> Union[int, float]:
306299
if "." in str(x):
307300
return float(x)

pytorch_lightning/utilities/argparse_utils.py

Lines changed: 0 additions & 7 deletions
This file was deleted.
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import sys
15+
from types import ModuleType
16+
17+
import pytorch_lightning.utilities.argparse
18+
19+
20+
class pl_legacy_patch:
21+
"""Registers legacy artifacts (classes, methods, etc.) that were removed but still need to be included for
22+
unpickling old checkpoints. The following patches apply.
23+
24+
1. ``pytorch_lightning.utilities.argparse._gpus_arg_default``: Applies to all checkpoints saved prior to
25+
version 1.2.8. See: https://github.com/PyTorchLightning/pytorch-lightning/pull/6898
26+
2. ``pytorch_lightning.utilities.argparse_utils``: A module that was deprecated in 1.2 and removed in 1.4,
27+
but still needs to be available for import for legacy checkpoints.
28+
29+
Example:
30+
31+
with pl_legacy_patch():
32+
torch.load("path/to/legacy/checkpoint.ckpt")
33+
"""
34+
35+
def __enter__(self):
36+
# `pl.utilities.argparse_utils` was renamed to `pl.utilities.argparse`
37+
legacy_argparse_module = ModuleType("pytorch_lightning.utilities.argparse_utils")
38+
sys.modules["pytorch_lightning.utilities.argparse_utils"] = legacy_argparse_module
39+
40+
# `_gpus_arg_default` used to be imported from these locations
41+
legacy_argparse_module._gpus_arg_default = lambda x: x
42+
pytorch_lightning.utilities.argparse._gpus_arg_default = lambda x: x
43+
return self
44+
45+
def __exit__(self, exc_type, exc_value, exc_traceback):
46+
if hasattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default"):
47+
delattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default")
48+
del sys.modules["pytorch_lightning.utilities.argparse_utils"]

pytorch_lightning/utilities/upgrade_checkpoint.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch
1919

2020
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
21+
from pytorch_lightning.utilities.migration import pl_legacy_patch
2122

2223
KEYS_MAPPING = {
2324
"checkpoint_callback_best_model_score": (ModelCheckpoint, "best_model_score"),
@@ -58,4 +59,5 @@ def upgrade_checkpoint(filepath):
5859

5960
log.info("Creating a backup of the existing checkpoint file before overwriting in the upgrade process.")
6061
copyfile(args.file, args.file + ".bak")
61-
upgrade_checkpoint(args.file)
62+
with pl_legacy_patch():
63+
upgrade_checkpoint(args.file)

tests/deprecated_api/test_remove_2-0.py

Lines changed: 0 additions & 24 deletions
This file was deleted.

tests/utilities/test_migration.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import sys
15+
16+
import pytorch_lightning
17+
from pytorch_lightning.utilities.migration import pl_legacy_patch
18+
19+
20+
def test_patch_legacy_argparse_utils():
21+
with pl_legacy_patch():
22+
from pytorch_lightning.utilities import argparse_utils
23+
24+
assert callable(argparse_utils._gpus_arg_default)
25+
assert "pytorch_lightning.utilities.argparse_utils" in sys.modules
26+
27+
assert "pytorch_lightning.utilities.argparse_utils" not in sys.modules
28+
29+
30+
def test_patch_legacy_gpus_arg_default():
31+
with pl_legacy_patch():
32+
from pytorch_lightning.utilities.argparse import _gpus_arg_default
33+
34+
assert callable(_gpus_arg_default)
35+
assert not hasattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default")
36+
assert not hasattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default")

0 commit comments

Comments
 (0)