Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src/lightning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def _detail(self: Any, message: str, *args: Any, **kwargs: Any) -> None:
from lightning.fabric.utilities.seed import seed_everything # noqa: E402
from lightning.pytorch.callbacks import Callback # noqa: E402
from lightning.pytorch.core import LightningDataModule, LightningModule # noqa: E402
from lightning.pytorch.lite import LightningLite # noqa: E402
from lightning.pytorch.trainer import Trainer # noqa: E402

import lightning.app # isort: skip # noqa: E402
Expand All @@ -61,7 +60,6 @@ def _detail(self: Any, message: str, *args: Any, **kwargs: Any) -> None:
"LightningModule",
"Callback",
"seed_everything",
"LightningLite",
"Fabric",
"storage",
"pdb",
Expand Down
3 changes: 2 additions & 1 deletion tests/tests_pytorch/utilities/migration/test_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,9 @@ def test_migrate_model_checkpoint_save_on_train_epoch_end_default_collision():
assert updated_checkpoint["callbacks"] == old_checkpoint["callbacks"] # no migration was performed


def test_migrate_dropped_apex_amp_state():
def test_migrate_dropped_apex_amp_state(monkeypatch):
"""Test that the migration warns about collisions that would occur if the keys were modified."""
monkeypatch.setattr(pl, "__version__", "2.0.0") # pretend this version of Lightning is >= 2.0.0
old_checkpoint = {"amp_scaling_state": {"scale": 1.23}}
_set_version(old_checkpoint, "1.9.0") # pretend a checkpoint prior to 2.0.0
with pytest.warns(UserWarning, match="checkpoint contains apex AMP data"):
Expand Down