Skip to content

Commit de93167

Browse files
authored
Fix LRScheduler import for PyTorch 2.0 (#15940)
* Fix LRScheduler import for PyTorch 2.0 * Add comment for posterity
1 parent 2041908 commit de93167

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

src/pytorch_lightning/utilities/types.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,13 @@
2727
from torchmetrics import Metric
2828
from typing_extensions import Protocol, runtime_checkable
2929

30+
try:
31+
from torch.optim.lr_scheduler import LRScheduler as TorchLRScheduler
32+
except ImportError:
33+
# For torch <= 1.13.x
34+
# TODO: Remove once minimum torch version is 1.14 (or 2.0)
35+
from torch.optim.lr_scheduler import _LRScheduler as TorchLRScheduler
36+
3037
from lightning_lite.utilities.types import _LRScheduler, ProcessGroup, ReduceLROnPlateau
3138

3239
_NUMBER = Union[int, float]
@@ -111,9 +118,9 @@ def no_sync(self) -> Generator:
111118

112119

113120
# todo: improve LRSchedulerType naming/typing
114-
LRSchedulerTypeTuple = (torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau)
115-
LRSchedulerTypeUnion = Union[torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau]
116-
LRSchedulerType = Union[Type[torch.optim.lr_scheduler._LRScheduler], Type[torch.optim.lr_scheduler.ReduceLROnPlateau]]
121+
LRSchedulerTypeTuple = (TorchLRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau)
122+
LRSchedulerTypeUnion = Union[TorchLRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau]
123+
LRSchedulerType = Union[Type[TorchLRScheduler], Type[torch.optim.lr_scheduler.ReduceLROnPlateau]]
117124
LRSchedulerPLType = Union[_LRScheduler, ReduceLROnPlateau]
118125

119126

0 commit comments

Comments
 (0)