Skip to content

Commit b5fa896

Browse files
Make LightningModule torch.jit.script-able again (#15947)
* Make LightningModule torch.jit.script-able again * remove skip Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 67a47d4 commit b5fa896

File tree

5 files changed

+15
-39
lines changed

5 files changed

+15
-39
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
9191
- Fixed `fit_loop.restarting` to be `False` for lr finder ([#15620](https://github.com/Lightning-AI/lightning/pull/15620))
9292

9393

94+
- Fixed `torch.jit.script`-ing a LightningModule causing an unintended error message about deprecated `use_amp` property ([#15947](https://github.com/Lightning-AI/lightning/pull/15947))
95+
96+
9497
## [1.8.3] - 2022-11-22
9598

9699
### Changed

src/pytorch_lightning/_graveyard/core.py

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
from typing import Any
1515

16-
from pytorch_lightning import LightningDataModule, LightningModule
16+
from pytorch_lightning import LightningDataModule
1717

1818

1919
def _on_save_checkpoint(_: LightningDataModule, __: Any) -> None:
@@ -32,28 +32,6 @@ def _on_load_checkpoint(_: LightningDataModule, __: Any) -> None:
3232
)
3333

3434

35-
def _use_amp(_: LightningModule) -> None:
36-
# Remove in v2.0.0 and the skip in `__jit_unused_properties__`
37-
if not LightningModule._jit_is_scripting:
38-
# cannot use `AttributeError` as it messes up with `nn.Module.__getattr__`
39-
raise RuntimeError(
40-
"`LightningModule.use_amp` was deprecated in v1.6 and is no longer accessible as of v1.8."
41-
" Please use `Trainer.amp_backend`.",
42-
)
43-
44-
45-
def _use_amp_setter(_: LightningModule, __: bool) -> None:
46-
# Remove in v2.0.0
47-
# cannot use `AttributeError` as it messes up with `nn.Module.__getattr__`
48-
raise RuntimeError(
49-
"`LightningModule.use_amp` was deprecated in v1.6 and is no longer accessible as of v1.8."
50-
" Please use `Trainer.amp_backend`.",
51-
)
52-
53-
54-
# Properties
55-
LightningModule.use_amp = property(fget=_use_amp, fset=_use_amp_setter)
56-
5735
# Methods
5836
LightningDataModule.on_save_checkpoint = _on_save_checkpoint
5937
LightningDataModule.on_load_checkpoint = _on_load_checkpoint

src/pytorch_lightning/core/module.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ class LightningModule(
8888
"automatic_optimization",
8989
"truncated_bptt_steps",
9090
"trainer",
91-
"use_amp", # from graveyard
9291
]
9392
+ _DeviceDtypeModuleMixin.__jit_unused_properties__
9493
+ HyperparametersMixin.__jit_unused_properties__

tests/tests_pytorch/core/test_lightning_module.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,17 @@ def test_proper_refcount():
425425
assert sys.getrefcount(torch_module) == sys.getrefcount(lightning_module)
426426

427427

428+
def test_lightning_module_scriptable():
429+
"""Test that the LightningModule is `torch.jit.script`-able.
430+
431+
Regression test for #15917.
432+
"""
433+
model = BoringModel()
434+
trainer = Trainer()
435+
model.trainer = trainer
436+
torch.jit.script(model)
437+
438+
428439
def test_trainer_reference_recursively():
429440
ensemble = LightningModule()
430441
inner = LightningModule()

tests/tests_pytorch/graveyard/test_core.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -53,18 +53,3 @@ def on_load_checkpoint(self, checkpoint):
5353
match="`LightningDataModule.on_load_checkpoint`.*no longer supported as of v1.8.",
5454
):
5555
trainer.fit(model, OnLoadDataModule())
56-
57-
58-
def test_v2_0_0_lightning_module_unsupported_use_amp():
59-
model = BoringModel()
60-
with pytest.raises(
61-
RuntimeError,
62-
match="`LightningModule.use_amp`.*no longer accessible as of v1.8.",
63-
):
64-
model.use_amp
65-
66-
with pytest.raises(
67-
RuntimeError,
68-
match="`LightningModule.use_amp`.*no longer accessible as of v1.8.",
69-
):
70-
model.use_amp = False

0 commit comments

Comments
 (0)