Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
0088397
main removals
awaelchli Feb 9, 2023
fda5eeb
docs
awaelchli Feb 9, 2023
63d64ee
update example
awaelchli Feb 9, 2023
36f29b8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 9, 2023
d437782
fix mypy
awaelchli Feb 10, 2023
879722f
removal
awaelchli Feb 10, 2023
2ea152f
remove
awaelchli Feb 10, 2023
588f669
removal
awaelchli Feb 10, 2023
dde285f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 10, 2023
496c2b0
changelog
awaelchli Feb 10, 2023
4e27b0e
Merge branch 'master' into removal/argparse
awaelchli Feb 11, 2023
78c386b
merge
awaelchli Feb 11, 2023
21b9d19
missing imports
awaelchli Feb 11, 2023
4b9e66f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 11, 2023
3d11a05
blabla
awaelchli Feb 11, 2023
d63acac
Merge remote-tracking branch 'origin/removal/argparse' into removal/a…
awaelchli Feb 11, 2023
d2a526d
x
awaelchli Feb 11, 2023
6c38857
Merge branch 'master' into removal/argparse
carmocca Feb 12, 2023
eea51a6
Update docs/source-pytorch/common/trainer.rst
awaelchli Feb 13, 2023
151b18e
Update docs/source-pytorch/common/trainer.rst
awaelchli Feb 13, 2023
2940480
collapse model-specific args
awaelchli Feb 13, 2023
5d979c4
remove _ARGPARSE_CLS
awaelchli Feb 13, 2023
0625c77
remove more parsing utils
awaelchli Feb 13, 2023
4a7c34a
simple argparse docs
awaelchli Feb 13, 2023
4956e62
restore
awaelchli Feb 13, 2023
170c9de
Merge branch 'master' into removal/argparse
Borda Feb 13, 2023
3c5e4d8
Merge branch 'master' into removal/argparse
Borda Feb 13, 2023
6a9c61d
no need for doctest
awaelchli Feb 13, 2023
8caa93b
fix reference
awaelchli Feb 13, 2023
93d2aa2
Merge branch 'master' into removal/argparse
awaelchli Feb 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/source-pytorch/api_references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,6 @@ utilities
:toctree: api
:nosignatures:

argparse
data
deepspeed
distributed
Expand Down
7 changes: 0 additions & 7 deletions docs/source-pytorch/cli/lightning_cli.rst
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,6 @@ Miscellaneous
:button_link: lightning_cli_faq.html
:height: 150

.. displayitem::
:header: Legacy CLIs
:description: Documentation for the legacy argparse-based CLIs
:col_css: col-md-6
:button_link: ../common/hyperparameters.html
:height: 150

.. raw:: html

</div>
Expand Down
4 changes: 2 additions & 2 deletions docs/source-pytorch/common/checkpointing_basic.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ Inside a Lightning checkpoint you'll find:
- State of all learning rate schedulers
- State of all callbacks (for stateful callbacks)
- State of datamodule (for stateful datamodules)
- The hyperparameters used for that model if passed in as hparams (Argparse.Namespace)
- The hyperparameters used for that datamodule if passed in as hparams (Argparse.Namespace)
- The hyperparameters (init arguments) with which the model was created
- The hyperparameters (init arguments) with which the datamodule was created
- State of Loops

----
Expand Down
205 changes: 24 additions & 181 deletions docs/source-pytorch/common/hyperparameters.rst
Original file line number Diff line number Diff line change
@@ -1,209 +1,52 @@
:orphan:

.. testsetup:: *
Configure hyperparameters from the CLI
--------------------------------------

from argparse import ArgumentParser, Namespace
You can use any CLI tool you want with Lightning.
For beginners, we recommand using Python's built-in argument parser.

sys.argv = ["foo"]

Configure hyperparameters from the CLI (legacy)
-----------------------------------------------
----

.. warning::

This is the documentation for the use of Python's ``argparse`` to implement a CLI. This approach is no longer
recommended, and people are encouraged to use the new `LightningCLI <../cli/lightning_cli.html>`_ class instead.


Lightning has utilities to interact seamlessly with the command line ``ArgumentParser``
and plays well with the hyperparameter optimization framework of your choice.

----------

ArgumentParser
^^^^^^^^^^^^^^
Lightning is designed to augment a lot of the functionality of the built-in Python ArgumentParser

.. testcode::

from argparse import ArgumentParser

parser = ArgumentParser()
parser.add_argument("--layer_1_dim", type=int, default=128)
args = parser.parse_args()

This allows you to call your program like so:

.. code-block:: bash

python trainer.py --layer_1_dim 64

----------

