Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 21 additions & 19 deletions pytorch_lightning/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,14 +184,14 @@ def _init_optimizers_and_lr_schedulers(
optim_conf = _MockOptimizer()

optimizers, lr_schedulers, optimizer_frequencies, monitor = _configure_optimizers(optim_conf)
lr_schedulers = (
lr_scheduler_configs = (
_configure_schedulers_automatic_opt(lr_schedulers, monitor)
if model.automatic_optimization
else _configure_schedulers_manual_opt(lr_schedulers)
)
_set_scheduler_opt_idx(optimizers, lr_schedulers)
_validate_scheduler_api(lr_schedulers, model)
return optimizers, lr_schedulers, optimizer_frequencies
_set_scheduler_opt_idx(optimizers, lr_scheduler_configs)
_validate_scheduler_api(lr_scheduler_configs, model)
return optimizers, lr_scheduler_configs, optimizer_frequencies


def _configure_optimizers(
Expand Down Expand Up @@ -259,8 +259,9 @@ def _configure_optimizers(


def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str]) -> List[LRSchedulerConfig]:
"""Convert each scheduler into dict structure with relevant information, when using automatic optimization."""
lr_schedulers = []
"""Convert each scheduler into `LRSchedulerConfig` with relevant information, when using automatic
optimization."""
lr_scheduler_configs = []
for scheduler in schedulers:
if isinstance(scheduler, dict):
# check provided keys
Expand Down Expand Up @@ -296,24 +297,25 @@ def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str]
" Are you sure you didn't mean 'interval': 'step'?",
category=RuntimeWarning,
)
scheduler = LRSchedulerConfig(**scheduler)
config = LRSchedulerConfig(**scheduler)
elif isinstance(scheduler, ReduceLROnPlateau):
if monitor is None:
raise MisconfigurationException(
"`configure_optimizers` must include a monitor when a `ReduceLROnPlateau`"
" scheduler is used. For example:"
' {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"}'
)
scheduler = LRSchedulerConfig(scheduler, reduce_on_plateau=True, monitor=monitor)
config = LRSchedulerConfig(scheduler, reduce_on_plateau=True, monitor=monitor)
else:
scheduler = LRSchedulerConfig(scheduler)
lr_schedulers.append(scheduler)
return lr_schedulers
config = LRSchedulerConfig(scheduler)
lr_scheduler_configs.append(config)
return lr_scheduler_configs


def _configure_schedulers_manual_opt(schedulers: list) -> List[LRSchedulerConfig]:
"""Convert each scheduler into dict structure with relevant information, when using manual optimization."""
lr_schedulers = []
"""Convert each scheduler into `LRSchedulerConfig` structure with relevant information, when using manual
optimization."""
lr_scheduler_configs = []
for scheduler in schedulers:
if isinstance(scheduler, dict):
invalid_keys = {"interval", "frequency", "reduce_on_plateau", "monitor", "strict"}
Expand All @@ -326,15 +328,15 @@ def _configure_schedulers_manual_opt(schedulers: list) -> List[LRSchedulerConfig
category=RuntimeWarning,
)

scheduler = LRSchedulerConfig(**{key: scheduler[key] for key in scheduler if key not in invalid_keys})
config = LRSchedulerConfig(**{key: scheduler[key] for key in scheduler if key not in invalid_keys})
else:
scheduler = LRSchedulerConfig(scheduler)
lr_schedulers.append(scheduler)
return lr_schedulers
config = LRSchedulerConfig(scheduler)
lr_scheduler_configs.append(config)
return lr_scheduler_configs


