Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a `PrecisionPlugin.teardown` method ([#10990](https://github.com/PyTorchLightning/pytorch-lightning/issues/10990))


- Added `opt_idx` to scheduler config if not assigned by user ([#11247](https://github.com/PyTorchLightning/pytorch-lightning/issues/11247))


- Added a `MisconfigurationException` if user provided `opt_idx` in scheduler config doesn't match with actual optimizer index of its respective optimizer ([#11247](https://github.com/PyTorchLightning/pytorch-lightning/issues/11247))



### Changed

Expand Down
26 changes: 19 additions & 7 deletions pytorch_lightning/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def _init_optimizers_and_lr_schedulers(model: "pl.LightningModule") -> Tuple[Lis
_configure_schedulers_automatic_opt if model.automatic_optimization else _configure_schedulers_manual_opt
)
lr_schedulers = _configure_schedulers(lr_schedulers, monitor)
_validate_scheduler_optimizer(optimizers, lr_schedulers)
_set_scheduler_opt_idx(optimizers, lr_schedulers)
return optimizers, lr_schedulers, optimizer_frequencies


Expand Down Expand Up @@ -337,15 +337,27 @@ def _get_default_scheduler_config() -> Dict[str, Any]:
"reduce_on_plateau": False, # most often not ReduceLROnPlateau scheduler
"monitor": None, # value to monitor for ReduceLROnPlateau
"strict": True, # enforce that the monitor exists for ReduceLROnPlateau
"opt_idx": None, # necessary to store opt_idx when optimizer frequencies are specified
"opt_idx": None, # opt_idx assigned internally if not assigned by user
}


def _validate_scheduler_optimizer(optimizers: List[Any], lr_schedulers: List[Any]) -> None:
if any(sch["scheduler"].optimizer not in optimizers for sch in lr_schedulers):
raise MisconfigurationException(
"Some schedulers are attached with an optimizer that wasn't returned from `configure_optimizers`."
)
def _set_scheduler_opt_idx(optimizers: List[Any], lr_schedulers: List[Any]) -> None:
for sch in lr_schedulers:

for opt_idx, opt in enumerate(optimizers):
if sch["scheduler"].optimizer is opt:
if sch["opt_idx"] is not None and sch["opt_idx"] != opt_idx:
raise MisconfigurationException(
"`opt_idx` set inside scheduler config does not match with the index"
" of the respective optimizer returned from `configure_optimizers`."
)

sch["opt_idx"] = opt_idx
break
else:
raise MisconfigurationException(
"Some schedulers are attached with an optimizer that wasn't returned from `configure_optimizers`."
)


def _validate_optim_conf(optim_conf: Dict[str, Any]) -> None:
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ def _update_learning_rates(
opt_indices = []

for lr_scheduler in self.trainer.lr_schedulers:
if isinstance(lr_scheduler["opt_idx"], int) and lr_scheduler["opt_idx"] not in opt_indices:
if lr_scheduler["opt_idx"] not in opt_indices:
continue

if update_plateau_schedulers ^ lr_scheduler["reduce_on_plateau"]:
Expand Down
9 changes: 7 additions & 2 deletions pytorch_lightning/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@

import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.core.optimizer import _get_default_scheduler_config, _init_optimizers_and_lr_schedulers
from pytorch_lightning.core.optimizer import (
_get_default_scheduler_config,
_init_optimizers_and_lr_schedulers,
_set_scheduler_opt_idx,
)
from pytorch_lightning.loggers.base import DummyLogger
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.cloud_io import get_filesystem
Expand Down Expand Up @@ -124,11 +128,12 @@ def func(trainer):
args = (optimizer, self.lr_max, self.num_training)
scheduler = _LinearLR(*args) if self.mode == "linear" else _ExponentialLR(*args)
sched_config = _get_default_scheduler_config()
sched_config.update({"scheduler": scheduler, "interval": "step"})
sched_config.update({"scheduler": scheduler, "interval": "step", "opt_idx": 0})

trainer.strategy.optimizers = [optimizer]
trainer.strategy.lr_schedulers = [sched_config]
trainer.strategy.optimizer_frequencies = []
_set_scheduler_opt_idx(trainer.optimizers, trainer.lr_schedulers)

return func

Expand Down
22 changes: 20 additions & 2 deletions tests/trainer/optimization/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def configure_optimizers(self):
frequency=1,
reduce_on_plateau=True,
strict=True,
opt_idx=None,
opt_idx=0,
name=None,
)

Expand Down Expand Up @@ -182,7 +182,7 @@ def test_optimizer_return_options(tmpdir):
monitor=None,
strict=True,
name=None,
opt_idx=None,
opt_idx=0,
)

# opt tuple of 2 lists
Expand Down Expand Up @@ -511,6 +511,24 @@ def configure_optimizers(self):
trainer.fit(model)


def test_invalid_opt_idx_in_scheduler(tmpdir):
"""Test exception when incorrect opt_idx is set in lr_scheduler config."""

class InvalidOptimizerModel(BoringModel):
def configure_optimizers(self):
opt1 = optim.SGD(self.layer.parameters(), lr=0.1)
opt2 = optim.SGD(self.layer.parameters(), lr=0.1)
lr_scheduler = {"scheduler": optim.lr_scheduler.StepLR(opt2, step_size=1), "opt_idx": 0}
return [opt1, opt2], [lr_scheduler]

model = InvalidOptimizerModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
with pytest.raises(
MisconfigurationException, match="`opt_idx` .* does not match with the index of the respective optimizer"
):
trainer.fit(model)


def test_invalid_optimizer_dict_raises(tmpdir):
"""Test exception when lr_scheduler dict has no scheduler."""

Expand Down