Skip to content

Commit edf90d1

Browse files
carmoccaawaelchli
andcommitted
Add amp_scaling_state (apex) migration (#16161)
Co-authored-by: Adrian Wälchli <[email protected]> Fixes #16149 (comment)
1 parent 719b125 commit edf90d1

File tree

4 files changed

+30
-2
lines changed

4 files changed

+30
-2
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
77

88
## [fabricLaunchPLVersion] - 202Y-MM-DD
99

10+
### Added
11+
12+
- Added migration logic to warn about checkpoints with apex AMP state ([#16161](https://github.com/Lightning-AI/lightning/pull/16161))
13+
1014
### Removed
1115

1216
- Removed the `pytorch_lightning.lite` module in favor of `lightning_fabric` ([#15953](https://github.com/Lightning-AI/lightning/pull/15953))

src/pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -293,8 +293,6 @@ def restore_precision_plugin_state(self) -> None:
293293
prec_plugin.load_state_dict(self._loaded_checkpoint[prec_plugin.__class__.__qualname__])
294294

295295
# old checkpoints compatibility
296-
if "amp_scaling_state" in self._loaded_checkpoint:
297-
rank_zero_warn("This checkpoint contains apex AMP data, but apex support has been removed.")
298296
if "native_amp_scaling_state" in self._loaded_checkpoint and isinstance(prec_plugin, MixedPrecisionPlugin):
299297
prec_plugin.load_state_dict(self._loaded_checkpoint["native_amp_scaling_state"])
300298

src/pytorch_lightning/utilities/migration/migration.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def _migration_index() -> Dict[str, List[Callable[[_CHECKPOINT], _CHECKPOINT]]]:
4646
"1.6.0": [_migrate_loop_global_step_to_progress_tracking, _migrate_loop_current_epoch_to_progress_tracking],
4747
"1.6.5": [_migrate_loop_batches_that_stepped],
4848
"1.9.0": [_migrate_model_checkpoint_save_on_train_epoch_end_default],
49+
"2.0.0": [_drop_apex_amp_state],
4950
}
5051

5152

@@ -203,3 +204,18 @@ def new_key(old_key: str) -> str:
203204

204205
checkpoint["callbacks"] = new_callback_states
205206
return checkpoint
207+
208+
209+
def _drop_apex_amp_state(checkpoint: _CHECKPOINT) -> _CHECKPOINT:
210+
"""Apex support was removed in v2.0.0, and this migration drops it from the state-keys saved in the checkpoint
211+
dict.
212+
213+
Version: 2.0.0
214+
Commit: e544676ff434ed96c6dd3b4e73a708bcb27ebcf1
215+
PR: #16149
216+
"""
217+
key = "amp_scaling_state"
218+
if key in checkpoint:
219+
rank_zero_warn("This checkpoint contains apex AMP data, but apex support has been removed in v2.0.0.")
220+
del checkpoint[key]
221+
return checkpoint

tests/tests_pytorch/utilities/migration/test_migration.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,3 +146,13 @@ def test_migrate_model_checkpoint_save_on_train_epoch_end_default_collision():
146146
with pytest.warns(PossibleUserWarning, match="callback states in this checkpoint.* colliding with each other"):
147147
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint.copy(), target_version="1.9.0")
148148
assert updated_checkpoint["callbacks"] == old_checkpoint["callbacks"] # no migration was performed
149+
150+
151+
def test_migrate_dropped_apex_amp_state(monkeypatch):
152+
"""Test that the migration warns about collisions that would occur if the keys were modified."""
153+
monkeypatch.setattr(pl, "__version__", "2.0.0") # pretend this version of Lightning is >= 2.0.0
154+
old_checkpoint = {"amp_scaling_state": {"scale": 1.23}}
155+
_set_version(old_checkpoint, "1.9.0") # pretend a checkpoint prior to 2.0.0
156+
with pytest.warns(UserWarning, match="checkpoint contains apex AMP data"):
157+
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint.copy())
158+
assert "amp_scaling_state" not in updated_checkpoint

0 commit comments

Comments
 (0)