diff --git a/CHANGELOG.md b/CHANGELOG.md index 112b5170cb5a8..6516ab9c5b0fd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -148,6 +148,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `RichModelSummary` callback ([#9546](https://github.com/PyTorchLightning/pytorch-lightning/pull/9546)) +- Added `pl_legacy_patch` load utility for loading old checkpoints that have pickled legacy Lightning attributes ([#9166](https://github.com/PyTorchLightning/pytorch-lightning/pull/9166)) + + ### Changed - `pytorch_lightning.loggers.neptune.NeptuneLogger` is now consistent with new [neptune-client](https://github.com/neptune-ai/neptune-client) API ([#6867](https://github.com/PyTorchLightning/pytorch-lightning/pull/6867)). @@ -242,9 +245,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `on_{train/val/test/predict}_dataloader()` from `LightningModule` and `LightningDataModule` [#9098](https://github.com/PyTorchLightning/pytorch-lightning/pull/9098) -- Updated deprecation of `argparse_utils.py` from removal in 1.4 to 2.0 ([#9162](https://github.com/PyTorchLightning/pytorch-lightning/pull/9162)) - - - Deprecated `on_keyboard_interrupt` callback hook in favor of new `on_exception` hook ([#9260](https://github.com/PyTorchLightning/pytorch-lightning/pull/9260)) @@ -323,6 +323,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed deprecated `profiled_functions` argument from `PyTorchProfiler` ([#9178](https://github.com/PyTorchLightning/pytorch-lightning/pull/9178)) +- Removed deprecated `pytorch_lighting.utilities.argparse_utils` module ([#9166](https://github.com/PyTorchLightning/pytorch-lightning/pull/9166)) + + - Removed deprecated property `Trainer.running_sanity_check` in favor of `Trainer.sanity_checking` ([#9209](https://github.com/PyTorchLightning/pytorch-lightning/pull/9209)) diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index 942f0e32cd8fe..a9950318b18b1 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -29,6 +29,7 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.cloud_io import load as pl_load +from pytorch_lightning.utilities.migration import pl_legacy_patch from pytorch_lightning.utilities.parsing import parse_class_init_keys log = logging.getLogger(__name__) @@ -125,10 +126,11 @@ def load_from_checkpoint( pretrained_model.freeze() y_hat = pretrained_model(x) """ - if map_location is not None: - checkpoint = pl_load(checkpoint_path, map_location=map_location) - else: - checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) + with pl_legacy_patch(): + if map_location is not None: + checkpoint = pl_load(checkpoint_path, map_location=map_location) + else: + checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) if hparams_file is not None: extension = hparams_file.split(".")[-1] diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index b750b0f81b26f..1dd87e3d4b5be 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -26,6 +26,7 @@ from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training +from pytorch_lightning.utilities.migration import pl_legacy_patch from pytorch_lightning.utilities.types import _PATH from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS @@ -65,7 +66,8 @@ def resume_start(self) -> None: self._loaded_checkpoint = self._load_and_validate_checkpoint(checkpoint_path) def _load_and_validate_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: - loaded_checkpoint = self.trainer.training_type_plugin.load_checkpoint(checkpoint_path) + with pl_legacy_patch(): + loaded_checkpoint = self.trainer.training_type_plugin.load_checkpoint(checkpoint_path) if any(key in loaded_checkpoint for key in DEPRECATED_CHECKPOINT_KEYS): raise ValueError( "The checkpoint you're attempting to load follows an" diff --git a/pytorch_lightning/utilities/argparse.py b/pytorch_lightning/utilities/argparse.py index ad885f58192fd..61443bea07cd7 100644 --- a/pytorch_lightning/utilities/argparse.py +++ b/pytorch_lightning/utilities/argparse.py @@ -295,13 +295,6 @@ def _gpus_allowed_type(x: str) -> Union[int, str]: return int(x) -def _gpus_arg_default(x: str) -> Union[int, str]: # pragma: no-cover - # unused, but here for backward compatibility with old checkpoints that need to be able to - # unpickle the function from the checkpoint, as it was not filtered out in versions < 1.2.8 - # see: https://github.com/PyTorchLightning/pytorch-lightning/pull/6898 - pass - - def _int_or_float_type(x: Union[int, float, str]) -> Union[int, float]: if "." in str(x): return float(x) diff --git a/pytorch_lightning/utilities/argparse_utils.py b/pytorch_lightning/utilities/argparse_utils.py deleted file mode 100644 index e3c2c3c86dd94..0000000000000 --- a/pytorch_lightning/utilities/argparse_utils.py +++ /dev/null @@ -1,7 +0,0 @@ -from pytorch_lightning.utilities import rank_zero_deprecation - -rank_zero_deprecation("`argparse_utils` package has been renamed to `argparse` since v1.2 and will be removed in v2.0") - -# for backward compatibility with old checkpoints (versions < 1.2.0) -# that need to be able to unpickle the function from the checkpoint -from pytorch_lightning.utilities.argparse import _gpus_arg_default # noqa: E402, F401 # isort: skip diff --git a/pytorch_lightning/utilities/migration.py b/pytorch_lightning/utilities/migration.py new file mode 100644 index 0000000000000..68f20403bb0c7 --- /dev/null +++ b/pytorch_lightning/utilities/migration.py @@ -0,0 +1,48 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +from types import ModuleType + +import pytorch_lightning.utilities.argparse + + +class pl_legacy_patch: + """Registers legacy artifacts (classes, methods, etc.) that were removed but still need to be included for + unpickling old checkpoints. The following patches apply. + + 1. ``pytorch_lightning.utilities.argparse._gpus_arg_default``: Applies to all checkpoints saved prior to + version 1.2.8. See: https://github.com/PyTorchLightning/pytorch-lightning/pull/6898 + 2. ``pytorch_lightning.utilities.argparse_utils``: A module that was deprecated in 1.2 and removed in 1.4, + but still needs to be available for import for legacy checkpoints. + + Example: + + with pl_legacy_patch(): + torch.load("path/to/legacy/checkpoint.ckpt") + """ + + def __enter__(self): + # `pl.utilities.argparse_utils` was renamed to `pl.utilities.argparse` + legacy_argparse_module = ModuleType("pytorch_lightning.utilities.argparse_utils") + sys.modules["pytorch_lightning.utilities.argparse_utils"] = legacy_argparse_module + + # `_gpus_arg_default` used to be imported from these locations + legacy_argparse_module._gpus_arg_default = lambda x: x + pytorch_lightning.utilities.argparse._gpus_arg_default = lambda x: x + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + if hasattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default"): + delattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default") + del sys.modules["pytorch_lightning.utilities.argparse_utils"] diff --git a/pytorch_lightning/utilities/upgrade_checkpoint.py b/pytorch_lightning/utilities/upgrade_checkpoint.py index ddff8ee1d5ab3..34483ce39b925 100644 --- a/pytorch_lightning/utilities/upgrade_checkpoint.py +++ b/pytorch_lightning/utilities/upgrade_checkpoint.py @@ -18,6 +18,7 @@ import torch from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint +from pytorch_lightning.utilities.migration import pl_legacy_patch KEYS_MAPPING = { "checkpoint_callback_best_model_score": (ModelCheckpoint, "best_model_score"), @@ -58,4 +59,5 @@ def upgrade_checkpoint(filepath): log.info("Creating a backup of the existing checkpoint file before overwriting in the upgrade process.") copyfile(args.file, args.file + ".bak") - upgrade_checkpoint(args.file) + with pl_legacy_patch(): + upgrade_checkpoint(args.file) diff --git a/tests/deprecated_api/test_remove_2-0.py b/tests/deprecated_api/test_remove_2-0.py deleted file mode 100644 index a4a3eb2b3726d..0000000000000 --- a/tests/deprecated_api/test_remove_2-0.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Test deprecated functionality which will be removed in v1.4.0.""" - -import pytest - -from tests.deprecated_api import _soft_unimport_module - - -def test_v2_0_0_deprecated_imports(): - _soft_unimport_module("pytorch_lightning.utilities.argparse_utils") - with pytest.deprecated_call(match="will be removed in v2.0"): - from pytorch_lightning.utilities.argparse_utils import _gpus_arg_default # noqa: F401 diff --git a/tests/utilities/test_migration.py b/tests/utilities/test_migration.py new file mode 100644 index 0000000000000..ee94ee690e798 --- /dev/null +++ b/tests/utilities/test_migration.py @@ -0,0 +1,36 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys + +import pytorch_lightning +from pytorch_lightning.utilities.migration import pl_legacy_patch + + +def test_patch_legacy_argparse_utils(): + with pl_legacy_patch(): + from pytorch_lightning.utilities import argparse_utils + + assert callable(argparse_utils._gpus_arg_default) + assert "pytorch_lightning.utilities.argparse_utils" in sys.modules + + assert "pytorch_lightning.utilities.argparse_utils" not in sys.modules + + +def test_patch_legacy_gpus_arg_default(): + with pl_legacy_patch(): + from pytorch_lightning.utilities.argparse import _gpus_arg_default + + assert callable(_gpus_arg_default) + assert not hasattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default") + assert not hasattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default")