Skip to content

Commit edc9986

Browse files
authored
Apply dynamo to training_step, validation_step, test_step, predict_step (#15957)
* Apply dynamo to training_step, validation_step, test_step, predict_step * Add entry to CHANGELOG.md
1 parent 4983083 commit edc9986

File tree

4 files changed

+35
-4
lines changed

4 files changed

+35
-4
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99

1010
### Added
1111

12+
- Added support for `torch.compile` ([#15922](https://github.com/Lightning-AI/lightning/pull/15922), [15957](https://github.com/Lightning-AI/lightning/pull/15957))
13+
14+
1215
- Added support for DDP with `LRFinder` ([#15304](https://github.com/Lightning-AI/lightning/pull/15304))
1316

1417

src/pytorch_lightning/core/module.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1979,9 +1979,17 @@ def from_compiled(cls, model: "torch._dynamo.OptimizedModule") -> "pl.LightningM
19791979
"compiler": "dynamo",
19801980
"dynamo_ctx": model.dynamo_ctx,
19811981
"original_forward": orig_module.forward,
1982+
"original_training_step": orig_module.training_step,
1983+
"original_validation_step": orig_module.validation_step,
1984+
"original_test_step": orig_module.test_step,
1985+
"original_predict_step": orig_module.predict_step,
19821986
}
19831987

19841988
orig_module.forward = model.dynamo_ctx(orig_module.forward) # type: ignore[assignment]
1989+
orig_module.training_step = model.dynamo_ctx(orig_module.training_step) # type: ignore[assignment]
1990+
orig_module.validation_step = model.dynamo_ctx(orig_module.validation_step) # type: ignore[assignment]
1991+
orig_module.test_step = model.dynamo_ctx(orig_module.test_step) # type: ignore[assignment]
1992+
orig_module.predict_step = model.dynamo_ctx(orig_module.predict_step) # type: ignore[assignment]
19851993
return orig_module
19861994

19871995
@classmethod
@@ -2010,6 +2018,10 @@ def to_uncompiled(cls, model: Union["pl.LightningModule", "torch._dynamo.Optimiz
20102018
raise ValueError("`model` must either be an instance of torch._dynamo.OptimizedModule or LightningModule")
20112019

20122020
model.forward = model._compiler_ctx["original_forward"] # type: ignore[assignment]
2021+
model.training_step = model._compiler_ctx["original_training_step"] # type: ignore[assignment]
2022+
model.validation_step = model._compiler_ctx["original_validation_step"] # type: ignore[assignment]
2023+
model.test_step = model._compiler_ctx["original_test_step"] # type: ignore[assignment]
2024+
model.predict_step = model._compiler_ctx["original_predict_step"] # type: ignore[assignment]
20132025
model._compiler_ctx = None
20142026

20152027
return model

tests/tests_pytorch/core/test_lightning_module.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from torch.optim import Adam, SGD
2222

2323
from pytorch_lightning import LightningModule, Trainer
24-
from pytorch_lightning.demos.boring_classes import BoringModel, DemoModel
24+
from pytorch_lightning.demos.boring_classes import BoringModel
2525
from pytorch_lightning.loggers import TensorBoardLogger
2626
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2727
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_11, _TORCH_GREATER_EQUAL_1_13
@@ -457,15 +457,32 @@ def test_trainer_reference_recursively():
457457
@RunIf(min_torch="1.14.0.dev20221202")
458458
def test_compile_uncompile():
459459

460-
lit_model = DemoModel()
460+
lit_model = BoringModel()
461461
model_compiled = torch.compile(lit_model)
462462

463463
lit_model_compiled = LightningModule.from_compiled(model_compiled)
464464

465+
def has_dynamo(fn):
466+
return any(el for el in dir(fn) if el.startswith("_torchdynamo"))
467+
465468
assert isinstance(lit_model_compiled, LightningModule)
466469
assert lit_model_compiled._compiler_ctx is not None
470+
assert has_dynamo(lit_model_compiled.forward)
471+
assert has_dynamo(lit_model_compiled.training_step)
472+
assert has_dynamo(lit_model_compiled.validation_step)
473+
assert has_dynamo(lit_model_compiled.test_step)
474+
assert has_dynamo(lit_model_compiled.predict_step)
467475

468476
lit_model_orig = LightningModule.to_uncompiled(lit_model)
469477

470478
assert lit_model_orig._compiler_ctx is None
471479
assert lit_model_orig.forward == lit_model.forward
480+
assert lit_model_orig.training_step == lit_model.training_step
481+
assert lit_model_orig.validation_step == lit_model.validation_step
482+
assert lit_model_orig.test_step == lit_model.test_step
483+
assert lit_model_orig.predict_step == lit_model.predict_step
484+
assert not has_dynamo(lit_model_orig.forward)
485+
assert not has_dynamo(lit_model_orig.training_step)
486+
assert not has_dynamo(lit_model_orig.validation_step)
487+
assert not has_dynamo(lit_model_orig.test_step)
488+
assert not has_dynamo(lit_model_orig.predict_step)

tests/tests_pytorch/trainer/test_trainer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
from pytorch_lightning.demos.boring_classes import (
4646
BoringDataModule,
4747
BoringModel,
48-
DemoModel,
4948
RandomDataset,
5049
RandomIterableDataset,
5150
RandomIterableDatasetWithLen,
@@ -2245,7 +2244,7 @@ def on_fit_start(self):
22452244
# TODO: replace with 1.14 when it is released
22462245
@RunIf(min_torch="1.14.0.dev20221202")
22472246
def test_trainer_compiled_model():
2248-
model = DemoModel()
2247+
model = BoringModel()
22492248

22502249
model = torch.compile(model)
22512250

0 commit comments

Comments
 (0)