Skip to content

Commit 7b19a02

Browse files
rohitgr7carmocca
andcommitted
Add opt_idx to scheduler config if not assigned by user (#11247)
Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 512ec4f commit 7b19a02

File tree

5 files changed

+53
-12
lines changed

5 files changed

+53
-12
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
5858
- Added a `PrecisionPlugin.teardown` method ([#10990](https://github.com/PyTorchLightning/pytorch-lightning/issues/10990))
5959

6060

61+
- Added `opt_idx` to scheduler config if not assigned by user ([#11247](https://github.com/PyTorchLightning/pytorch-lightning/issues/11247))
62+
63+
64+
- Added a `MisconfigurationException` if user provided `opt_idx` in scheduler config doesn't match with actual optimizer index of its respective optimizer ([#11247](https://github.com/PyTorchLightning/pytorch-lightning/issues/11247))
65+
66+
6167

6268
### Changed
6369

pytorch_lightning/core/optimizer.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def _init_optimizers_and_lr_schedulers(model: "pl.LightningModule") -> Tuple[Lis
184184
_configure_schedulers_automatic_opt if model.automatic_optimization else _configure_schedulers_manual_opt
185185
)
186186
lr_schedulers = _configure_schedulers(lr_schedulers, monitor)
187-
_validate_scheduler_optimizer(optimizers, lr_schedulers)
187+
_set_scheduler_opt_idx(optimizers, lr_schedulers)
188188
return optimizers, lr_schedulers, optimizer_frequencies
189189

190190

@@ -337,15 +337,27 @@ def _get_default_scheduler_config() -> Dict[str, Any]:
337337
"reduce_on_plateau": False, # most often not ReduceLROnPlateau scheduler
338338
"monitor": None, # value to monitor for ReduceLROnPlateau
339339
"strict": True, # enforce that the monitor exists for ReduceLROnPlateau
340-
"opt_idx": None, # necessary to store opt_idx when optimizer frequencies are specified
340+
"opt_idx": None, # opt_idx assigned internally if not assigned by user
341341
}
342342

343343

344-
def _validate_scheduler_optimizer(optimizers: List[Any], lr_schedulers: List[Any]) -> None:
345-
if any(sch["scheduler"].optimizer not in optimizers for sch in lr_schedulers):
346-
raise MisconfigurationException(
347-
"Some schedulers are attached with an optimizer that wasn't returned from `configure_optimizers`."
348-
)
344+
def _set_scheduler_opt_idx(optimizers: List[Any], lr_schedulers: List[Any]) -> None:
345+
for sch in lr_schedulers:
346+
347+
for opt_idx, opt in enumerate(optimizers):
348+
if sch["scheduler"].optimizer is opt:
349+
if sch["opt_idx"] is not None and sch["opt_idx"] != opt_idx:
350+
raise MisconfigurationException(
351+
"`opt_idx` set inside scheduler config does not match with the index"
352+
" of the respective optimizer returned from `configure_optimizers`."
353+
)
354+
355+
sch["opt_idx"] = opt_idx
356+
break
357+
else:
358+
raise MisconfigurationException(
359+
"Some schedulers are attached with an optimizer that wasn't returned from `configure_optimizers`."
360+
)
349361

350362

351363
def _validate_optim_conf(optim_conf: Dict[str, Any]) -> None:

pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ def _update_learning_rates(
468468
opt_indices = []
469469

470470
for lr_scheduler in self.trainer.lr_schedulers:
471-
if isinstance(lr_scheduler["opt_idx"], int) and lr_scheduler["opt_idx"] not in opt_indices:
471+
if lr_scheduler["opt_idx"] not in opt_indices:
472472
continue
473473

474474
if update_plateau_schedulers ^ lr_scheduler["reduce_on_plateau"]:

pytorch_lightning/tuner/lr_finder.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,11 @@
2424

2525
import pytorch_lightning as pl
2626
from pytorch_lightning.callbacks import Callback
27-
from pytorch_lightning.core.optimizer import _get_default_scheduler_config, _init_optimizers_and_lr_schedulers
27+
from pytorch_lightning.core.optimizer import (
28+
_get_default_scheduler_config,
29+
_init_optimizers_and_lr_schedulers,
30+
_set_scheduler_opt_idx,
31+
)
2832
from pytorch_lightning.loggers.base import DummyLogger
2933
from pytorch_lightning.utilities import rank_zero_warn
3034
from pytorch_lightning.utilities.cloud_io import get_filesystem
@@ -124,11 +128,12 @@ def func(trainer):
124128
args = (optimizer, self.lr_max, self.num_training)
125129
scheduler = _LinearLR(*args) if self.mode == "linear" else _ExponentialLR(*args)
126130
sched_config = _get_default_scheduler_config()
127-
sched_config.update({"scheduler": scheduler, "interval": "step"})
131+
sched_config.update({"scheduler": scheduler, "interval": "step", "opt_idx": 0})
128132

129133
trainer.strategy.optimizers = [optimizer]
130134
trainer.strategy.lr_schedulers = [sched_config]
131135
trainer.strategy.optimizer_frequencies = []
136+
_set_scheduler_opt_idx(trainer.optimizers, trainer.lr_schedulers)
132137

133138
return func
134139

tests/trainer/optimization/test_optimizers.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def configure_optimizers(self):
140140
frequency=1,
141141
reduce_on_plateau=True,
142142
strict=True,
143-
opt_idx=None,
143+
opt_idx=0,
144144
name=None,
145145
)
146146

@@ -182,7 +182,7 @@ def test_optimizer_return_options(tmpdir):
182182
monitor=None,
183183
strict=True,
184184
name=None,
185-
opt_idx=None,
185+
opt_idx=0,
186186
)
187187

188188
# opt tuple of 2 lists
@@ -511,6 +511,24 @@ def configure_optimizers(self):
511511
trainer.fit(model)
512512

513513

514+
def test_invalid_opt_idx_in_scheduler(tmpdir):
515+
"""Test exception when incorrect opt_idx is set in lr_scheduler config."""
516+
517+
class InvalidOptimizerModel(BoringModel):
518+
def configure_optimizers(self):
519+
opt1 = optim.SGD(self.layer.parameters(), lr=0.1)
520+
opt2 = optim.SGD(self.layer.parameters(), lr=0.1)
521+
lr_scheduler = {"scheduler": optim.lr_scheduler.StepLR(opt2, step_size=1), "opt_idx": 0}
522+
return [opt1, opt2], [lr_scheduler]
523+
524+
model = InvalidOptimizerModel()
525+
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
526+
with pytest.raises(
527+
MisconfigurationException, match="`opt_idx` .* does not match with the index of the respective optimizer"
528+
):
529+
trainer.fit(model)
530+
531+
514532
def test_invalid_optimizer_dict_raises(tmpdir):
515533
"""Test exception when lr_scheduler dict has no scheduler."""
516534

0 commit comments

Comments
 (0)