21
21
import numpy as np
22
22
import torch
23
23
from lightning_utilities .core .imports import RequirementCache
24
- from torch .optim .lr_scheduler import _LRScheduler
25
24
26
25
import pytorch_lightning as pl
26
+ from lightning_lite .utilities .types import _TORCH_LRSCHEDULER
27
27
from pytorch_lightning .callbacks import Callback
28
28
from pytorch_lightning .utilities .exceptions import MisconfigurationException
29
29
from pytorch_lightning .utilities .parsing import lightning_hasattr , lightning_setattr
30
30
from pytorch_lightning .utilities .rank_zero import rank_zero_warn
31
- from pytorch_lightning .utilities .types import LRSchedulerConfig , STEP_OUTPUT
31
+ from pytorch_lightning .utilities .types import LRScheduler , LRSchedulerConfig , STEP_OUTPUT
32
32
33
33
# check if ipywidgets is installed before importing tqdm.auto
34
34
# to ensure it won't fail and a progress bar is displayed
@@ -124,7 +124,7 @@ def _exchange_scheduler(self, trainer: "pl.Trainer") -> None:
124
124
125
125
args = (optimizer , self .lr_max , self .num_training )
126
126
scheduler = _LinearLR (* args ) if self .mode == "linear" else _ExponentialLR (* args )
127
- scheduler = cast (pl . utilities . types . _LRScheduler , scheduler )
127
+ scheduler = cast (LRScheduler , scheduler )
128
128
129
129
trainer .strategy .optimizers = [optimizer ]
130
130
trainer .strategy .lr_scheduler_configs = [LRSchedulerConfig (scheduler , interval = "step" , opt_idx = 0 )]
@@ -404,7 +404,7 @@ def on_train_batch_end(
404
404
self .losses .append (smoothed_loss )
405
405
406
406
407
- class _LinearLR (_LRScheduler ):
407
+ class _LinearLR (_TORCH_LRSCHEDULER ):
408
408
"""Linearly increases the learning rate between two boundaries over a number of iterations.
409
409
410
410
Args:
@@ -423,7 +423,7 @@ def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: in
423
423
self .num_iter = num_iter
424
424
super ().__init__ (optimizer , last_epoch )
425
425
426
- def get_lr (self ) -> List [float ]: # type: ignore[override]
426
+ def get_lr (self ) -> List [float ]:
427
427
curr_iter = self .last_epoch + 1
428
428
r = curr_iter / self .num_iter
429
429
@@ -439,7 +439,7 @@ def lr(self) -> Union[float, List[float]]:
439
439
return self ._lr
440
440
441
441
442
- class _ExponentialLR (_LRScheduler ):
442
+ class _ExponentialLR (_TORCH_LRSCHEDULER ):
443
443
"""Exponentially increases the learning rate between two boundaries over a number of iterations.
444
444
445
445
Arguments:
@@ -458,7 +458,7 @@ def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: in
458
458
self .num_iter = num_iter
459
459
super ().__init__ (optimizer , last_epoch )
460
460
461
- def get_lr (self ) -> List [float ]: # type: ignore[override]
461
+ def get_lr (self ) -> List [float ]:
462
462
curr_iter = self .last_epoch + 1
463
463
r = curr_iter / self .num_iter
464
464
0 commit comments