Argparser Best Practices
^^^^^^^^^^^^^^^^^^^^^^^^
It is best practice to layer your arguments in three sections.
The :class:`~argparse.ArgumentParser` is a built-in feature in Python that let's you build CLI programs.
You can use it to make hyperparameters and other training settings available from the command line:

1. Trainer args (``accelerator``, ``devices``, ``num_nodes``, etc...)
2. Model specific arguments (``layer_dim``, ``num_layers``, ``learning_rate``, etc...)
3. Program arguments (``data_path``, ``cluster_email``, etc...)

|

We can do this as follows. First, in your ``LightningModule``, define the arguments
specific to that module. Remember that data splits or data paths may also be specific to
a module (i.e.: if your project has a model that trains on Imagenet and another on CIFAR-10).

.. testcode::

class LitModel(LightningModule):
@staticmethod
def add_model_specific_args(parent_parser):
parser = parent_parser.add_argument_group("LitModel")
parser.add_argument("--encoder_layers", type=int, default=12)
parser.add_argument("--data_path", type=str, default="/some/path")
return parent_parser

Now in your main trainer file, add the ``Trainer`` args, the program args, and add the model args

.. testcode::
.. code-block:: python

# ----------------
# trainer_main.py
# ----------------
from argparse import ArgumentParser

parser = ArgumentParser()

# add PROGRAM level args
parser.add_argument("--conda_env", type=str, default="some_name")
parser.add_argument("--notification_email", type=str, default="[email protected]")

# add model specific args
parser = LitModel.add_model_specific_args(parser)
# Trainer arguments
parser.add_argument("--devices", type=int, default=2)

# add all the available trainer options to argparse
# ie: now --accelerator --devices --num_nodes ... --fast_dev_run all work in the cli
parser = Trainer.add_argparse_args(parser)
# Hyperparameters for the model
parser.add_argument("--layer_1_dim", type=int, default=128)

# Parse the user inputs and defaults (returns a argparse.Namespace)
args = parser.parse_args()

Now you can call run your program like so:

.. code-block:: bash

python trainer_main.py --accelerator 'gpu' --devices 2 --num_nodes 2 --conda_env 'my_env' --encoder_layers 12
# Use the parsed arguments in your program
trainer = Trainer(devices=args.devices)
model = MyModel(layer_1_dim=args.layer_1_dim)

Finally, make sure to start the training like so:

.. code-block:: python

# init the trainer like this
trainer = Trainer.from_argparse_args(args, early_stopping_callback=...)
This allows you to call your program like so:

# NOT like this
trainer = Trainer(accelerator=hparams.accelerator, devices=hparams.devices, ...)
.. code-block:: bash

# init the model with Namespace directly
model = LitModel(args)
python trainer.py --layer_1_dim 64 --devices 1

# or init the model with all the key-value pairs
dict_args = vars(args)
model = LitModel(**dict_args)
----

----------

Trainer args
LightningCLI
^^^^^^^^^^^^
To recap, add ALL possible trainer flags to the argparser and init the ``Trainer`` this way

.. code-block:: python

parser = ArgumentParser()
parser = Trainer.add_argparse_args(parser)
hparams = parser.parse_args()

trainer = Trainer.from_argparse_args(hparams)

# or if you need to pass in callbacks
trainer = Trainer.from_argparse_args(hparams, enable_checkpointing=..., callbacks=[...])

----------

Multiple Lightning Modules
^^^^^^^^^^^^^^^^^^^^^^^^^^

We often have multiple Lightning Modules where each one has different arguments. Instead of
polluting the ``main.py`` file, the ``LightningModule`` lets you define arguments for each one.

.. testcode::

class LitMNIST(LightningModule):
def __init__(self, layer_1_dim, **kwargs):
super().__init__()
self.layer_1 = nn.Linear(28 * 28, layer_1_dim)

@staticmethod
def add_model_specific_args(parent_parser):
parser = parent_parser.add_argument_group("LitMNIST")
parser.add_argument("--layer_1_dim", type=int, default=128)
return parent_parser

.. testcode::

class GoodGAN(LightningModule):
def __init__(self, encoder_layers, **kwargs):
super().__init__()
self.encoder = Encoder(layers=encoder_layers)

@staticmethod
def add_model_specific_args(parent_parser):
parser = parent_parser.add_argument_group("GoodGAN")
parser.add_argument("--encoder_layers", type=int, default=12)
return parent_parser


Now we can allow each model to inject the arguments it needs in the ``main.py``

.. code-block:: python

def main(args):
dict_args = vars(args)

# pick model
if args.model_name == "gan":
model = GoodGAN(**dict_args)
elif args.model_name == "mnist":
model = LitMNIST(**dict_args)

