@@ -217,8 +217,8 @@ If the CLI is implemented as ``LightningCLI(MyMainModel)`` the configuration wou
217
217
It is also possible to combine ``subclass_mode_model=True `` and submodules, thereby having two levels of ``class_path ``.
218
218
219
219
220
- Optimizers
221
- ^^^^^^^^^^
220
+ Fixed optimizer and scheduler
221
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
222
222
223
223
In some cases, fixing the optimizer and/or learning scheduler might be desired instead of allowing multiple. For this,
224
224
you can manually add the arguments for specific classes by subclassing the CLI. The following code snippet shows how to
@@ -251,58 +251,88 @@ where the arguments can be passed directly through the command line without spec
251
251
252
252
$ python trainer.py fit --optimizer.lr=0.01 --lr_scheduler.gamma=0.2
253
253
254
- The automatic implementation of ``configure_optimizers `` can be disabled by linking the configuration group. An example
255
- can be when someone wants to add support for multiple optimizers:
254
+
255
+ Multiple optimizers and schedulers
256
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
257
+
258
+ By default, the CLIs support multiple optimizers and/or learning schedulers, automatically implementing
259
+ ``configure_optimizers ``. This behavior can be disabled by providing ``auto_configure_optimizers=False `` on
260
+ instantiation of :class: `~pytorch_lightning.cli.LightningCLI `. This would be required for example to support multiple
261
+ optimizers, for each selecting a particular optimizer class. Similar to multiple submodules, this can be done via
262
+ `dependency injection <https://en.wikipedia.org/wiki/Dependency_injection >`__. Unlike the submodules, it is not possible
263
+ to expect an instance of a class, because optimizers require the module's parameters to optimize, which are only
264
+ available after instantiation of the module. Learning schedulers are a similar situation, requiring an optimizer
265
+ instance. For these cases, dependency injection involves providing a function that instantiates the respective class
266
+ when called.
267
+
268
+ An example of a model that uses two optimizers is the following:
256
269
257
270
.. code-block :: python
258
271
259
- from pytorch_lightning.cli import instantiate_class
272
+ from typing import Iterable
273
+ from torch.optim import Optimizer
274
+
275
+
276
+ OptimizerCallable = Callable[[Iterable], Optimizer]
260
277
261
278
262
279
class MyModel (LightningModule ):
263
- def __init__ (self , optimizer1_init : dict , optimizer2_init : dict ):
280
+ def __init__ (self , optimizer1 : OptimizerCallable, optimizer2 : OptimizerCallable ):
264
281
super ().__init__ ()
265
- self .optimizer1_init = optimizer1_init
266
- self .optimizer2_init = optimizer2_init
282
+ self .optimizer1 = optimizer1
283
+ self .optimizer2 = optimizer2
267
284
268
285
def configure_optimizers (self ):
269
- optimizer1 = instantiate_class (self .parameters(), self .optimizer1_init )
270
- optimizer2 = instantiate_class (self .parameters(), self .optimizer2_init )
286
+ optimizer1 = self .optimizer1 (self .parameters())
287
+ optimizer2 = self .optimizer2 (self .parameters())
271
288
return [optimizer1, optimizer2]
272
289
273
290
274
- class MyLightningCLI (LightningCLI ):
275
- def add_arguments_to_parser (self , parser ):
276
- parser.add_optimizer_args(nested_key = " optimizer1" , link_to = " model.optimizer1_init" )
277
- parser.add_optimizer_args(nested_key = " optimizer2" , link_to = " model.optimizer2_init" )
291
+ cli = MyLightningCLI(MyModel, auto_configure_optimizers = False )
278
292
293
+ Note the type ``Callable[[Iterable], Optimizer] ``, which denotes a function that receives a singe argument, some
294
+ learnable parameters, and returns an optimizer instance. With this, from the command line it is possible to select the
295
+ class and init arguments for each of the optimizers, as follows:
279
296
280
- cli = MyLightningCLI(MyModel)
297
+ .. code-block :: bash
281
298
282
- The value given to ``optimizer*_init `` will always be a dictionary including ``class_path `` and ``init_args `` entries.
283
- The function :func: `~pytorch_lightning.cli.instantiate_class ` takes care of importing the class defined in
284
- ``class_path `` and instantiating it using some positional arguments, in this case ``self.parameters() ``, and the
285
- ``init_args ``. Any number of optimizers and learning rate schedulers can be added when using ``link_to ``.
299
+ $ python trainer.py fit \
300
+ --model.optimizer1=Adam \
301
+ --model.optimizer1.lr=0.01 \
302
+ --model.optimizer2=AdamW \
303
+ --model.optimizer2.lr=0.0001
286
304
287
- With shorthand notation:
305
+ In the example above, the ``OptimizerCallable `` type alias was created to illustrate what the type hint means. For
306
+ convenience, this type alias and one for learning schedulers is available in the ``cli `` module. An example of a model
307
+ that uses dependency injection for an optimizer and a learning scheduler is:
288
308
289
- .. code-block :: bash
309
+ .. code-block :: python
290
310
291
- $ python trainer.py fit \
292
- --optimizer1=Adam \
293
- --optimizer1.lr=0.01 \
294
- --optimizer2=AdamW \
295
- --optimizer2.lr=0.0001
311
+ from pytorch_lightning.cli import OptimizerCallable, LRSchedulerCallable, LightningCLI
296
312
297
- You can also pass the class path directly, for example, if the optimizer hasn't been imported:
298
313
299
- .. code-block :: bash
314
+ class MyModel (LightningModule ):
315
+ def __init__ (
316
+ self ,
317
+ optimizer : OptimizerCallable = torch.optim.Adam,
318
+ scheduler : LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR,
319
+ ):
320
+ super ().__init__ ()
321
+ self .optimizer = optimizer
322
+ self .scheduler = scheduler
300
323
301
- $ python trainer.py fit \
302
- --optimizer1=torch.optim.Adam \
303
- --optimizer1.lr=0.01 \
304
- --optimizer2=torch.optim.AdamW \
305
- --optimizer2.lr=0.0001
324
+ def configure_optimizers (self ):
325
+ optimizer = self .optimizer(self .parameters())
326
+ scheduler = self .scheduler(self .parameters())
327
+ return {" optimizer" : optimizer, " lr_scheduler" : scheduler}
328
+
329
+
330
+ cli = MyLightningCLI(MyModel, auto_configure_optimizers = False )
331
+
332
+ Note that for this example, classes are used as defaults. This is compatible with the type hints, since they are also
333
+ callables that receive the same first argument and return an instance of the class. Classes that have more than one
334
+ required argument will not work as default. For these cases a lambda function can be used, e.g. ``optimizer:
335
+ OptimizerCallable = lambda p: torch.optim.SGD(p, lr=0.01) ``.
306
336
307
337
308
338
Run from Python
0 commit comments