Skip to content

Commit 2577285

Browse files
qmarufBordacarmocca
authored
Use LRScheduler for torch >= 1.14 otherwise use _LRScheduler (#15768)
Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 4fea6bf commit 2577285

File tree

11 files changed

+34
-36
lines changed

11 files changed

+34
-36
lines changed

docs/source-pytorch/cli/lightning_cli_intermediate_2.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ If the scheduler you want needs other arguments, add them via the CLI (no need t
201201
202202
python main.py fit --lr_scheduler=ReduceLROnPlateau --lr_scheduler.monitor=epoch
203203
204-
Furthermore, any custom subclass of ``torch.optim.lr_scheduler._LRScheduler`` can be used as learning rate scheduler:
204+
Furthermore, any custom subclass of ``torch.optim.lr_scheduler.LRScheduler`` can be used as learning rate scheduler:
205205

206206
.. code:: python
207207

src/lightning_lite/utilities/types.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from torch.optim import Optimizer
2020
from typing_extensions import Protocol, runtime_checkable
2121

22-
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_13
22+
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_13, _TORCH_GREATER_EQUAL_1_14
2323

2424
_PATH = Union[str, Path]
2525
_DEVICE = Union[torch.device, str, int]
@@ -63,7 +63,7 @@ def rank(self) -> int:
6363
# Inferred from `torch.optim.lr_scheduler.pyi`
6464
# Missing attributes were added to improve typing
6565
@runtime_checkable
66-
class _LRScheduler(_Stateful[str], Protocol):
66+
class LRScheduler(_Stateful[str], Protocol):
6767
optimizer: Optimizer
6868
base_lrs: List[float]
6969

@@ -74,6 +74,11 @@ def step(self, epoch: Optional[int] = None) -> None:
7474
...
7575

7676

77+
_TORCH_LRSCHEDULER = (
78+
torch.optim.lr_scheduler.LRScheduler if _TORCH_GREATER_EQUAL_1_14 else torch.optim.lr_scheduler._LRScheduler
79+
)
80+
81+
7782
# Inferred from `torch.optim.lr_scheduler.pyi`
7883
# Missing attributes were added to improve typing
7984
@runtime_checkable

src/pytorch_lightning/callbacks/stochastic_weight_avg.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from torch.optim.swa_utils import SWALR
2424

2525
import pytorch_lightning as pl
26-
from lightning_lite.utilities.types import _LRScheduler
26+
from lightning_lite.utilities.types import LRScheduler
2727
from pytorch_lightning.callbacks.callback import Callback
2828
from pytorch_lightning.strategies import DDPFullyShardedStrategy, DeepSpeedStrategy
2929
from pytorch_lightning.strategies.fully_sharded_native import DDPFullyShardedNativeStrategy
@@ -125,7 +125,7 @@ def __init__(
125125
self._model_contains_batch_norm: Optional[bool] = None
126126
self._average_model: Optional["pl.LightningModule"] = None
127127
self._initialized = False
128-
self._swa_scheduler: Optional[_LRScheduler] = None
128+
self._swa_scheduler: Optional[LRScheduler] = None
129129
self._scheduler_state: Optional[Dict] = None
130130
self._init_n_averaged = 0
131131
self._latest_update_epoch = -1
@@ -192,7 +192,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
192192

193193
assert trainer.max_epochs is not None
194194
self._swa_scheduler = cast(
195-
_LRScheduler,
195+
LRScheduler,
196196
SWALR(
197197
optimizer,
198198
swa_lr=self._swa_lrs, # type: ignore[arg-type]

src/pytorch_lightning/core/optimizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,8 +253,8 @@ def _configure_optimizers(
253253
" Output from `model.configure_optimizers()` should be one of:\n"
254254
" * `Optimizer`\n"
255255
" * [`Optimizer`]\n"
256-
" * ([`Optimizer`], [`_LRScheduler`])\n"
257-
' * {"optimizer": `Optimizer`, (optional) "lr_scheduler": `_LRScheduler`}\n'
256+
" * ([`Optimizer`], [`LRScheduler`])\n"
257+
' * {"optimizer": `Optimizer`, (optional) "lr_scheduler": `LRScheduler`}\n'
258258
' * A list of the previously described dict format, with an optional "frequency" key (int)'
259259
)
260260
return optimizers, lr_schedulers, optimizer_frequencies, monitor

src/pytorch_lightning/demos/boring_classes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
import torch.nn.functional as F
1919
from torch import Tensor
2020
from torch.optim import Optimizer
21-
from torch.optim.lr_scheduler import _LRScheduler
2221
from torch.utils.data import DataLoader, Dataset, IterableDataset, Subset
2322

23+
from lightning_lite.utilities.types import _TORCH_LRSCHEDULER
2424
from pytorch_lightning import LightningDataModule, LightningModule
2525
from pytorch_lightning.core.optimizer import LightningOptimizer
2626
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
@@ -137,7 +137,7 @@ def test_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> No
137137
outputs = cast(List[Dict[str, Tensor]], outputs)
138138
torch.stack([x["y"] for x in outputs]).mean()
139139

140-
def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List[_LRScheduler]]:
140+
def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List[_TORCH_LRSCHEDULER]]:
141141
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
142142
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
143143
return [optimizer], [lr_scheduler]

src/pytorch_lightning/strategies/deepspeed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from lightning_lite.utilities.enums import AMPType, PrecisionType
3535
from lightning_lite.utilities.optimizer import _optimizers_to_device
3636
from lightning_lite.utilities.seed import reset_seed
37-
from lightning_lite.utilities.types import _LRScheduler, _PATH, ReduceLROnPlateau
37+
from lightning_lite.utilities.types import _PATH, LRScheduler, ReduceLROnPlateau
3838
from pytorch_lightning.accelerators.cuda import CUDAAccelerator
3939
from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers
4040
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
@@ -426,7 +426,7 @@ def _setup_model_and_optimizer(
426426
self,
427427
model: Module,
428428
optimizer: Optional[Optimizer],
429-
lr_scheduler: Optional[Union[_LRScheduler, ReduceLROnPlateau]] = None,
429+
lr_scheduler: Optional[Union[LRScheduler, ReduceLROnPlateau]] = None,
430430
) -> Tuple["deepspeed.DeepSpeedEngine", Optimizer]:
431431
"""Initialize one model and one optimizer with an optional learning rate scheduler.
432432

src/pytorch_lightning/strategies/hivemind.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import pytorch_lightning as pl
1111
from lightning_lite.utilities.enums import PrecisionType
12-
from lightning_lite.utilities.types import _LRScheduler, ReduceLROnPlateau
12+
from lightning_lite.utilities.types import LRScheduler, ReduceLROnPlateau
1313
from pytorch_lightning.strategies.strategy import Strategy, TBroadcast
1414
from pytorch_lightning.utilities.data import extract_batch_size
1515
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@@ -312,7 +312,7 @@ class HiveMindScheduler:
312312

313313
base_lrs: List[float]
314314

315-
def __init__(self, optimizer: "hivemind.Optimizer", scheduler: _LRScheduler) -> None:
315+
def __init__(self, optimizer: "hivemind.Optimizer", scheduler: LRScheduler) -> None:
316316
# copy most of the `Scheduler` methods into this instance. `__del__` is skipped in case the scheduler has
317317
# implemented custom logic which we would not want to call on destruction of the `HiveMindScheduler`
318318
self.__dict__ = {k: v for k, v in scheduler.__dict__.items() if k not in ("step", "__del__")}

src/pytorch_lightning/tuner/lr_finder.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@
2121
import numpy as np
2222
import torch
2323
from lightning_utilities.core.imports import RequirementCache
24-
from torch.optim.lr_scheduler import _LRScheduler
2524

2625
import pytorch_lightning as pl
26+
from lightning_lite.utilities.types import _TORCH_LRSCHEDULER
2727
from pytorch_lightning.callbacks import Callback
2828
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2929
from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr
3030
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
31-
from pytorch_lightning.utilities.types import LRSchedulerConfig, STEP_OUTPUT
31+
from pytorch_lightning.utilities.types import LRScheduler, LRSchedulerConfig, STEP_OUTPUT
3232

3333
# check if ipywidgets is installed before importing tqdm.auto
3434
# to ensure it won't fail and a progress bar is displayed
@@ -124,7 +124,7 @@ def _exchange_scheduler(self, trainer: "pl.Trainer") -> None:
124124

125125
args = (optimizer, self.lr_max, self.num_training)
126126
scheduler = _LinearLR(*args) if self.mode == "linear" else _ExponentialLR(*args)
127-
scheduler = cast(pl.utilities.types._LRScheduler, scheduler)
127+
scheduler = cast(LRScheduler, scheduler)
128128

129129
trainer.strategy.optimizers = [optimizer]
130130
trainer.strategy.lr_scheduler_configs = [LRSchedulerConfig(scheduler, interval="step", opt_idx=0)]
@@ -404,7 +404,7 @@ def on_train_batch_end(
404404
self.losses.append(smoothed_loss)
405405

406406

407-
class _LinearLR(_LRScheduler):
407+
class _LinearLR(_TORCH_LRSCHEDULER):
408408
"""Linearly increases the learning rate between two boundaries over a number of iterations.
409409
410410
Args:
@@ -423,7 +423,7 @@ def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: in
423423
self.num_iter = num_iter
424424
super().__init__(optimizer, last_epoch)
425425

426-
def get_lr(self) -> List[float]: # type: ignore[override]
426+
def get_lr(self) -> List[float]:
427427
curr_iter = self.last_epoch + 1
428428
r = curr_iter / self.num_iter
429429

@@ -439,7 +439,7 @@ def lr(self) -> Union[float, List[float]]:
439439
return self._lr
440440

441441

442-
class _ExponentialLR(_LRScheduler):
442+
class _ExponentialLR(_TORCH_LRSCHEDULER):
443443
"""Exponentially increases the learning rate between two boundaries over a number of iterations.
444444
445445
Arguments:
@@ -458,7 +458,7 @@ def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: in
458458
self.num_iter = num_iter
459459
super().__init__(optimizer, last_epoch)
460460

461-
def get_lr(self) -> List[float]: # type: ignore[override]
461+
def get_lr(self) -> List[float]:
462462
curr_iter = self.last_epoch + 1
463463
r = curr_iter / self.num_iter
464464

src/pytorch_lightning/utilities/types.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,7 @@
2727
from torchmetrics import Metric
2828
from typing_extensions import Protocol, runtime_checkable
2929

30-
try:
31-
from torch.optim.lr_scheduler import LRScheduler as TorchLRScheduler
32-
except ImportError:
33-
# For torch <= 1.13.x
34-
# TODO: Remove once minimum torch version is 1.14 (or 2.0)
35-
from torch.optim.lr_scheduler import _LRScheduler as TorchLRScheduler
36-
37-
from lightning_lite.utilities.types import _LRScheduler, ProcessGroup, ReduceLROnPlateau
30+
from lightning_lite.utilities.types import _TORCH_LRSCHEDULER, LRScheduler, ProcessGroup, ReduceLROnPlateau
3831

3932
_NUMBER = Union[int, float]
4033
_METRIC = Union[Metric, Tensor, _NUMBER]
@@ -118,15 +111,15 @@ def no_sync(self) -> Generator:
118111

119112

120113
# todo: improve LRSchedulerType naming/typing
121-
LRSchedulerTypeTuple = (TorchLRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau)
122-
LRSchedulerTypeUnion = Union[TorchLRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau]
123-
LRSchedulerType = Union[Type[TorchLRScheduler], Type[torch.optim.lr_scheduler.ReduceLROnPlateau]]
124-
LRSchedulerPLType = Union[_LRScheduler, ReduceLROnPlateau]
114+
LRSchedulerTypeTuple = (_TORCH_LRSCHEDULER, torch.optim.lr_scheduler.ReduceLROnPlateau)
115+
LRSchedulerTypeUnion = Union[_TORCH_LRSCHEDULER, torch.optim.lr_scheduler.ReduceLROnPlateau]
116+
LRSchedulerType = Union[Type[_TORCH_LRSCHEDULER], Type[torch.optim.lr_scheduler.ReduceLROnPlateau]]
117+
LRSchedulerPLType = Union[LRScheduler, ReduceLROnPlateau]
125118

126119

127120
@dataclass
128121
class LRSchedulerConfig:
129-
scheduler: Union[_LRScheduler, ReduceLROnPlateau]
122+
scheduler: Union[_TORCH_LRSCHEDULER, ReduceLROnPlateau]
130123
# no custom name
131124
name: Optional[str] = None
132125
# after epoch is over

tests/tests_pytorch/trainer/optimization/test_manual_optimization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -965,7 +965,7 @@ def training_step(self, batch, batch_idx):
965965
with patch("torch.optim.lr_scheduler.StepLR.step") as lr_step:
966966
trainer.fit(model)
967967

968-
# If a lr scheduler inherits `torch.optim.lr_scheduler._LRScheduler`,
968+
# If a lr scheduler inherits `torch.optim.lr_scheduler.LRScheduler`,
969969
# `.step()` is called once during its instantiation.
970970
# Thus, the call count should be 1, not 0.
971971
assert lr_step.call_count == 1

0 commit comments

Comments
 (0)