Skip to content
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
c1b271e
Support DDP for LRFinder
rohitgr7 Oct 25, 2022
bac7042
chlog
rohitgr7 Oct 25, 2022
dfd0b06
Merge branch 'master' into feat/lr_finder_ddp
rohitgr7 Oct 25, 2022
60b0235
Apply suggestions from code review
rohitgr7 Oct 25, 2022
a8726e4
Support DDP for BatchSizeFinder
rohitgr7 Oct 25, 2022
eb15a8a
Merge branch 'feat/lr_finder_ddp' into feat/bs_finder_ddp
rohitgr7 Oct 25, 2022
ca61bfa
chlog
rohitgr7 Oct 25, 2022
7a3a1ad
docs
rohitgr7 Oct 25, 2022
88a0090
docs
rohitgr7 Oct 25, 2022
21a0666
Merge branch 'feat/lr_finder_ddp' into feat/bs_finder_ddp
rohitgr7 Oct 25, 2022
9ba35a2
state key
awaelchli Nov 9, 2022
0f3739d
update
awaelchli Nov 9, 2022
dc30511
typo
awaelchli Nov 9, 2022
988424d
Merge branch 'master' into feature/migrate-state-key
awaelchli Nov 10, 2022
800393a
handle edge case
awaelchli Nov 10, 2022
e10adce
Merge branch 'master' into feat/bs_finder_ddp
awaelchli Nov 10, 2022
de83d92
remove dead code
awaelchli Nov 10, 2022
d9d0f1e
sync the found batch size
awaelchli Nov 10, 2022
edd1a0b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 10, 2022
cc96f52
fix keyerror
awaelchli Nov 10, 2022
89ff79e
skip legacy key
awaelchli Nov 10, 2022
bdbed4d
Merge branch 'master' into feature/migrate-state-key
awaelchli Nov 10, 2022
b2263e1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 10, 2022
7da2ee5
verify script
awaelchli Nov 10, 2022
392e060
wip
awaelchli Nov 10, 2022
56ebdc4
fix
awaelchli Nov 10, 2022
544c0fd
x
awaelchli Nov 10, 2022
2239370
x
awaelchli Nov 10, 2022
b2e92bf
x
awaelchli Nov 10, 2022
288b37d
Merge branch 'master' into feature/migrate-state-key
awaelchli Nov 11, 2022
a3ef361
update
awaelchli Nov 11, 2022
83808fb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 11, 2022
4fb0484
test updates
awaelchli Nov 11, 2022
c87207f
update docstring
awaelchli Nov 11, 2022
c86481f
unused imports
awaelchli Nov 11, 2022
5c22004
remove debug script
awaelchli Nov 11, 2022
5cb4e3f
Update src/pytorch_lightning/utilities/migration/migration.py
awaelchli Nov 11, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion src/pytorch_lightning/utilities/migration/migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,13 @@
cp model.ckpt model.ckpt.backup
python -m pytorch_lightning.utilities.upgrade_checkpoint --file model.ckpt
"""
import re
from typing import Any, Callable, Dict, List, Type, Union

from typing import Any, Callable, Dict, List
from lightning_utilities.core.rank_zero import rank_zero_warn

import pytorch_lightning as pl
from lightning_lite.utilities.warnings import PossibleUserWarning
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint

Expand All @@ -43,6 +47,7 @@ def _migration_index() -> Dict[str, List[Callable[[_CHECKPOINT], _CHECKPOINT]]]:
"0.10.0": [_migrate_model_checkpoint_early_stopping],
"1.6.0": [_migrate_loop_global_step_to_progress_tracking, _migrate_loop_current_epoch_to_progress_tracking],
"1.6.5": [_migrate_loop_batches_that_stepped],
"1.9.0": [_migrate_model_checkpoint_save_on_train_epoch_end_default],
}


Expand Down Expand Up @@ -160,3 +165,44 @@ def _migrate_loop_batches_that_stepped(checkpoint: _CHECKPOINT) -> _CHECKPOINT:
},
"state_dict": {},
}


def _migrate_model_checkpoint_save_on_train_epoch_end_default(checkpoint: _CHECKPOINT) -> _CHECKPOINT:
"""Changes the value of `save_on_train_epoch_end` inside the state key of ``ModelCheckpoint`` callbacks.

