Skip to content

Conversation

awaelchli
Copy link
Contributor

@awaelchli awaelchli commented Nov 9, 2022

What does this PR do?

A breaking change was introduced in #15300 regarding the state loading for the model checkpoint callback. Checkpoints that have the the state saved with save_on_train_epoch_end=True|False can still be loaded, but the ModelCheckpoint won't load its state correctly, since the code and what's saved in the checkpoint do not match.

We can fix this issue by dropping the save_on_train_epoch_end from the state key. This way, all old states from ModelCheckpoint can be loaded successfully without ambiguity in what save_on_train_epoch_end should be set to.
An extremely rare edge case in which a migration is not possible is also handled with a user warning.

Full repro script:

import os

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def on_train_start(self) -> None:
        print("The callback states at the beginning of training are")
        for c in self.trainer.callbacks:
            print(c.state_key, c.state_dict())

    def on_train_end(self) -> None:
        print("The callback states at the end of training are")
        for c in self.trainer.callbacks:
            print(c.state_key, c.state_dict())

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss1", loss)
        self.log("valid_loss2", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    checkpoint1 = ModelCheckpoint(monitor="valid_loss1")  # in v1.7 -> True
    checkpoint2 = ModelCheckpoint(monitor="valid_loss2", save_on_train_epoch_end=False)  # in v1.7 -> False
    checkpoint3 = ModelCheckpoint(monitor="train_loss", save_on_train_epoch_end=True)  # in v1.7 -> True

    name = "check/verify1-7-7.ckpt"
    resume = False  # train on 1.7.7, switch to 1.8.0 and flip resume to True
    
    if not resume:

        model = BoringModel()
        trainer = Trainer(
            default_root_dir=os.getcwd(),
            limit_train_batches=1,
            limit_val_batches=1,
            limit_test_batches=1,
            num_sanity_val_steps=0,
            max_epochs=1,
            callbacks=[checkpoint1, checkpoint2, checkpoint3],
            enable_model_summary=False,
        )
        trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
        trainer.save_checkpoint(name)
    else:
        model = BoringModel()
        trainer = Trainer(
            default_root_dir=os.getcwd(),
            limit_train_batches=1,
            limit_val_batches=1,
            limit_test_batches=1,
            num_sanity_val_steps=0,
            max_epochs=2,
            enable_model_summary=False,
            callbacks=[checkpoint1, checkpoint2, checkpoint3],
        )
        trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data, ckpt_path=name)


if __name__ == "__main__":
    run()

Before submitting

  • Was this discussed/approved via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or internal minor changes/refactorings)

PR review

Anyone in the community is free to review the PR once the tests have passed.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:

  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

Did you have fun?

I made sure I had fun coding 🙃

cc @Borda @awaelchli @justusschock

@github-actions github-actions bot added the pl Generic label for PyTorch Lightning package label Nov 9, 2022
@awaelchli awaelchli added feature Is an improvement or enhancement checkpointing Related to checkpointing breaking change Includes a breaking change and removed pl Generic label for PyTorch Lightning package labels Nov 9, 2022
@awaelchli awaelchli added this to the v1.9 milestone Nov 9, 2022
@github-actions github-actions bot added the pl Generic label for PyTorch Lightning package label Nov 9, 2022
@awaelchli awaelchli marked this pull request as ready for review November 10, 2022 04:26
@awaelchli awaelchli changed the title Checkpoint migration for ModelCheckpoint state-key changes WIP: Checkpoint migration for ModelCheckpoint state-key changes Nov 10, 2022
@awaelchli awaelchli requested a review from tchaton as a code owner November 11, 2022 11:34
@awaelchli awaelchli changed the title WIP: Checkpoint migration for ModelCheckpoint state-key changes Checkpoint migration for ModelCheckpoint state-key changes Nov 11, 2022
@mergify mergify bot added the ready PRs ready to be merged label Nov 11, 2022
Copy link
Contributor

@carmocca carmocca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great!

@awaelchli awaelchli enabled auto-merge (squash) November 11, 2022 12:21
@awaelchli awaelchli merged commit 18288eb into master Nov 11, 2022
@awaelchli awaelchli deleted the feature/migrate-state-key branch November 11, 2022 13:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
breaking change Includes a breaking change checkpointing Related to checkpointing feature Is an improvement or enhancement pl Generic label for PyTorch Lightning package ready PRs ready to be merged
Projects
No open projects
Status: Done
Development

Successfully merging this pull request may close these issues.

5 participants