trainer = Trainer.from_argparse_args(args)
trainer.fit(model)


if __name__ == "__main__":
parser = ArgumentParser()
parser = Trainer.add_argparse_args(parser)

# figure out which model to use
parser.add_argument("--model_name", type=str, default="gan", help="gan or mnist")

# THIS LINE IS KEY TO PULL THE MODEL NAME
temp_args, _ = parser.parse_known_args()

# let the model add what it wants
if temp_args.model_name == "gan":
parser = GoodGAN.add_model_specific_args(parser)
elif temp_args.model_name == "mnist":
parser = LitMNIST.add_model_specific_args(parser)

args = parser.parse_args()

# train
main(args)

and now we can train MNIST or the GAN using the command line interface!

.. code-block:: bash

$ python main.py --model_name gan --encoder_layers 24
$ python main.py --model_name mnist --layer_1_dim 128
Python's argument parser works well for simple use cases, but it can become cumbersome to maintain for larger projects.
For example, every time you add, change, or delete an argument from your model, you will have to add, edit, or remove the corresponding ``parser.add_argument`` code.
The :doc:`Lightning CLI <../cli/lightning_cli>` provides a seamless integration with the Trainer and LightningModule for which the CLI arguments get generated automatically for you!
36 changes: 6 additions & 30 deletions docs/source-pytorch/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -118,38 +118,14 @@ So you can run it like so:

.. note::

Pro-tip: You don't need to define all flags manually. Lightning can add them automatically
Pro-tip: You don't need to define all flags manually.
You can let the :doc:`LightningCLI <../cli/lightning_cli>` create the Trainer and model with arguments supplied from the CLI.

.. code-block:: python

from argparse import ArgumentParser


def main(args):
model = LightningModule()
trainer = Trainer.from_argparse_args(args)
trainer.fit(model)


if __name__ == "__main__":
parser = ArgumentParser()
parser = Trainer.add_argparse_args(parser)
args = parser.parse_args()

main(args)

So you can run it like so:

.. code-block:: bash

python main.py --accelerator 'gpu' --devices 2 --max_steps 10 --limit_train_batches 10 --any_trainer_arg x

.. note::
If you want to stop a training run early, you can press "Ctrl + C" on your keyboard.
The trainer will catch the ``KeyboardInterrupt`` and attempt a graceful shutdown, including
running accelerator callback ``on_train_end`` to clean up memory. The trainer object will also set
an attribute ``interrupted`` to ``True`` in such cases. If you have a callback which shuts down compute
resources, for example, you can conditionally run the shutdown logic for only uninterrupted runs.
If you want to stop a training run early, you can press "Ctrl + C" on your keyboard.
The trainer will catch the ``KeyboardInterrupt`` and attempt a graceful shutdown. The trainer object will also set
an attribute ``interrupted`` to ``True`` in such cases. If you have a callback which shuts down compute
resources, for example, you can conditionally run the shutdown logic for only uninterrupted runs by overriding :meth:`lightning.pytorch.Callback.on_exception`.

------------

Expand Down
32 changes: 7 additions & 25 deletions examples/pl_domain_templates/generative_adversarial_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,20 +127,6 @@ def __init__(

self.example_input_array = torch.zeros(2, self.hparams.latent_dim)

@staticmethod
def add_argparse_args(parent_parser: ArgumentParser, *, use_argument_group=True):
if use_argument_group:
parser = parent_parser.add_argument_group("GAN")
parser_out = parent_parser
else:
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser_out = parser
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of second order momentum of gradient")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
return parser_out

def forward(self, z):
return self.generator(z)

Expand Down Expand Up @@ -226,8 +212,8 @@ def main(args: Namespace) -> None:
# ------------------------
# If use distributed training PyTorch recommends to use DistributedDataParallel.
# See: https://pytorch.org/docs/stable/nn.html#torch.nn.DataParallel
dm = MNISTDataModule.from_argparse_args(args)
trainer = Trainer.from_argparse_args(args)
dm = MNISTDataModule()
trainer = Trainer(accelerator="gpu", devices=1)

# ------------------------
# 3 START TRAINING
Expand All @@ -239,15 +225,11 @@ def main(args: Namespace) -> None:
cli_lightning_logo()
parser = ArgumentParser()

# Add program level args, if any.
# ------------------------
# Add LightningDataLoader args
parser = MNISTDataModule.add_argparse_args(parser)
# Add model specific args
parser = GAN.add_argparse_args(parser)
# Add trainer args
parser = Trainer.add_argparse_args(parser)
# Parse all arguments
# Hyperparameters
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of second order momentum of gradient")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
args = parser.parse_args()

main(args)
Loading