The initial value of ``ModelCheckpoint.save_on_train_epoch_end`` before training (and before loading the state)
has changed. After this breaking change, Lightning is no longer able to determine whether
``save_on_train_epoch_end=True|False`` was set by the user or set internally (according to old logic).

Checkpoints created with ``ModelCheckpoint(..., save_on_train_epoch_end=True|False)`` will be loaded as if
``save_on_train_epoch_end`` was set to ``None`` to mitigate the impact of this breaking change.

Version: 1.9.0
Commit: f4ca56
PR: #15300
"""
if "callbacks" not in checkpoint:
return checkpoint

def new_key(old_key: Union[str, Type[pl.Callback]]) -> Union[str, Type[pl.Callback]]:
if not isinstance(old_key, str):
# this is a legacy state key (the type of the callback)
return old_key
if not old_key.startswith("ModelCheckpoint"):
return old_key
return re.sub("'save_on_train_epoch_end': (True|False)", "'save_on_train_epoch_end': None", old_key)

num_keys = len(checkpoint["callbacks"])
new_callback_states = {new_key(old_key): state for old_key, state in checkpoint["callbacks"].items()}
if len(new_callback_states) < num_keys:
rank_zero_warn(
"You have multiple `ModelCheckpoint` callback states in this checkpoint, but we found state keys"
" that would end up colliding with each other after an upgrade, which means we can't differentiate"
" which of your checkpoint callbacks needs which states. At least one of your `ModelCheckpoint`"
" callbacks will not be able to reloaded the state.",
category=PossibleUserWarning,
)
return checkpoint

checkpoint["callbacks"] = new_callback_states
return checkpoint
36 changes: 36 additions & 0 deletions tests/tests_pytorch/utilities/migration/test_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch

import pytorch_lightning as pl
from lightning_lite.utilities.warnings import PossibleUserWarning
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.demos.boring_classes import BoringModel, ManualOptimBoringModel
Expand Down Expand Up @@ -109,3 +110,38 @@ def on_train_start(self) -> None:
trainer.fit(model, ckpt_path=ckpt_path)
new_loop = trainer.fit_loop.epoch_loop
assert new_loop.global_step == new_loop._batches_that_stepped == 2


def test_migrate_model_checkpoint_save_on_train_epoch_end_default():
# # None -> None
legacy_state_key_none = "ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None, 'save_on_train_epoch_end': None}" # noqa: E501
old_checkpoint = {"callbacks": {legacy_state_key_none: {"dummy": 0}}, "global_step": 0, "epoch": 1}
_set_version(old_checkpoint, "1.8.9") # pretend a checkpoint prior to 1.9.0
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint)
assert updated_checkpoint["callbacks"] == {legacy_state_key_none: {"dummy": 0}} # None -> None

# True -> None
legacy_state_key_true = "ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None, 'save_on_train_epoch_end': True}" # noqa: E501
old_checkpoint = {"callbacks": {legacy_state_key_true: {"dummy": 0}}, "global_step": 0, "epoch": 1}
_set_version(old_checkpoint, "1.8.9") # pretend a checkpoint prior to 1.9.0
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint)
assert updated_checkpoint["callbacks"] == {legacy_state_key_none: {"dummy": 0}} # True -> None

# False -> None
legacy_state_key_false = "ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None, 'save_on_train_epoch_end': False}" # noqa: E501
old_checkpoint = {"callbacks": {legacy_state_key_false: {"dummy": 0}}, "global_step": 0, "epoch": 1}
_set_version(old_checkpoint, "1.8.9") # pretend a checkpoint prior to 1.9.0
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint)
assert updated_checkpoint["callbacks"] == {legacy_state_key_none: {"dummy": 0}} # False -> None

# Simulate collision
# False -> None and True -> None
old_checkpoint = {
"callbacks": {legacy_state_key_false: {"dummy": 0}, legacy_state_key_true: {"dummy": 0}},
"global_step": 0,
"epoch": 1,
}
_set_version(old_checkpoint, "1.8.9") # pretend a checkpoint prior to 1.9.0
with pytest.warns(PossibleUserWarning, match="callback states in this checkpoint.* colliding with each other"):
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint.copy())
assert updated_checkpoint["callbacks"] == old_checkpoint["callbacks"] # no migration was performed