Skip to content

Commit ed52823

Browse files
mauvilsaBordacarmocca
authored
LightningCLI support for optimizers and schedulers via dependency injection (#15869)
Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 38acba0 commit ed52823

File tree

5 files changed

+140
-46
lines changed

5 files changed

+140
-46
lines changed

docs/source-pytorch/cli/lightning_cli_advanced_3.rst

Lines changed: 63 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,8 @@ If the CLI is implemented as ``LightningCLI(MyMainModel)`` the configuration wou
217217
It is also possible to combine ``subclass_mode_model=True`` and submodules, thereby having two levels of ``class_path``.
218218

219219

220-
Optimizers
221-
^^^^^^^^^^
220+
Fixed optimizer and scheduler
221+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
222222

223223
In some cases, fixing the optimizer and/or learning scheduler might be desired instead of allowing multiple. For this,
224224
you can manually add the arguments for specific classes by subclassing the CLI. The following code snippet shows how to
@@ -251,58 +251,88 @@ where the arguments can be passed directly through the command line without spec
251251
252252
$ python trainer.py fit --optimizer.lr=0.01 --lr_scheduler.gamma=0.2
253253
254-
The automatic implementation of ``configure_optimizers`` can be disabled by linking the configuration group. An example
255-
can be when someone wants to add support for multiple optimizers:
254+
255+
Multiple optimizers and schedulers
256+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
257+
258+
By default, the CLIs support multiple optimizers and/or learning schedulers, automatically implementing
259+
``configure_optimizers``. This behavior can be disabled by providing ``auto_configure_optimizers=False`` on
260+
instantiation of :class:`~pytorch_lightning.cli.LightningCLI`. This would be required for example to support multiple
261+
optimizers, for each selecting a particular optimizer class. Similar to multiple submodules, this can be done via
262+
`dependency injection <https://en.wikipedia.org/wiki/Dependency_injection>`__. Unlike the submodules, it is not possible
263+
to expect an instance of a class, because optimizers require the module's parameters to optimize, which are only
264+
available after instantiation of the module. Learning schedulers are a similar situation, requiring an optimizer
265+
instance. For these cases, dependency injection involves providing a function that instantiates the respective class
266+
when called.
267+
268+
An example of a model that uses two optimizers is the following:
256269

257270
.. code-block:: python
258271
259-
from pytorch_lightning.cli import instantiate_class
272+
from typing import Iterable
273+
from torch.optim import Optimizer
274+
275+
276+
OptimizerCallable = Callable[[Iterable], Optimizer]
260277
261278
262279
class MyModel(LightningModule):
263-
def __init__(self, optimizer1_init: dict, optimizer2_init: dict):
280+
def __init__(self, optimizer1: OptimizerCallable, optimizer2: OptimizerCallable):
264281
super().__init__()
265-
self.optimizer1_init = optimizer1_init
266-
self.optimizer2_init = optimizer2_init
282+
self.optimizer1 = optimizer1
283+
self.optimizer2 = optimizer2
267284
268285
def configure_optimizers(self):
269-
optimizer1 = instantiate_class(self.parameters(), self.optimizer1_init)
270-
optimizer2 = instantiate_class(self.parameters(), self.optimizer2_init)
286+
optimizer1 = self.optimizer1(self.parameters())
287+
optimizer2 = self.optimizer2(self.parameters())
271288
return [optimizer1, optimizer2]
272289
273290
274-
class MyLightningCLI(LightningCLI):
275-
def add_arguments_to_parser(self, parser):
276-
parser.add_optimizer_args(nested_key="optimizer1", link_to="model.optimizer1_init")
277-
parser.add_optimizer_args(nested_key="optimizer2", link_to="model.optimizer2_init")
291+
cli = MyLightningCLI(MyModel, auto_configure_optimizers=False)
278292
293+
Note the type ``Callable[[Iterable], Optimizer]``, which denotes a function that receives a singe argument, some
294+
learnable parameters, and returns an optimizer instance. With this, from the command line it is possible to select the
295+
class and init arguments for each of the optimizers, as follows:
279296

280-
cli = MyLightningCLI(MyModel)
297+
.. code-block:: bash
281298
282-
The value given to ``optimizer*_init`` will always be a dictionary including ``class_path`` and ``init_args`` entries.
283-
The function :func:`~pytorch_lightning.cli.instantiate_class` takes care of importing the class defined in
284-
``class_path`` and instantiating it using some positional arguments, in this case ``self.parameters()``, and the
285-
``init_args``. Any number of optimizers and learning rate schedulers can be added when using ``link_to``.
299+
$ python trainer.py fit \
300+
--model.optimizer1=Adam \
301+
--model.optimizer1.lr=0.01 \
302+
--model.optimizer2=AdamW \
303+
--model.optimizer2.lr=0.0001
286304
287-
With shorthand notation:
305+
In the example above, the ``OptimizerCallable`` type alias was created to illustrate what the type hint means. For
306+
convenience, this type alias and one for learning schedulers is available in the ``cli`` module. An example of a model
307+
that uses dependency injection for an optimizer and a learning scheduler is:
288308

289-
.. code-block:: bash
309+
.. code-block:: python
290310
291-
$ python trainer.py fit \
292-
--optimizer1=Adam \
293-
--optimizer1.lr=0.01 \
294-
--optimizer2=AdamW \
295-
--optimizer2.lr=0.0001
311+
from pytorch_lightning.cli import OptimizerCallable, LRSchedulerCallable, LightningCLI
296312
297-
You can also pass the class path directly, for example, if the optimizer hasn't been imported:
298313
299-
.. code-block:: bash
314+
class MyModel(LightningModule):
315+
def __init__(
316+
self,
317+
optimizer: OptimizerCallable = torch.optim.Adam,
318+
scheduler: LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR,
319+
):
320+
super().__init__()
321+
self.optimizer = optimizer
322+
self.scheduler = scheduler
300323
301-
$ python trainer.py fit \
302-
--optimizer1=torch.optim.Adam \
303-
--optimizer1.lr=0.01 \
304-
--optimizer2=torch.optim.AdamW \
305-
--optimizer2.lr=0.0001
324+
def configure_optimizers(self):
325+
optimizer = self.optimizer(self.parameters())
326+
scheduler = self.scheduler(self.parameters())
327+
return {"optimizer": optimizer, "lr_scheduler": scheduler}
328+
329+
330+
cli = MyLightningCLI(MyModel, auto_configure_optimizers=False)
331+
332+
Note that for this example, classes are used as defaults. This is compatible with the type hints, since they are also
333+
callables that receive the same first argument and return an instance of the class. Classes that have more than one
334+
required argument will not work as default. For these cases a lambda function can be used, e.g. ``optimizer:
335+
OptimizerCallable = lambda p: torch.optim.SGD(p, lr=0.01)``.
306336

307337

308338
Run from Python

requirements/pytorch/extra.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@
55
matplotlib>3.1, <3.6.2
66
omegaconf>=2.0.5, <2.3.0
77
hydra-core>=1.0.5, <1.3.0
8-
jsonargparse[signatures]>=4.17.0, <4.18.0
8+
jsonargparse[signatures]>=4.18.0, <4.19.0
99
rich>=10.14.0, !=10.15.0.a, <13.0.0

src/pytorch_lightning/CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3030
- Added a warning when `self.log(..., logger=True)` is called without a configured logger ([#15814](https://github.com/Lightning-AI/lightning/pull/15814))
3131

3232

33-
- Added support for activation checkpointing for the `DDPFullyShardedNativeStrategy` strategy ([#15826](https://github.com/Lightning-AI/lightning/pull/15826))
33+
- Added `LightningCLI` support for optimizer and learning schedulers via callable type dependency injection ([#15869](https://github.com/Lightning-AI/lightning/pull/15869))
34+
3435

36+
- Added support for activation checkpointing for the `DDPFullyShardedNativeStrategy` strategy ([#15826](https://github.com/Lightning-AI/lightning/pull/15826))
3537

3638
- Added the option to set `DDPFullyShardedNativeStrategy(cpu_offload=True|False)` via bool instead of needing to pass a configufation object ([#15832](https://github.com/Lightning-AI/lightning/pull/15832))
3739

src/pytorch_lightning/cli.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import sys
1616
from functools import partial, update_wrapper
1717
from types import MethodType
18-
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
18+
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
1919

2020
import torch
2121
from lightning_utilities.core.imports import RequirementCache
@@ -24,6 +24,7 @@
2424

2525
import pytorch_lightning as pl
2626
from lightning_lite.utilities.cloud_io import get_filesystem
27+
from lightning_lite.utilities.types import _TORCH_LRSCHEDULER
2728
from pytorch_lightning import Callback, LightningDataModule, LightningModule, seed_everything, Trainer
2829
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2930
from pytorch_lightning.utilities.model_helpers import is_overridden
@@ -49,19 +50,22 @@
4950
locals()["Namespace"] = object
5051

5152

52-
ArgsType = Optional[Union[List[str], Dict[str, Any], Namespace]]
53-
54-
5553
class ReduceLROnPlateau(torch.optim.lr_scheduler.ReduceLROnPlateau):
5654
def __init__(self, optimizer: Optimizer, monitor: str, *args: Any, **kwargs: Any) -> None:
5755
super().__init__(optimizer, *args, **kwargs)
5856
self.monitor = monitor
5957

6058

6159
# LightningCLI requires the ReduceLROnPlateau defined here, thus it shouldn't accept the one from pytorch:
62-
LRSchedulerTypeTuple = (torch.optim.lr_scheduler._LRScheduler, ReduceLROnPlateau)
63-
LRSchedulerTypeUnion = Union[torch.optim.lr_scheduler._LRScheduler, ReduceLROnPlateau]
64-
LRSchedulerType = Union[Type[torch.optim.lr_scheduler._LRScheduler], Type[ReduceLROnPlateau]]
60+
LRSchedulerTypeTuple = (_TORCH_LRSCHEDULER, ReduceLROnPlateau)
61+
LRSchedulerTypeUnion = Union[_TORCH_LRSCHEDULER, ReduceLROnPlateau]
62+
LRSchedulerType = Union[Type[_TORCH_LRSCHEDULER], Type[ReduceLROnPlateau]]
63+
64+
65+
# Type aliases intended for convenience of CLI developers
66+
ArgsType = Optional[Union[List[str], Dict[str, Any], Namespace]]
67+
OptimizerCallable = Callable[[Iterable], Optimizer]
68+
LRSchedulerCallable = Callable[[Optimizer], Union[_TORCH_LRSCHEDULER, ReduceLROnPlateau]]
6569

6670

6771
class LightningArgumentParser(ArgumentParser):
@@ -274,6 +278,7 @@ def __init__(
274278
subclass_mode_data: bool = False,
275279
args: ArgsType = None,
276280
run: bool = True,
281+
auto_configure_optimizers: bool = True,
277282
auto_registry: bool = False,
278283
**kwargs: Any, # Remove with deprecations of v1.10
279284
) -> None:
@@ -326,6 +331,7 @@ def __init__(
326331
self.trainer_defaults = trainer_defaults or {}
327332
self.seed_everything_default = seed_everything_default
328333
self.parser_kwargs = parser_kwargs or {} # type: ignore[var-annotated] # github.com/python/mypy/issues/6463
334+
self.auto_configure_optimizers = auto_configure_optimizers
329335

330336
self._handle_deprecated_params(kwargs)
331337

@@ -447,10 +453,11 @@ def _add_arguments(self, parser: LightningArgumentParser) -> None:
447453
self.add_core_arguments_to_parser(parser)
448454
self.add_arguments_to_parser(parser)
449455
# add default optimizer args if necessary
450-
if not parser._optimizers: # already added by the user in `add_arguments_to_parser`
451-
parser.add_optimizer_args((Optimizer,))
452-
if not parser._lr_schedulers: # already added by the user in `add_arguments_to_parser`
453-
parser.add_lr_scheduler_args(LRSchedulerTypeTuple)
456+
if self.auto_configure_optimizers:
457+
if not parser._optimizers: # already added by the user in `add_arguments_to_parser`
458+
parser.add_optimizer_args((Optimizer,))
459+
if not parser._lr_schedulers: # already added by the user in `add_arguments_to_parser`
460+
parser.add_lr_scheduler_args(LRSchedulerTypeTuple)
454461
self.link_optimizers_and_lr_schedulers(parser)
455462

456463
def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
@@ -602,6 +609,9 @@ def configure_optimizers(
602609
def _add_configure_optimizers_method_to_model(self, subcommand: Optional[str]) -> None:
603610
"""Overrides the model's :meth:`~pytorch_lightning.core.module.LightningModule.configure_optimizers` method
604611
if a single optimizer and optionally a scheduler argument groups are added to the parser as 'AUTOMATIC'."""
612+
if not self.auto_configure_optimizers:
613+
return
614+
605615
parser = self._parser(subcommand)
606616

607617
def get_automatic(

tests/tests_pytorch/test_cli.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@
3636
instantiate_class,
3737
LightningArgumentParser,
3838
LightningCLI,
39+
LRSchedulerCallable,
3940
LRSchedulerTypeTuple,
41+
OptimizerCallable,
4042
SaveConfigCallback,
4143
)
4244
from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel
@@ -706,6 +708,56 @@ def __init__(self, optim1: dict, optim2: dict, scheduler: dict):
706708
assert isinstance(cli.model.scheduler, torch.optim.lr_scheduler.ExponentialLR)
707709

708710

711+
def test_lightning_cli_optimizers_and_lr_scheduler_with_callable_type():
712+
class TestModel(BoringModel):
713+
def __init__(
714+
self,
715+
optim1: OptimizerCallable = torch.optim.Adam,
716+
optim2: OptimizerCallable = torch.optim.Adagrad,
717+
scheduler: LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR,
718+
):
719+
super().__init__()
720+
self.optim1 = optim1
721+
self.optim2 = optim2
722+
self.scheduler = scheduler
723+
724+
def configure_optimizers(self):
725+
optim1 = self.optim1(self.parameters())
726+
optim2 = self.optim2(self.parameters())
727+
scheduler = self.scheduler(optim2)
728+
return (
729+
{"optimizer": optim1},
730+
{"optimizer": optim2, "lr_scheduler": scheduler},
731+
)
732+
733+
out = StringIO()
734+
with mock.patch("sys.argv", ["any.py", "-h"]), redirect_stdout(out), pytest.raises(SystemExit):
735+
LightningCLI(TestModel, run=False, auto_configure_optimizers=False)
736+
out = out.getvalue()
737+
assert "--optimizer" not in out
738+
assert "--lr_scheduler" not in out
739+
assert "--model.optim1" in out
740+
assert "--model.optim2" in out
741+
assert "--model.scheduler" in out
742+
743+
cli_args = [
744+
"--model.optim1=Adagrad",
745+
"--model.optim2=SGD",
746+
"--model.optim2.lr=0.007",
747+
"--model.scheduler=ExponentialLR",
748+
"--model.scheduler.gamma=0.3",
749+
]
750+
with mock.patch("sys.argv", ["any.py"] + cli_args):
751+
cli = LightningCLI(TestModel, run=False, auto_configure_optimizers=False)
752+
753+
init = cli.model.configure_optimizers()
754+
assert isinstance(init[0]["optimizer"], torch.optim.Adagrad)
755+
assert isinstance(init[1]["optimizer"], torch.optim.SGD)
756+
assert isinstance(init[1]["lr_scheduler"], torch.optim.lr_scheduler.ExponentialLR)
757+
assert init[1]["optimizer"].param_groups[0]["lr"] == 0.007
758+
assert init[1]["lr_scheduler"].gamma == 0.3
759+
760+
709761
@pytest.mark.parametrize("fn", [fn.value for fn in TrainerFn])
710762
def test_lightning_cli_trainer_fn(fn):
711763
class TestCLI(LightningCLI):

0 commit comments

Comments
 (0)