diff --git a/docs/source-pytorch/api_references.rst b/docs/source-pytorch/api_references.rst index 9ed881dc8ed85..c30f4c2019425 100644 --- a/docs/source-pytorch/api_references.rst +++ b/docs/source-pytorch/api_references.rst @@ -250,7 +250,6 @@ utilities :toctree: api :nosignatures: - argparse data deepspeed distributed diff --git a/docs/source-pytorch/cli/lightning_cli.rst b/docs/source-pytorch/cli/lightning_cli.rst index 5b0bf73754db3..e3220a6f0ed7d 100644 --- a/docs/source-pytorch/cli/lightning_cli.rst +++ b/docs/source-pytorch/cli/lightning_cli.rst @@ -115,13 +115,6 @@ Miscellaneous :button_link: lightning_cli_faq.html :height: 150 -.. displayitem:: - :header: Legacy CLIs - :description: Documentation for the legacy argparse-based CLIs - :col_css: col-md-6 - :button_link: ../common/hyperparameters.html - :height: 150 - .. raw:: html diff --git a/docs/source-pytorch/common/checkpointing_basic.rst b/docs/source-pytorch/common/checkpointing_basic.rst index 1be425de391ca..ab45d93c35e46 100644 --- a/docs/source-pytorch/common/checkpointing_basic.rst +++ b/docs/source-pytorch/common/checkpointing_basic.rst @@ -35,8 +35,8 @@ Inside a Lightning checkpoint you'll find: - State of all learning rate schedulers - State of all callbacks (for stateful callbacks) - State of datamodule (for stateful datamodules) -- The hyperparameters used for that model if passed in as hparams (Argparse.Namespace) -- The hyperparameters used for that datamodule if passed in as hparams (Argparse.Namespace) +- The hyperparameters (init arguments) with which the model was created +- The hyperparameters (init arguments) with which the datamodule was created - State of Loops ---- diff --git a/docs/source-pytorch/common/hyperparameters.rst b/docs/source-pytorch/common/hyperparameters.rst index b5d9b509a8208..ce356277bd054 100644 --- a/docs/source-pytorch/common/hyperparameters.rst +++ b/docs/source-pytorch/common/hyperparameters.rst @@ -1,209 +1,52 @@ :orphan: -.. testsetup:: * +Configure hyperparameters from the CLI +-------------------------------------- - from argparse import ArgumentParser, Namespace +You can use any CLI tool you want with Lightning. +For beginners, we recommand using Python's built-in argument parser. - sys.argv = ["foo"] -Configure hyperparameters from the CLI (legacy) ------------------------------------------------ +---- -.. warning:: - - This is the documentation for the use of Python's ``argparse`` to implement a CLI. This approach is no longer - recommended, and people are encouraged to use the new `LightningCLI <../cli/lightning_cli.html>`_ class instead. - - -Lightning has utilities to interact seamlessly with the command line ``ArgumentParser`` -and plays well with the hyperparameter optimization framework of your choice. - ----------- ArgumentParser ^^^^^^^^^^^^^^ -Lightning is designed to augment a lot of the functionality of the built-in Python ArgumentParser - -.. testcode:: - - from argparse import ArgumentParser - - parser = ArgumentParser() - parser.add_argument("--layer_1_dim", type=int, default=128) - args = parser.parse_args() - -This allows you to call your program like so: - -.. code-block:: bash - - python trainer.py --layer_1_dim 64 - ----------- -Argparser Best Practices -^^^^^^^^^^^^^^^^^^^^^^^^ -It is best practice to layer your arguments in three sections. +The :class:`~argparse.ArgumentParser` is a built-in feature in Python that let's you build CLI programs. +You can use it to make hyperparameters and other training settings available from the command line: -1. Trainer args (``accelerator``, ``devices``, ``num_nodes``, etc...) -2. Model specific arguments (``layer_dim``, ``num_layers``, ``learning_rate``, etc...) -3. Program arguments (``data_path``, ``cluster_email``, etc...) - -| - -We can do this as follows. First, in your ``LightningModule``, define the arguments -specific to that module. Remember that data splits or data paths may also be specific to -a module (i.e.: if your project has a model that trains on Imagenet and another on CIFAR-10). - -.. testcode:: - - class LitModel(LightningModule): - @staticmethod - def add_model_specific_args(parent_parser): - parser = parent_parser.add_argument_group("LitModel") - parser.add_argument("--encoder_layers", type=int, default=12) - parser.add_argument("--data_path", type=str, default="/some/path") - return parent_parser - -Now in your main trainer file, add the ``Trainer`` args, the program args, and add the model args - -.. testcode:: +.. code-block:: python - # ---------------- - # trainer_main.py - # ---------------- from argparse import ArgumentParser parser = ArgumentParser() - # add PROGRAM level args - parser.add_argument("--conda_env", type=str, default="some_name") - parser.add_argument("--notification_email", type=str, default="will@email.com") - - # add model specific args - parser = LitModel.add_model_specific_args(parser) + # Trainer arguments + parser.add_argument("--devices", type=int, default=2) - # add all the available trainer options to argparse - # ie: now --accelerator --devices --num_nodes ... --fast_dev_run all work in the cli - parser = Trainer.add_argparse_args(parser) + # Hyperparameters for the model + parser.add_argument("--layer_1_dim", type=int, default=128) + # Parse the user inputs and defaults (returns a argparse.Namespace) args = parser.parse_args() -Now you can call run your program like so: - -.. code-block:: bash - - python trainer_main.py --accelerator 'gpu' --devices 2 --num_nodes 2 --conda_env 'my_env' --encoder_layers 12 + # Use the parsed arguments in your program + trainer = Trainer(devices=args.devices) + model = MyModel(layer_1_dim=args.layer_1_dim) -Finally, make sure to start the training like so: - -.. code-block:: python - - # init the trainer like this - trainer = Trainer.from_argparse_args(args, early_stopping_callback=...) +This allows you to call your program like so: - # NOT like this - trainer = Trainer(accelerator=hparams.accelerator, devices=hparams.devices, ...) +.. code-block:: bash - # init the model with Namespace directly - model = LitModel(args) + python trainer.py --layer_1_dim 64 --devices 1 - # or init the model with all the key-value pairs - dict_args = vars(args) - model = LitModel(**dict_args) +---- ----------- -Trainer args +LightningCLI ^^^^^^^^^^^^ -To recap, add ALL possible trainer flags to the argparser and init the ``Trainer`` this way - -.. code-block:: python - - parser = ArgumentParser() - parser = Trainer.add_argparse_args(parser) - hparams = parser.parse_args() - - trainer = Trainer.from_argparse_args(hparams) - - # or if you need to pass in callbacks - trainer = Trainer.from_argparse_args(hparams, enable_checkpointing=..., callbacks=[...]) - ----------- - -Multiple Lightning Modules -^^^^^^^^^^^^^^^^^^^^^^^^^^ - -We often have multiple Lightning Modules where each one has different arguments. Instead of -polluting the ``main.py`` file, the ``LightningModule`` lets you define arguments for each one. - -.. testcode:: - - class LitMNIST(LightningModule): - def __init__(self, layer_1_dim, **kwargs): - super().__init__() - self.layer_1 = nn.Linear(28 * 28, layer_1_dim) - - @staticmethod - def add_model_specific_args(parent_parser): - parser = parent_parser.add_argument_group("LitMNIST") - parser.add_argument("--layer_1_dim", type=int, default=128) - return parent_parser - -.. testcode:: - - class GoodGAN(LightningModule): - def __init__(self, encoder_layers, **kwargs): - super().__init__() - self.encoder = Encoder(layers=encoder_layers) - - @staticmethod - def add_model_specific_args(parent_parser): - parser = parent_parser.add_argument_group("GoodGAN") - parser.add_argument("--encoder_layers", type=int, default=12) - return parent_parser - - -Now we can allow each model to inject the arguments it needs in the ``main.py`` - -.. code-block:: python - - def main(args): - dict_args = vars(args) - - # pick model - if args.model_name == "gan": - model = GoodGAN(**dict_args) - elif args.model_name == "mnist": - model = LitMNIST(**dict_args) - - trainer = Trainer.from_argparse_args(args) - trainer.fit(model) - - - if __name__ == "__main__": - parser = ArgumentParser() - parser = Trainer.add_argparse_args(parser) - - # figure out which model to use - parser.add_argument("--model_name", type=str, default="gan", help="gan or mnist") - - # THIS LINE IS KEY TO PULL THE MODEL NAME - temp_args, _ = parser.parse_known_args() - - # let the model add what it wants - if temp_args.model_name == "gan": - parser = GoodGAN.add_model_specific_args(parser) - elif temp_args.model_name == "mnist": - parser = LitMNIST.add_model_specific_args(parser) - - args = parser.parse_args() - - # train - main(args) - -and now we can train MNIST or the GAN using the command line interface! - -.. code-block:: bash - $ python main.py --model_name gan --encoder_layers 24 - $ python main.py --model_name mnist --layer_1_dim 128 +Python's argument parser works well for simple use cases, but it can become cumbersome to maintain for larger projects. +For example, every time you add, change, or delete an argument from your model, you will have to add, edit, or remove the corresponding ``parser.add_argument`` code. +The :doc:`Lightning CLI <../cli/lightning_cli>` provides a seamless integration with the Trainer and LightningModule for which the CLI arguments get generated automatically for you! diff --git a/docs/source-pytorch/common/trainer.rst b/docs/source-pytorch/common/trainer.rst index da5864f40043b..233d947919d9f 100644 --- a/docs/source-pytorch/common/trainer.rst +++ b/docs/source-pytorch/common/trainer.rst @@ -118,38 +118,14 @@ So you can run it like so: .. note:: - Pro-tip: You don't need to define all flags manually. Lightning can add them automatically + Pro-tip: You don't need to define all flags manually. + You can let the :doc:`LightningCLI <../cli/lightning_cli>` create the Trainer and model with arguments supplied from the CLI. -.. code-block:: python - - from argparse import ArgumentParser - - - def main(args): - model = LightningModule() - trainer = Trainer.from_argparse_args(args) - trainer.fit(model) - - - if __name__ == "__main__": - parser = ArgumentParser() - parser = Trainer.add_argparse_args(parser) - args = parser.parse_args() - - main(args) - -So you can run it like so: - -.. code-block:: bash - python main.py --accelerator 'gpu' --devices 2 --max_steps 10 --limit_train_batches 10 --any_trainer_arg x - -.. note:: - If you want to stop a training run early, you can press "Ctrl + C" on your keyboard. - The trainer will catch the ``KeyboardInterrupt`` and attempt a graceful shutdown, including - running accelerator callback ``on_train_end`` to clean up memory. The trainer object will also set - an attribute ``interrupted`` to ``True`` in such cases. If you have a callback which shuts down compute - resources, for example, you can conditionally run the shutdown logic for only uninterrupted runs. +If you want to stop a training run early, you can press "Ctrl + C" on your keyboard. +The trainer will catch the ``KeyboardInterrupt`` and attempt a graceful shutdown. The trainer object will also set +an attribute ``interrupted`` to ``True`` in such cases. If you have a callback which shuts down compute +resources, for example, you can conditionally run the shutdown logic for only uninterrupted runs by overriding :meth:`lightning.pytorch.Callback.on_exception`. ------------ diff --git a/examples/pl_domain_templates/generative_adversarial_net.py b/examples/pl_domain_templates/generative_adversarial_net.py index 0dbd34a3620d4..9dfae5ae0866f 100644 --- a/examples/pl_domain_templates/generative_adversarial_net.py +++ b/examples/pl_domain_templates/generative_adversarial_net.py @@ -127,20 +127,6 @@ def __init__( self.example_input_array = torch.zeros(2, self.hparams.latent_dim) - @staticmethod - def add_argparse_args(parent_parser: ArgumentParser, *, use_argument_group=True): - if use_argument_group: - parser = parent_parser.add_argument_group("GAN") - parser_out = parent_parser - else: - parser = ArgumentParser(parents=[parent_parser], add_help=False) - parser_out = parser - parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate") - parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient") - parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of second order momentum of gradient") - parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space") - return parser_out - def forward(self, z): return self.generator(z) @@ -226,8 +212,8 @@ def main(args: Namespace) -> None: # ------------------------ # If use distributed training PyTorch recommends to use DistributedDataParallel. # See: https://pytorch.org/docs/stable/nn.html#torch.nn.DataParallel - dm = MNISTDataModule.from_argparse_args(args) - trainer = Trainer.from_argparse_args(args) + dm = MNISTDataModule() + trainer = Trainer(accelerator="gpu", devices=1) # ------------------------ # 3 START TRAINING @@ -239,15 +225,11 @@ def main(args: Namespace) -> None: cli_lightning_logo() parser = ArgumentParser() - # Add program level args, if any. - # ------------------------ - # Add LightningDataLoader args - parser = MNISTDataModule.add_argparse_args(parser) - # Add model specific args - parser = GAN.add_argparse_args(parser) - # Add trainer args - parser = Trainer.add_argparse_args(parser) - # Parse all arguments + # Hyperparameters + parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate") + parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient") + parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of second order momentum of gradient") + parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space") args = parser.parse_args() main(args) diff --git a/examples/pl_domain_templates/reinforce_learn_Qnet.py b/examples/pl_domain_templates/reinforce_learn_Qnet.py index 0b4e3b3954e95..7873eadd4fed1 100644 --- a/examples/pl_domain_templates/reinforce_learn_Qnet.py +++ b/examples/pl_domain_templates/reinforce_learn_Qnet.py @@ -365,33 +365,10 @@ def get_device(self, batch) -> str: """Retrieve device currently being used by minibatch.""" return batch[0].device.index if self.on_gpu else "cpu" - @staticmethod - def add_model_specific_args(parent_parser): # pragma: no-cover - parser = parent_parser.add_argument_group("DQNLightning") - parser.add_argument("--batch_size", type=int, default=16, help="size of the batches") - parser.add_argument("--lr", type=float, default=1e-2, help="learning rate") - parser.add_argument("--env", type=str, default="CartPole-v1", help="gym environment tag") - parser.add_argument("--gamma", type=float, default=0.99, help="discount factor") - parser.add_argument("--sync_rate", type=int, default=10, help="how many frames do we update the target network") - parser.add_argument("--replay_size", type=int, default=1000, help="capacity of the replay buffer") - parser.add_argument( - "--warm_start_steps", - type=int, - default=1000, - help="how many samples do we use to fill our buffer at the start of training", - ) - parser.add_argument("--eps_last_frame", type=int, default=1000, help="what frame should epsilon stop decaying") - parser.add_argument("--eps_start", type=float, default=1.0, help="starting value of epsilon") - parser.add_argument("--eps_end", type=float, default=0.01, help="final value of epsilon") - parser.add_argument("--episode_length", type=int, default=200, help="max length of an episode") - return parent_parser - def main(args) -> None: model = DQNLightning(**vars(args)) - - trainer = Trainer(accelerator="gpu", devices=1, strategy="dp", val_check_interval=100) - + trainer = Trainer(accelerator="gpu", devices=1, val_check_interval=100) trainer.fit(model) @@ -399,8 +376,23 @@ def main(args) -> None: cli_lightning_logo() seed_everything(0) - parser = argparse.ArgumentParser(add_help=False) - parser = DQNLightning.add_model_specific_args(parser) + parser = argparse.ArgumentParser() + parser.add_argument("--batch_size", type=int, default=16, help="size of the batches") + parser.add_argument("--lr", type=float, default=1e-2, help="learning rate") + parser.add_argument("--env", type=str, default="CartPole-v1", help="gym environment tag") + parser.add_argument("--gamma", type=float, default=0.99, help="discount factor") + parser.add_argument("--sync_rate", type=int, default=10, help="how many frames do we update the target network") + parser.add_argument("--replay_size", type=int, default=1000, help="capacity of the replay buffer") + parser.add_argument( + "--warm_start_steps", + type=int, + default=1000, + help="how many samples do we use to fill our buffer at the start of training", + ) + parser.add_argument("--eps_last_frame", type=int, default=1000, help="what frame should epsilon stop decaying") + parser.add_argument("--eps_start", type=float, default=1.0, help="starting value of epsilon") + parser.add_argument("--eps_end", type=float, default=0.01, help="final value of epsilon") + parser.add_argument("--episode_length", type=int, default=200, help="max length of an episode") args = parser.parse_args() main(args) diff --git a/examples/pl_domain_templates/reinforce_learn_ppo.py b/examples/pl_domain_templates/reinforce_learn_ppo.py index d8dc6b360c6b2..5fe4a3ab9fd02 100644 --- a/examples/pl_domain_templates/reinforce_learn_ppo.py +++ b/examples/pl_domain_templates/reinforce_learn_ppo.py @@ -428,36 +428,10 @@ def train_dataloader(self) -> DataLoader: """Get train loader.""" return self._dataloader() - @staticmethod - def add_model_specific_args(parent_parser): # pragma: no-cover - parser = parent_parser.add_argument_group("PPOLightning") - parser.add_argument("--env", type=str, default="CartPole-v0") - parser.add_argument("--gamma", type=float, default=0.99, help="discount factor") - parser.add_argument("--lam", type=float, default=0.95, help="advantage discount factor") - parser.add_argument("--lr_actor", type=float, default=3e-4, help="learning rate of actor network") - parser.add_argument("--lr_critic", type=float, default=1e-3, help="learning rate of critic network") - parser.add_argument("--max_episode_len", type=int, default=1000, help="capacity of the replay buffer") - parser.add_argument("--batch_size", type=int, default=512, help="batch_size when training network") - parser.add_argument( - "--steps_per_epoch", - type=int, - default=2048, - help="how many action-state pairs to rollout for trajectory collection per epoch", - ) - parser.add_argument( - "--nb_optim_iters", type=int, default=4, help="how many steps of gradient descent to perform on each batch" - ) - parser.add_argument( - "--clip_ratio", type=float, default=0.2, help="hyperparameter for clipping in the policy objective" - ) - - return parent_parser - def main(args) -> None: model = PPOLightning(**vars(args)) - - trainer = Trainer.from_argparse_args(args) + trainer = Trainer(accelerator="cpu", devices=1, val_check_interval=100) trainer.fit(model) @@ -465,10 +439,26 @@ def main(args) -> None: cli_lightning_logo() seed_everything(0) - parent_parser = argparse.ArgumentParser(add_help=False) - parent_parser = Trainer.add_argparse_args(parent_parser) - - parser = PPOLightning.add_model_specific_args(parent_parser) + parser = argparse.ArgumentParser() + parser.add_argument("--env", type=str, default="CartPole-v0") + parser.add_argument("--gamma", type=float, default=0.99, help="discount factor") + parser.add_argument("--lam", type=float, default=0.95, help="advantage discount factor") + parser.add_argument("--lr_actor", type=float, default=3e-4, help="learning rate of actor network") + parser.add_argument("--lr_critic", type=float, default=1e-3, help="learning rate of critic network") + parser.add_argument("--max_episode_len", type=int, default=1000, help="capacity of the replay buffer") + parser.add_argument("--batch_size", type=int, default=512, help="batch_size when training network") + parser.add_argument( + "--steps_per_epoch", + type=int, + default=2048, + help="how many action-state pairs to rollout for trajectory collection per epoch", + ) + parser.add_argument( + "--nb_optim_iters", type=int, default=4, help="how many steps of gradient descent to perform on each batch" + ) + parser.add_argument( + "--clip_ratio", type=float, default=0.2, help="hyperparameter for clipping in the policy objective" + ) args = parser.parse_args() main(args) diff --git a/examples/pl_domain_templates/semantic_segmentation.py b/examples/pl_domain_templates/semantic_segmentation.py index 4b4d4e8883caa..3b5c59baa3b8a 100644 --- a/examples/pl_domain_templates/semantic_segmentation.py +++ b/examples/pl_domain_templates/semantic_segmentation.py @@ -383,19 +383,6 @@ def train_dataloader(self): def val_dataloader(self): return DataLoader(self.validset, batch_size=self.batch_size, shuffle=False) - @staticmethod - def add_model_specific_args(parent_parser): # pragma: no-cover - parser = parent_parser.add_argument_group("SegModel") - parser.add_argument("--data_path", type=str, help="path where dataset is stored") - parser.add_argument("--batch_size", type=int, default=16, help="size of the batches") - parser.add_argument("--lr", type=float, default=0.001, help="adam: learning rate") - parser.add_argument("--num_layers", type=int, default=5, help="number of layers on u-net") - parser.add_argument("--features_start", type=float, default=64, help="number of features in first layer") - parser.add_argument( - "--bilinear", action="store_true", default=False, help="whether to use bilinear interpolation or transposed" - ) - return parent_parser - def main(hparams: Namespace): # ------------------------ @@ -416,7 +403,7 @@ def main(hparams: Namespace): # ------------------------ # 3 INIT TRAINER # ------------------------ - trainer = Trainer.from_argparse_args(hparams) + trainer = Trainer() # ------------------------ # 5 START TRAINING @@ -426,8 +413,16 @@ def main(hparams: Namespace): if __name__ == "__main__": cli_lightning_logo() - parser = ArgumentParser(add_help=False) - parser = SegModel.add_model_specific_args(parser) + + parser = ArgumentParser() + parser.add_argument("--data_path", type=str, help="path where dataset is stored") + parser.add_argument("--batch_size", type=int, default=16, help="size of the batches") + parser.add_argument("--lr", type=float, default=0.001, help="adam: learning rate") + parser.add_argument("--num_layers", type=int, default=5, help="number of layers on u-net") + parser.add_argument("--features_start", type=float, default=64, help="number of features in first layer") + parser.add_argument( + "--bilinear", action="store_true", default=False, help="whether to use bilinear interpolation or transposed" + ) hparams = parser.parse_args() main(hparams) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 6733c73191ce0..3a7f16e19172a 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -271,6 +271,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed the unused `lightning.pytorch.utilities.metrics.metrics_to_scalars` function ([#16681](https://github.com/Lightning-AI/lightning/pull/16681)) +- Removed legacy argparse utilities ([#16708](https://github.com/Lightning-AI/lightning/pull/16708)) + * Removed `LightningDataModule` methods: `add_argparse_args()`, `from_argparse_args()`, `parse_argparser()`, `get_init_arguments_and_types()` + * Removed class methods from Trainer: `default_attributes()`, `from_argparse_args()`, `parse_argparser()`, `match_env_arguments()`, `add_argparse_args()` + * Removed functions from `lightning.pytorch.utilities.argparse`: `from_argparse_args()`, `parse_argparser()`, `parse_env_variables()`, `get_init_arguments_and_types()`, `add_argparse_args()` + * Removed functions from `lightning.pytorch.utilities.parsing`: `import str_to_bool()`, `str_to_bool_or_int()`, `str_to_bool_or_str()` + + - Removed support for passing a scheduling dictionary to `Trainer(accumulate_grad_batches=...)` ([#16729](https://github.com/Lightning-AI/lightning/pull/16729)) diff --git a/src/lightning/pytorch/core/datamodule.py b/src/lightning/pytorch/core/datamodule.py index 2ccf4cce6fcd2..629df0f690dfb 100644 --- a/src/lightning/pytorch/core/datamodule.py +++ b/src/lightning/pytorch/core/datamodule.py @@ -13,8 +13,7 @@ # limitations under the License. """LightningDataModule for loading DataLoaders with ease.""" import inspect -from argparse import ArgumentParser, Namespace -from typing import Any, Dict, IO, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Dict, IO, Mapping, Optional, Sequence, Union from torch.utils.data import DataLoader, Dataset, IterableDataset from typing_extensions import Self @@ -24,13 +23,7 @@ from lightning.pytorch.core.hooks import DataHooks from lightning.pytorch.core.mixins import HyperparametersMixin from lightning.pytorch.core.saving import _load_from_checkpoint -from lightning.pytorch.utilities.argparse import ( - add_argparse_args, - from_argparse_args, - get_init_arguments_and_types, - parse_argparser, -) -from lightning.pytorch.utilities.types import _ADD_ARGPARSE_RETURN, EVAL_DATALOADERS, TRAIN_DATALOADERS +from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS class LightningDataModule(DataHooks, HyperparametersMixin): @@ -72,49 +65,6 @@ def __init__(self) -> None: # Pointer to the trainer object self.trainer: Optional["pl.Trainer"] = None - @classmethod - def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs: Any) -> _ADD_ARGPARSE_RETURN: - """Extends existing argparse by default `LightningDataModule` attributes. - - Example:: - - parser = ArgumentParser(add_help=False) - parser = LightningDataModule.add_argparse_args(parser) - """ - return add_argparse_args(cls, parent_parser, **kwargs) - - @classmethod - def from_argparse_args( - cls, args: Union[Namespace, ArgumentParser], **kwargs: Any - ) -> Union["pl.LightningDataModule", "pl.Trainer"]: - """Create an instance from CLI arguments. - - Args: - args: The parser or namespace to take arguments from. Only known arguments will be - parsed and passed to the :class:`~lightning.pytorch.core.datamodule.LightningDataModule`. - **kwargs: Additional keyword arguments that may override ones in the parser or namespace. - These must be valid DataModule arguments. - - Example:: - - module = LightningDataModule.from_argparse_args(args) - """ - return from_argparse_args(cls, args, **kwargs) - - @classmethod - def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace: - return parse_argparser(cls, arg_parser) - - @classmethod - def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]: - r"""Scans the DataModule signature and returns argument names, types and default values. - - Returns: - List with tuples of 3 values: - (argument name, set with argument types, argument default value). - """ - return get_init_arguments_and_types(cls) - @classmethod def from_datasets( cls, diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index cc2c8d26fd8e6..ab6eb13ef4b07 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -20,12 +20,10 @@ # - WILLIAM FALCON """Trainer to automate the training.""" -import inspect import logging import math import os import warnings -from argparse import _ArgumentGroup, ArgumentParser, Namespace from copy import deepcopy from datetime import timedelta from typing import Any, Dict, Iterable, List, Optional, Type, Union @@ -75,13 +73,7 @@ from lightning.pytorch.trainer.states import RunningStage, TrainerFn, TrainerState, TrainerStatus from lightning.pytorch.trainer.supporters import _LITERAL_SUPPORTED_MODES, CombinedLoader from lightning.pytorch.utilities import GradClipAlgorithmType, parsing -from lightning.pytorch.utilities.argparse import ( - _defaults_from_env_vars, - add_argparse_args, - from_argparse_args, - parse_argparser, - parse_env_variables, -) +from lightning.pytorch.utilities.argparse import _defaults_from_env_vars from lightning.pytorch.utilities.compile import _maybe_unwrap_optimized, _verify_strategy_supports_compile from lightning.pytorch.utilities.data import has_len_all_ranks from lightning.pytorch.utilities.exceptions import MisconfigurationException @@ -1615,31 +1607,6 @@ def save_checkpoint( ) self._checkpoint_connector.save_checkpoint(filepath, weights_only=weights_only, storage_options=storage_options) - """ - Parsing properties - """ - - @classmethod - def default_attributes(cls) -> dict: - init_signature = inspect.signature(cls) - return {k: v.default for k, v in init_signature.parameters.items()} - - @classmethod - def from_argparse_args(cls: Any, args: Union[Namespace, ArgumentParser], **kwargs: Any) -> Any: - return from_argparse_args(cls, args, **kwargs) - - @classmethod - def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace: - return parse_argparser(cls, arg_parser) - - @classmethod - def match_env_arguments(cls) -> Namespace: - return parse_env_variables(cls) - - @classmethod - def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs: Any) -> Union[_ArgumentGroup, ArgumentParser]: - return add_argparse_args(cls, parent_parser, **kwargs) - """ State properties """ diff --git a/src/lightning/pytorch/utilities/argparse.py b/src/lightning/pytorch/utilities/argparse.py index 44dd26e41dc32..888a3b3755e1e 100644 --- a/src/lightning/pytorch/utilities/argparse.py +++ b/src/lightning/pytorch/utilities/argparse.py @@ -15,101 +15,32 @@ import inspect import os -from argparse import _ArgumentGroup, ArgumentParser, Namespace +from argparse import Namespace from ast import literal_eval from contextlib import suppress from functools import wraps -from typing import Any, Callable, cast, Dict, List, Tuple, Type, TypeVar, Union - -import lightning.pytorch as pl -from lightning.pytorch.utilities.parsing import str_to_bool, str_to_bool_or_int, str_to_bool_or_str -from lightning.pytorch.utilities.types import _ADD_ARGPARSE_RETURN +from typing import Any, Callable, cast, Type, TypeVar _T = TypeVar("_T", bound=Callable[..., Any]) -_ARGPARSE_CLS = Union[Type["pl.LightningDataModule"], Type["pl.Trainer"]] - - -def from_argparse_args( - cls: _ARGPARSE_CLS, - args: Union[Namespace, ArgumentParser], - **kwargs: Any, -) -> Union["pl.LightningDataModule", "pl.Trainer"]: - """Create an instance from CLI arguments. Eventually use variables from OS environment which are defined as - ``"PL__"``. - - Args: - cls: Lightning class - args: The parser or namespace to take arguments from. Only known arguments will be - parsed and passed to the :class:`Trainer`. - **kwargs: Additional keyword arguments that may override ones in the parser or namespace. - These must be valid Trainer arguments. - - Examples: - - >>> from lightning.pytorch import Trainer - >>> parser = ArgumentParser(add_help=False) - >>> parser = Trainer.add_argparse_args(parser) - >>> parser.add_argument('--my_custom_arg', default='something') # doctest: +SKIP - >>> args = Trainer.parse_argparser(parser.parse_args("")) - >>> trainer = Trainer.from_argparse_args(args, logger=False) - """ - if isinstance(args, ArgumentParser): - args = cls.parse_argparser(args) - - params = vars(args) - - # we only want to pass in valid Trainer args, the rest may be user specific - valid_kwargs = inspect.signature(cls.__init__).parameters - trainer_kwargs = {name: params[name] for name in valid_kwargs if name in params} - trainer_kwargs.update(**kwargs) - - return cls(**trainer_kwargs) - - -def parse_argparser(cls: _ARGPARSE_CLS, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace: - """Parse CLI arguments, required for custom bool types.""" - args = arg_parser.parse_args() if isinstance(arg_parser, ArgumentParser) else arg_parser - - types_default = {arg: (arg_types, arg_default) for arg, arg_types, arg_default in get_init_arguments_and_types(cls)} - - modified_args = {} - for k, v in vars(args).items(): - if k in types_default and v is None: - # We need to figure out if the None is due to using nargs="?" or if it comes from the default value - arg_types, arg_default = types_default[k] - if bool in arg_types and isinstance(arg_default, bool): - # Value has been passed as a flag => It is currently None, so we need to set it to True - # We always set to True, regardless of the default value. - # Users must pass False directly, but when passing nothing True is assumed. - # i.e. the only way to disable something that defaults to True is to use the long form: - # "--a_default_true_arg False" becomes False, while "--a_default_false_arg" becomes None, - # which then becomes True here. - - v = True - modified_args[k] = v - return Namespace(**modified_args) - -def parse_env_variables(cls: _ARGPARSE_CLS, template: str = "PL_%(cls_name)s_%(cls_argument)s") -> Namespace: +def _parse_env_variables(cls: Type, template: str = "PL_%(cls_name)s_%(cls_argument)s") -> Namespace: """Parse environment arguments if they are defined. Examples: >>> from lightning.pytorch import Trainer - >>> parse_env_variables(Trainer) + >>> _parse_env_variables(Trainer) Namespace() >>> import os >>> os.environ["PL_TRAINER_DEVICES"] = '42' >>> os.environ["PL_TRAINER_BLABLABLA"] = '1.23' - >>> parse_env_variables(Trainer) + >>> _parse_env_variables(Trainer) Namespace(devices=42) >>> del os.environ["PL_TRAINER_DEVICES"] """ - cls_arg_defaults = get_init_arguments_and_types(cls) - env_args = {} - for arg_name, _, _ in cls_arg_defaults: + for arg_name in inspect.signature(cls).parameters: env = template % {"cls_name": cls.__name__.upper(), "cls_argument": arg_name.upper()} val = os.environ.get(env) if not (val is None or val == ""): @@ -121,226 +52,16 @@ def parse_env_variables(cls: _ARGPARSE_CLS, template: str = "PL_%(cls_name)s_%(c return Namespace(**env_args) -def get_init_arguments_and_types(cls: _ARGPARSE_CLS) -> List[Tuple[str, Tuple, Any]]: - r"""Scans the class signature and returns argument names, types and default values. - - Returns: - List with tuples of 3 values: - (argument name, set with argument types, argument default value). - - Examples: - - >>> from lightning.pytorch import Trainer - >>> args = get_init_arguments_and_types(Trainer) - - """ - cls_default_params = inspect.signature(cls).parameters - name_type_default = [] - for arg in cls_default_params: - arg_type = cls_default_params[arg].annotation - arg_default = cls_default_params[arg].default - try: - if type(arg_type).__name__ == "_LiteralGenericAlias": - # Special case: Literal[a, b, c, ...] - arg_types = tuple({type(a) for a in arg_type.__args__}) - elif "typing.Literal" in str(arg_type) or "typing_extensions.Literal" in str(arg_type): - # Special case: Union[Literal, ...] - arg_types = tuple({type(a) for union_args in arg_type.__args__ for a in union_args.__args__}) - else: - # Special case: ComposedType[type0, type1, ...] - arg_types = tuple(arg_type.__args__) - except (AttributeError, TypeError): - arg_types = (arg_type,) - - name_type_default.append((arg, arg_types, arg_default)) - - return name_type_default - - -def _get_abbrev_qualified_cls_name(cls: _ARGPARSE_CLS) -> str: - assert isinstance(cls, type), repr(cls) - if cls.__module__.startswith("lightning.pytorch."): - # Abbreviate. - return f"pl.{cls.__name__}" - # Fully qualified. - return f"{cls.__module__}.{cls.__qualname__}" - - -def add_argparse_args( - cls: _ARGPARSE_CLS, - parent_parser: ArgumentParser, - *, - use_argument_group: bool = True, -) -> _ADD_ARGPARSE_RETURN: - r"""Extends existing argparse by default attributes for ``cls``. - - Args: - cls: Lightning class - parent_parser: - The custom cli arguments parser, which will be extended by - the class's default arguments. - use_argument_group: - By default, this is True, and uses ``add_argument_group`` to add - a new group. - If False, this will use old behavior. - - Returns: - If use_argument_group is True, returns ``parent_parser`` to keep old - workflows. If False, will return the new ArgumentParser object. - - Only arguments of the allowed types (str, float, int, bool) will - extend the ``parent_parser``. - - Raises: - RuntimeError: - If ``parent_parser`` is not an ``ArgumentParser`` instance - - Examples: - - >>> # Option 1: Default usage. - >>> import argparse - >>> from lightning.pytorch import Trainer - >>> parser = argparse.ArgumentParser() - >>> parser = Trainer.add_argparse_args(parser) - >>> args = parser.parse_args([]) - - >>> # Option 2: Disable use_argument_group (old behavior). - >>> import argparse - >>> from lightning.pytorch import Trainer - >>> parser = argparse.ArgumentParser() - >>> parser = Trainer.add_argparse_args(parser, use_argument_group=False) - >>> args = parser.parse_args([]) - """ - if isinstance(parent_parser, _ArgumentGroup): - raise RuntimeError("Please only pass an `ArgumentParser` instance.") - if use_argument_group: - group_name = _get_abbrev_qualified_cls_name(cls) - parser: _ADD_ARGPARSE_RETURN = parent_parser.add_argument_group(group_name) - else: - parser = ArgumentParser(parents=[parent_parser], add_help=False) - - ignore_arg_names = ["self", "args", "kwargs"] - - allowed_types = (str, int, float, bool) - - # Get symbols from cls or init function. - for symbol in (cls, cls.__init__): - args_and_types = get_init_arguments_and_types(symbol) # type: ignore[arg-type] - args_and_types = [x for x in args_and_types if x[0] not in ignore_arg_names] - if len(args_and_types) > 0: - break - - args_help = _parse_args_from_docstring(cls.__init__.__doc__ or cls.__doc__ or "") - - for arg, arg_types, arg_default in args_and_types: - arg_types = tuple(at for at in allowed_types if at in arg_types) - if not arg_types: - # skip argument with not supported type - continue - arg_kwargs: Dict[str, Any] = {} - if bool in arg_types: - arg_kwargs.update(nargs="?", const=True) - # if the only arg type is bool - if len(arg_types) == 1: - use_type: Callable[[str], Union[bool, int, float, str]] = str_to_bool - elif int in arg_types: - use_type = str_to_bool_or_int - elif str in arg_types: - use_type = str_to_bool_or_str - else: - # filter out the bool as we need to use more general - use_type = [at for at in arg_types if at is not bool][0] - else: - use_type = arg_types[0] - - if arg == "devices": - use_type = _devices_allowed_type - - # hack for types in (int, float) - if len(arg_types) == 2 and int in set(arg_types) and float in set(arg_types): - use_type = _int_or_float_type - - # hack for track_grad_norm - if arg == "track_grad_norm": - use_type = float - - # hack for precision - if arg == "precision": - use_type = _precision_allowed_type - - parser.add_argument( - f"--{arg}", - dest=arg, - default=arg_default, - type=use_type, - help=args_help.get(arg), - required=(arg_default == inspect._empty), - **arg_kwargs, - ) - - if use_argument_group: - return parent_parser - return parser - - -def _parse_args_from_docstring(docstring: str) -> Dict[str, str]: - arg_block_indent = None - current_arg = "" - parsed = {} - for line in docstring.split("\n"): - stripped = line.lstrip() - if not stripped: - continue - line_indent = len(line) - len(stripped) - if stripped.startswith(("Args:", "Arguments:", "Parameters:")): - arg_block_indent = line_indent + 4 - elif arg_block_indent is None: - continue - elif line_indent < arg_block_indent: - break - elif line_indent == arg_block_indent: - current_arg, arg_description = stripped.split(":", maxsplit=1) - parsed[current_arg] = arg_description.lstrip() - elif line_indent > arg_block_indent: - parsed[current_arg] += f" {stripped}" - return parsed - - -def _devices_allowed_type(x: str) -> Union[int, str]: - if "," in x: - return str(x) - return int(x) - - -def _int_or_float_type(x: Union[int, float, str]) -> Union[int, float]: - if "." in str(x): - return float(x) - return int(x) - - -def _precision_allowed_type(x: Union[int, str]) -> Union[int, str]: - """ - >>> _precision_allowed_type("32") - 32 - >>> _precision_allowed_type("bf16") - 'bf16' - """ - try: - return int(x) - except ValueError: - return x - - def _defaults_from_env_vars(fn: _T) -> _T: @wraps(fn) def insert_env_defaults(self: Any, *args: Any, **kwargs: Any) -> Any: cls = self.__class__ # get the class if args: # in case any args passed move them to kwargs - # parse only the argument names - cls_arg_names = [arg[0] for arg in get_init_arguments_and_types(cls)] + # parse the argument names + cls_arg_names = inspect.signature(cls).parameters # convert args to kwargs kwargs.update(dict(zip(cls_arg_names, args))) - env_variables = vars(parse_env_variables(cls)) + env_variables = vars(_parse_env_variables(cls)) # update the kwargs by env variables kwargs = dict(list(env_variables.items()) + list(kwargs.items())) diff --git a/src/lightning/pytorch/utilities/parsing.py b/src/lightning/pytorch/utilities/parsing.py index 56a3873ade33c..0763a268d156a 100644 --- a/src/lightning/pytorch/utilities/parsing.py +++ b/src/lightning/pytorch/utilities/parsing.py @@ -26,62 +26,6 @@ from lightning.pytorch.utilities.rank_zero import rank_zero_warn -def str_to_bool_or_str(val: str) -> Union[str, bool]: - """Possibly convert a string representation of truth to bool. Returns the input otherwise. Based on the python - implementation distutils.utils.strtobool. - - True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values are 'n', 'no', 'f', 'false', 'off', and '0'. - """ - lower = val.lower() - if lower in ("y", "yes", "t", "true", "on", "1"): - return True - if lower in ("n", "no", "f", "false", "off", "0"): - return False - return val - - -def str_to_bool(val: str) -> bool: - """Convert a string representation of truth to bool. - - True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values - are 'n', 'no', 'f', 'false', 'off', and '0'. - - Raises: - ValueError: - If ``val`` isn't in one of the aforementioned true or false values. - - >>> str_to_bool('YES') - True - >>> str_to_bool('FALSE') - False - """ - val_converted = str_to_bool_or_str(val) - if isinstance(val_converted, bool): - return val_converted - raise ValueError(f"invalid truth value {val_converted}") - - -def str_to_bool_or_int(val: str) -> Union[bool, int, str]: - """Convert a string representation to truth of bool if possible, or otherwise try to convert it to an int. - - >>> str_to_bool_or_int("FALSE") - False - >>> str_to_bool_or_int("1") - True - >>> str_to_bool_or_int("2") - 2 - >>> str_to_bool_or_int("abc") - 'abc' - """ - val_converted = str_to_bool_or_str(val) - if isinstance(val_converted, bool): - return val_converted - try: - return int(val_converted) - except ValueError: - return val_converted - - def is_picklable(obj: object) -> bool: """Tests if an object can be pickled.""" diff --git a/src/lightning/pytorch/utilities/types.py b/src/lightning/pytorch/utilities/types.py index 3f45998a62ced..bccba5ad9f2c1 100644 --- a/src/lightning/pytorch/utilities/types.py +++ b/src/lightning/pytorch/utilities/types.py @@ -16,7 +16,6 @@ - Do not include any `_TYPE` suffix - Types used in public hooks (as those in the `LightningModule` and `Callback`) should be public (no leading `_`) """ -from argparse import _ArgumentGroup, ArgumentParser from contextlib import contextmanager from dataclasses import dataclass from typing import Any, Dict, Generator, List, Optional, Protocol, runtime_checkable, Sequence, Type, Union @@ -43,7 +42,6 @@ Dict[str, Sequence[DataLoader]], ] EVAL_DATALOADERS = Union[DataLoader, Sequence[DataLoader]] -_ADD_ARGPARSE_RETURN = Union[_ArgumentGroup, ArgumentParser] @runtime_checkable diff --git a/tests/tests_pytorch/core/test_datamodules.py b/tests/tests_pytorch/core/test_datamodules.py index 3e18e0fcecf23..513fe6b496cb6 100644 --- a/tests/tests_pytorch/core/test_datamodules.py +++ b/tests/tests_pytorch/core/test_datamodules.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import pickle -from argparse import ArgumentParser, Namespace +from argparse import Namespace from dataclasses import dataclass from typing import Any, Dict from unittest import mock @@ -132,23 +132,6 @@ def __init__(self, data_dir: str): self.data_dir = data_dir -def test_dm_add_argparse_args(tmpdir): - parser = ArgumentParser() - parser = DataDirDataModule.add_argparse_args(parser) - args = parser.parse_args(["--data_dir", str(tmpdir)]) - assert args.data_dir == str(tmpdir) - - -def test_dm_init_from_argparse_args(tmpdir): - parser = ArgumentParser() - parser = DataDirDataModule.add_argparse_args(parser) - args = parser.parse_args(["--data_dir", str(tmpdir)]) - dm = DataDirDataModule.from_argparse_args(args) - dm.prepare_data() - dm.setup("fit") - assert dm.data_dir == args.data_dir == str(tmpdir) - - def test_dm_pickle_after_init(): dm = BoringDataModule() pickle.dumps(dm) diff --git a/tests/tests_pytorch/models/test_tpu.py b/tests/tests_pytorch/models/test_tpu.py index cdf13b0aac742..790f100fe3f58 100644 --- a/tests/tests_pytorch/models/test_tpu.py +++ b/tests/tests_pytorch/models/test_tpu.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from argparse import ArgumentParser from functools import partial from unittest import mock @@ -257,28 +256,6 @@ def test_accelerator_set_when_using_tpu(devices): assert isinstance(Trainer(accelerator="tpu", devices=devices).accelerator, TPUAccelerator) -@pytest.mark.parametrize( - ["cli_args", "expected"], - [ - pytest.param("--accelerator=tpu --devices=8", {"accelerator": "tpu", "devices": 8}, id="tpu-8"), - pytest.param("--accelerator=tpu --devices=1,", {"accelerator": "tpu", "devices": "1,"}, id="tpu-1,"), - ], -) -@RunIf(tpu=True, standalone=True) -@mock.patch.dict(os.environ, os.environ.copy(), clear=True) -def test_tpu_devices_with_argparse(cli_args, expected): - """Test passing devices for TPU accelerator in command line.""" - cli_args = cli_args.split(" ") if cli_args else [] - with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): - parser = ArgumentParser(add_help=False) - parser = Trainer.add_argparse_args(parent_parser=parser) - args = Trainer.parse_argparser(parser) - - for k, v in expected.items(): - assert getattr(args, k) == v - assert Trainer.from_argparse_args(args) - - @RunIf(tpu=True) @mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_if_test_works_with_checkpoint_false(tmpdir): diff --git a/tests/tests_pytorch/trainer/test_trainer_cli.py b/tests/tests_pytorch/trainer/test_trainer_cli.py deleted file mode 100644 index 0ed20d9856590..0000000000000 --- a/tests/tests_pytorch/trainer/test_trainer_cli.py +++ /dev/null @@ -1,191 +0,0 @@ -# Copyright The Lightning AI team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import inspect -import pickle -from argparse import ArgumentParser, Namespace -from unittest import mock - -import pytest - -import tests_pytorch.helpers.utils as tutils -from lightning.pytorch import Trainer -from lightning.pytorch.utilities import argparse - - -@mock.patch("argparse.ArgumentParser.parse_args") -def test_default_args(mock_argparse, tmpdir): - """Tests default argument parser for Trainer.""" - mock_argparse.return_value = Namespace(**Trainer.default_attributes()) - - # logger file to get meta - logger = tutils.get_default_logger(tmpdir) - - parser = ArgumentParser(add_help=False) - args = parser.parse_args() - args.logger = logger - - args.max_epochs = 5 - trainer = Trainer.from_argparse_args(args) - - assert isinstance(trainer, Trainer) - assert trainer.max_epochs == 5 - - -@pytest.mark.parametrize("cli_args", [["--accumulate_grad_batches=22"], []]) -def test_add_argparse_args_redefined(cli_args: list): - """Redefines some default Trainer arguments via the cli and tests the Trainer initialization correctness.""" - parser = ArgumentParser(add_help=False) - parser = Trainer.add_argparse_args(parent_parser=parser) - - args = parser.parse_args(cli_args) - - # make sure we can pickle args - pickle.dumps(args) - - # Check few deprecated args are not in namespace: - for depr_name in ("gradient_clip", "nb_gpu_nodes", "max_nb_epochs"): - assert depr_name not in args - - trainer = Trainer.from_argparse_args(args=args) - pickle.dumps(trainer) - - assert isinstance(trainer, Trainer) - - -@pytest.mark.parametrize("cli_args", [["--accumulate_grad_batches=22"], []]) -def test_add_argparse_args(cli_args: list): - """Simple test ensuring Trainer.add_argparse_args works.""" - parser = ArgumentParser(add_help=False) - parser = Trainer.add_argparse_args(parser) - args = parser.parse_args(cli_args) - assert Trainer.from_argparse_args(args) - - parser = ArgumentParser(add_help=False) - parser = Trainer.add_argparse_args(parser, use_argument_group=False) - args = parser.parse_args(cli_args) - assert Trainer.from_argparse_args(args) - - -def test_get_init_arguments_and_types(): - """Asserts a correctness of the `get_init_arguments_and_types` Trainer classmethod.""" - args = argparse.get_init_arguments_and_types(Trainer) - parameters = inspect.signature(Trainer).parameters - assert len(parameters) == len(args) - for arg in args: - assert parameters[arg[0]].default == arg[2] - - kwargs = {arg[0]: arg[2] for arg in args} - trainer = Trainer(**kwargs) - assert isinstance(trainer, Trainer) - - -@pytest.mark.parametrize("cli_args", [["--callbacks=1", "--logger"], ["--foo", "--bar=1"]]) -def test_add_argparse_args_redefined_error(cli_args: list, monkeypatch): - """Asserts thar an error raised in case of passing not default cli arguments.""" - - class _UnkArgError(Exception): - pass - - def _raise(): - raise _UnkArgError - - parser = ArgumentParser(add_help=False) - parser = Trainer.add_argparse_args(parent_parser=parser) - - monkeypatch.setattr(parser, "exit", lambda *args: _raise(), raising=True) - - with pytest.raises(_UnkArgError): - parser.parse_args(cli_args) - - -@pytest.mark.parametrize( - ["cli_args", "expected"], - [ - ( - "", - { - # These parameters are marked as Optional[...] in Trainer.__init__, with None as default. - # They should not be changed by the argparse interface. - "min_steps": None, - "accelerator": None, - "profiler": None, - }, - ), - ], -) -def test_argparse_args_parsing(cli_args, expected): - """Test multi type argument with bool.""" - cli_args = cli_args.split(" ") if cli_args else [] - with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): - parser = ArgumentParser(add_help=False) - parser = Trainer.add_argparse_args(parent_parser=parser) - args = Trainer.parse_argparser(parser) - - for k, v in expected.items(): - assert getattr(args, k) == v - assert Trainer.from_argparse_args(args) - - -@pytest.mark.parametrize( - "cli_args,expected", - [("", False), ("--fast_dev_run=0", False), ("--fast_dev_run=True", True), ("--fast_dev_run 2", 2)], -) -def test_argparse_args_parsing_fast_dev_run(cli_args, expected): - """Test multi type argument with bool.""" - cli_args = cli_args.split(" ") if cli_args else [] - with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): - parser = ArgumentParser(add_help=False) - parser = Trainer.add_argparse_args(parent_parser=parser) - args = Trainer.parse_argparser(parser) - assert args.fast_dev_run is expected - - -@pytest.mark.parametrize( - ["cli_args", "expected_parsed"], - [("", None), ("--accelerator gpu --devices 1", 1), ("--accelerator gpu --devices 0,", "0,")], -) -def test_argparse_args_parsing_devices(cli_args, expected_parsed, cuda_count_1): - """Test multi type argument with bool.""" - cli_args = cli_args.split(" ") if cli_args else [] - with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): - parser = ArgumentParser(add_help=False) - parser = Trainer.add_argparse_args(parent_parser=parser) - args = Trainer.parse_argparser(parser) - - assert args.devices == expected_parsed - assert Trainer.from_argparse_args(args) - - -@pytest.mark.parametrize( - ["cli_args", "extra_args"], - [ - ({}, {}), - ({"logger": False}, {}), - ({"logger": False}, {"logger": True}), - ({"logger": False}, {"enable_checkpointing": True}), - ], -) -def test_init_from_argparse_args(cli_args, extra_args): - unknown_args = dict(unknown_arg=0) - - # unknown args in the argparser/namespace should be ignored - with mock.patch("lightning.pytorch.Trainer.__init__", autospec=True, return_value=None) as init: - trainer = Trainer.from_argparse_args(Namespace(**cli_args, **unknown_args), **extra_args) - expected = dict(cli_args) - expected.update(extra_args) # extra args should override any cli arg - init.assert_called_with(trainer, **expected) - - # passing in unknown manual args should throw an error - with pytest.raises(TypeError, match=r"__init__\(\) got an unexpected keyword argument 'unknown_arg'"): - Trainer.from_argparse_args(Namespace(**cli_args), **extra_args, **unknown_args) diff --git a/tests/tests_pytorch/utilities/test_argparse.py b/tests/tests_pytorch/utilities/test_argparse.py deleted file mode 100644 index b7da6e1a34a25..0000000000000 --- a/tests/tests_pytorch/utilities/test_argparse.py +++ /dev/null @@ -1,254 +0,0 @@ -import io -from argparse import ArgumentParser, Namespace -from typing import Generic, TypeVar -from unittest.mock import MagicMock - -import pytest - -from lightning.pytorch import Trainer -from lightning.pytorch.utilities.argparse import ( - _devices_allowed_type, - _get_abbrev_qualified_cls_name, - _int_or_float_type, - _parse_args_from_docstring, - _precision_allowed_type, - add_argparse_args, - from_argparse_args, - parse_argparser, -) - - -class ArgparseExample: - def __init__(self, a: int, b: str = "", c: bool = False): - self.a = a - self.b = b - self.c = c - - -def test_from_argparse_args(): - args = Namespace(a=1, b="test", c=True, d="not valid") - my_instance = from_argparse_args(ArgparseExample, args) - assert my_instance.a == 1 - assert my_instance.b == "test" - assert my_instance.c - - parser = ArgumentParser() - mock_trainer = MagicMock() - _ = from_argparse_args(mock_trainer, parser) - mock_trainer.parse_argparser.assert_called_once_with(parser) - - -def test_parse_argparser(): - args = Namespace(a=1, b="test", c=None, d="not valid") - new_args = parse_argparser(ArgparseExample, args) - assert new_args.a == 1 - assert new_args.b == "test" - assert new_args.c - assert new_args.d == "not valid" - - -def test_parse_args_from_docstring_normal(): - args_help = _parse_args_from_docstring( - """Constrain image dataset - - Args: - root: Root directory of dataset where ``MNIST/processed/training.pt`` - and ``MNIST/processed/test.pt`` exist. - train: If ``True``, creates dataset from ``training.pt``, - otherwise from ``test.pt``. - normalize: mean and std deviation of the MNIST dataset. - download: If true, downloads the dataset from the internet and - puts it in root directory. If dataset is already downloaded, it is not - downloaded again. - num_samples: number of examples per selected class/digit - digits: list selected MNIST digits/classes - - Examples: - >>> dataset = TrialMNIST(download=True) - >>> len(dataset) - 300 - >>> sorted(set([d.item() for d in dataset.targets])) - [0, 1, 2] - >>> torch.bincount(dataset.targets) - tensor([100, 100, 100]) - """ - ) - - expected_args = ["root", "train", "normalize", "download", "num_samples", "digits"] - assert len(args_help.keys()) == len(expected_args) - assert all(x == y for x, y in zip(args_help.keys(), expected_args)) - assert ( - args_help["root"] == "Root directory of dataset where ``MNIST/processed/training.pt``" - " and ``MNIST/processed/test.pt`` exist." - ) - assert args_help["normalize"] == "mean and std deviation of the MNIST dataset." - - -def test_parse_args_from_docstring_empty(): - args_help = _parse_args_from_docstring( - """Constrain image dataset - - Args: - - Returns: - - Examples: - """ - ) - assert len(args_help.keys()) == 0 - - -def test_get_abbrev_qualified_cls_name(): - assert _get_abbrev_qualified_cls_name(Trainer) == "pl.Trainer" - - class NestedClass: - pass - - assert not __name__.startswith("lightning.pytorch.") - expected_name = f"{__name__}.test_get_abbrev_qualified_cls_name..NestedClass" - assert _get_abbrev_qualified_cls_name(NestedClass) == expected_name - - -class AddArgparseArgsExampleClass: - """ - Args: - my_parameter: A thing. - """ - - def __init__(self, my_parameter: int = 0): - pass - - -class AddArgparseArgsExampleClassViaInit: - def __init__(self, my_parameter: int = 0): - """ - Args: - my_parameter: A thing. - """ - pass - - -class AddArgparseArgsExampleClassNoDoc: - def __init__(self, my_parameter: int = 0): - pass - - -class AddArgparseArgsExampleClassGeneric: - T = TypeVar("T") - - class SomeClass(Generic[T]): - pass - - def __init__(self, invalid_class: SomeClass): - pass - - -class AddArgparseArgsExampleClassNoDefault: - """ - Args: - my_parameter: A thing. - """ - - def __init__(self, my_parameter: int): - pass - - -def extract_help_text(parser): - help_str_buffer = io.StringIO() - parser.print_help(file=help_str_buffer) - help_str_buffer.seek(0) - return help_str_buffer.read() - - -@pytest.mark.parametrize( - ["cls", "name"], - [ - [AddArgparseArgsExampleClass, "AddArgparseArgsExampleClass"], - [AddArgparseArgsExampleClassViaInit, "AddArgparseArgsExampleClassViaInit"], - [AddArgparseArgsExampleClassNoDoc, "AddArgparseArgsExampleClassNoDoc"], - [AddArgparseArgsExampleClassNoDefault, "AddArgparseArgsExampleClassNoDefault"], - ], -) -def test_add_argparse_args(cls, name): - """Tests that ``add_argparse_args`` handles argument groups correctly, and can be parsed.""" - parser = ArgumentParser() - parser_main = parser.add_argument_group("main") - parser_main.add_argument("--main_arg", type=str, default="") - parser_old = parser # For testing. - parser = add_argparse_args(cls, parser) - assert parser is parser_old - - # Check nominal argument groups. - help_text = extract_help_text(parser) - assert "main:" in help_text - assert "--main_arg" in help_text - assert f"{name}:" in help_text - assert "--my_parameter" in help_text - if cls is not AddArgparseArgsExampleClassNoDoc: - assert "A thing" in help_text - - fake_argv = ["--main_arg=abc", "--my_parameter=2"] - args = parser.parse_args(fake_argv) - assert args.main_arg == "abc" - assert args.my_parameter == 2 - - fake_argv = ["--main_arg=abc"] - if cls is AddArgparseArgsExampleClassNoDefault: - with pytest.raises(SystemExit): - parser.parse_args(fake_argv) - else: - args = parser.parse_args(fake_argv) - assert args.main_arg == "abc" - assert args.my_parameter == 0 - - -def test_negative_add_argparse_args(): - with pytest.raises(RuntimeError, match="Please only pass an `ArgumentParser` instance."): - parser = ArgumentParser() - add_argparse_args(AddArgparseArgsExampleClass, parser.add_argument_group("bad workflow")) - - -def test_add_argparse_args_no_argument_group(): - """Tests that ``add_argparse_args(..., use_argument_group=False)`` (old workflow) handles argument groups - correctly, and can be parsed.""" - parser = ArgumentParser() - parser.add_argument("--main_arg", type=str, default="") - parser_old = parser # For testing. - parser = add_argparse_args(AddArgparseArgsExampleClass, parser, use_argument_group=False) - assert parser is not parser_old - - # Check arguments. - help_text = extract_help_text(parser) - assert "--main_arg" in help_text - assert "--my_parameter" in help_text - assert "AddArgparseArgsExampleClass:" not in help_text - - fake_argv = ["--main_arg=abc", "--my_parameter=2"] - args = parser.parse_args(fake_argv) - assert args.main_arg == "abc" - assert args.my_parameter == 2 - - -def test_devices_allowed_type(): - assert _devices_allowed_type("1,2") == "1,2" - assert _devices_allowed_type("1") == 1 - - -def test_int_or_float_type(): - assert isinstance(_int_or_float_type("0.0"), float) - assert isinstance(_int_or_float_type("0"), int) - - -@pytest.mark.parametrize(["arg", "expected"], [["--precision=16", 16], ["--precision=bf16", "bf16"]]) -def test_precision_parsed_correctly(arg, expected): - """Test to ensure that the precision flag is passed correctly when adding argparse args.""" - parser = ArgumentParser() - parser = Trainer.add_argparse_args(parser) - fake_argv = [arg] - args = parser.parse_args(fake_argv) - assert args.precision == expected - - -def test_precision_type(): - assert _precision_allowed_type("bf16") == "bf16" - assert _precision_allowed_type("16") == 16 diff --git a/tests/tests_pytorch/utilities/test_parsing.py b/tests/tests_pytorch/utilities/test_parsing.py index 6169741226a77..ba7e7692263da 100644 --- a/tests/tests_pytorch/utilities/test_parsing.py +++ b/tests/tests_pytorch/utilities/test_parsing.py @@ -28,9 +28,6 @@ lightning_hasattr, lightning_setattr, parse_class_init_keys, - str_to_bool, - str_to_bool_or_int, - str_to_bool_or_str, ) unpicklable_function = lambda: None @@ -141,7 +138,7 @@ def test_lightning_getattr(): lightning_getattr(m, "this_attr_not_exist") -def test_lightning_setattr(tmpdir): +def test_lightning_setattr(): """Test that the lightning_setattr works in all cases.""" models, _ = model_and_trainer_cases() *__, model5, model6, model7, model8 = models @@ -165,46 +162,7 @@ def test_lightning_setattr(tmpdir): lightning_setattr(m, "this_attr_not_exist", None) -def test_str_to_bool_or_str(): - true_cases = ["y", "yes", "t", "true", "on", "1"] - false_cases = ["n", "no", "f", "false", "off", "0"] - other_cases = ["yyeess", "noooo", "lightning"] - - for case in true_cases: - assert str_to_bool_or_str(case) is True - - for case in false_cases: - assert str_to_bool_or_str(case) is False - - for case in other_cases: - assert str_to_bool_or_str(case) == case - - -def test_str_to_bool(): - true_cases = ["y", "yes", "t", "true", "on", "1"] - false_cases = ["n", "no", "f", "false", "off", "0"] - other_cases = ["yyeess", "noooo", "lightning"] - - for case in true_cases: - assert str_to_bool(case) is True - - for case in false_cases: - assert str_to_bool(case) is False - - for case in other_cases: - with pytest.raises(ValueError): - str_to_bool(case) - - -def test_str_to_bool_or_int(): - assert str_to_bool_or_int("0") is False - assert str_to_bool_or_int("1") is True - assert str_to_bool_or_int("true") is True - assert str_to_bool_or_int("2") == 2 - assert str_to_bool_or_int("abc") == "abc" - - -def test_is_picklable(tmpdir): +def test_is_picklable(): # See the full list of picklable types at # https://docs.python.org/3/library/pickle.html#pickle-picklable class UnpicklableClass: @@ -221,7 +179,7 @@ class UnpicklableClass: assert is_picklable(case) is False -def test_clean_namespace(tmpdir): +def test_clean_namespace(): # See the full list of picklable types at # https://docs.python.org/3/library/pickle.html#pickle-picklable class UnpicklableClass: @@ -235,7 +193,7 @@ class UnpicklableClass: assert test_case == {"1": None, "2": True, "3": 123} -def test_parse_class_init_keys(tmpdir): +def test_parse_class_init_keys(): class Class: def __init__(self, hparams, *my_args, anykw=42, **my_kwargs): pass @@ -243,7 +201,7 @@ def __init__(self, hparams, *my_args, anykw=42, **my_kwargs): assert parse_class_init_keys(Class) == ("self", "my_args", "my_kwargs") -def test_get_init_args(tmpdir): +def test_get_init_args(): class AutomaticArgsModel: def __init__(self, anyarg, anykw=42, **kwargs): super().__init__() @@ -280,7 +238,7 @@ def __init__(self, anyarg, childarg, anykw=42, childkw=42, **kwargs): assert my_class.result[1] == {"anyarg": "test1", "childarg": "test2", "anykw": 32, "childkw": 22, "otherkw": 123} -def test_attribute_dict(tmpdir): +def test_attribute_dict(): # Test initialization inputs = {"key1": 1, "key2": "abc"} ad = AttributeDict(inputs) @@ -298,7 +256,7 @@ def test_attribute_dict(tmpdir): assert ad.key1 == 123 -def test_flatten_dict(tmpdir): +def test_flatten_dict(): d = {"1": 1, "_": {"2": 2, "_": {"3": 3, "4": 4}}} expected = {"1": 1, "2": 2, "3": 3, "4": 4}