def _validate_scheduler_api(lr_schedulers: List[LRSchedulerConfig], model: "pl.LightningModule") -> None:
for config in lr_schedulers:
def _validate_scheduler_api(lr_scheduler_configs: List[LRSchedulerConfig], model: "pl.LightningModule") -> None:
for config in lr_scheduler_configs:
scheduler = config.scheduler
if not isinstance(scheduler, _SupportsStateDict):
raise TypeError(
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def _initialize_deepspeed_train(self, model):
lr_scheduler = LRSchedulerConfig(deepspeed_scheduler)
else:
lr_scheduler.scheduler = deepspeed_scheduler
self.lr_schedulers = [lr_scheduler]
self.lr_scheduler_configs = [lr_scheduler]
self.model = model

@contextlib.contextmanager
Expand Down Expand Up @@ -578,7 +578,7 @@ def setup_optimizers(self, trainer: "pl.Trainer") -> None:
# via `_initialize_deepspeed_train`
# empty optimizers, schedulers and frequencies
self.optimizers = []
self.lr_schedulers = []
self.lr_scheduler_configs = []
self.optimizer_frequencies = []

@property
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/strategies/horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def _unpack_lightning_optimizer(opt):
param_group["lr"] *= self.world_size

# Horovod: adjust base LR used by schedulers to match scaled optimizer initial LR
lr_scheduler_configs = self.lr_schedulers
lr_scheduler_configs = self.lr_scheduler_configs
for config in lr_scheduler_configs:
scheduler = config.scheduler
scheduler.base_lrs = [lr * self.world_size for lr in scheduler.base_lrs]
Expand Down
11 changes: 6 additions & 5 deletions pytorch_lightning/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ def __init__(
self.precision_plugin = precision_plugin
self._optimizers: List[Optimizer] = []
self._lightning_optimizers: Dict[int, LightningOptimizer] = {}
# TODO: rename to `lr_scheduler_configs` to match the property in the `Trainer`
self.lr_schedulers: List[LRSchedulerConfig] = []
self.lr_scheduler_configs: List[LRSchedulerConfig] = []
self.optimizer_frequencies: List[int] = []
if is_overridden("post_dispatch", self, parent=Strategy):
rank_zero_deprecation(
Expand Down Expand Up @@ -117,7 +116,7 @@ def setup_optimizers(self, trainer: "pl.Trainer") -> None:
"""
if trainer.state.fn not in (TrainerFn.FITTING, TrainerFn.TUNING):
return
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = _init_optimizers_and_lr_schedulers(
self.optimizers, self.lr_scheduler_configs, self.optimizer_frequencies = _init_optimizers_and_lr_schedulers(
self.lightning_module
)

Expand All @@ -134,10 +133,12 @@ def setup(self, trainer: "pl.Trainer") -> None:

def setup_precision_plugin(self) -> None:
"""Attaches the precision plugin to the accelerator."""
model, optimizers, schedulers = self.precision_plugin.connect(self.model, self.optimizers, self.lr_schedulers)
model, optimizers, lr_scheduler_configs = self.precision_plugin.connect(
self.model, self.optimizers, self.lr_scheduler_configs
)
self.model = model
self.optimizers = optimizers
self.lr_schedulers = schedulers
self.lr_scheduler_configs = lr_scheduler_configs

def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None:
"""Moves the state of the optimizers to the appropriate device if needed."""
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2015,7 +2015,7 @@ def lightning_optimizers(self) -> Dict[int, LightningOptimizer]:

@property
def lr_scheduler_configs(self) -> List[LRSchedulerConfig]:
return self.strategy.lr_schedulers
return self.strategy.lr_scheduler_configs

@property
def lr_schedulers(self) -> List[Dict[str, Any]]:
Expand All @@ -2026,7 +2026,7 @@ def lr_schedulers(self) -> List[Dict[str, Any]]:
)
from dataclasses import asdict

return [asdict(config) for config in self.strategy.lr_schedulers]
return [asdict(config) for config in self.strategy.lr_scheduler_configs]

@property
def optimizer_frequencies(self) -> List[int]:
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def func(trainer):
scheduler = _LinearLR(*args) if self.mode == "linear" else _ExponentialLR(*args)

trainer.strategy.optimizers = [optimizer]
trainer.strategy.lr_schedulers = [LRSchedulerConfig(scheduler, interval="step", opt_idx=0)]
trainer.strategy.lr_scheduler_configs = [LRSchedulerConfig(scheduler, interval="step", opt_idx=0)]
trainer.strategy.optimizer_frequencies = []
_set_scheduler_opt_idx(trainer.optimizers, trainer.lr_scheduler_configs)

Expand Down