From 0faa278323dd22ce3a035abc4410ebc0b7ff381b Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 26 Nov 2022 11:31:16 +0100 Subject: [PATCH 1/9] simplify enabling cpu offload in fsdp strategy --- src/lightning_lite/strategies/fsdp.py | 21 ++++++++++++------- .../strategies/fully_sharded_native.py | 20 +++++++----------- tests/tests_lite/strategies/test_fsdp.py | 17 +++++++++++++-- .../test_ddp_fully_sharded_native.py | 15 ++++++++++++- 4 files changed, 51 insertions(+), 22 deletions(-) diff --git a/src/lightning_lite/strategies/fsdp.py b/src/lightning_lite/strategies/fsdp.py index 46a36bf95b763..9ca87dbc145b7 100644 --- a/src/lightning_lite/strategies/fsdp.py +++ b/src/lightning_lite/strategies/fsdp.py @@ -68,11 +68,10 @@ class FSDPStrategy(ParallelStrategy, _Sharded): `this tutorial `__ for more information. Arguments: - cpu_offload: CPU offloading config. Currently, only parameter and gradient CPU offload is supported. It - can be enabled via passing in ``cpu_offload=CPUOffload(offload_params=True)``. Note that this currently + cpu_offload: Enable offloading parameters and gradients to CPU to save GPU memory at the cost of speed. + You can also pass a config: ``cpu_offload=CPUOffload(offload_params=True)``. Note that this currently implicitly enables gradient offloading to CPU in order for parameters and gradients to be on same device - to work with the optimizer. This API is subject to change. Default is ``None`` in which case there - will be no offloading. + to work with the optimizer. This API is subject to change. Default: no offoading backward_prefetch: This is an experimental feature that is subject to change in the near future. It allows users to enable two different backward prefetching algorithms to help backward communication and computation overlapping. The pros and cons of each algorithm is explained in the class ``BackwardPrefetch``. @@ -91,7 +90,7 @@ def __init__( precision: Optional[Precision] = None, process_group_backend: Optional[str] = None, timeout: Optional[timedelta] = default_pg_timeout, - cpu_offload: Optional["CPUOffload"] = None, + cpu_offload: Optional[Union[bool, "CPUOffload"]] = None, backward_prefetch: Optional["BackwardPrefetch"] = None, mixed_precision: Optional["MixedPrecision"] = None, **kwargs: Any, @@ -112,7 +111,7 @@ def __init__( self._backward_sync_control = _FSDPBackwardSyncControl() self._ddp_kwargs = kwargs - self.cpu_offload = cpu_offload + self.cpu_offload = _init_cpu_offload(cpu_offload) self.backward_prefetch = backward_prefetch self.mixed_precision = mixed_precision @@ -269,7 +268,7 @@ def register_strategies(cls, strategy_registry: Dict) -> None: "fsdp_full_shard_offload", cls, description="Native FSDP with Full Sharding and CPU Offloading", - cpu_offload=CPUOffload(offload_params=True), + cpu_offload=True, ) def _setup_distributed(self) -> None: @@ -308,6 +307,14 @@ def no_backward_sync(self, module: Module) -> Generator: yield +def _init_cpu_offload(cpu_offload: Optional[Union[bool, "CPUOffload"]]) -> "CPUOffload": + from torch.distributed.fsdp import CPUOffload + + return ( + cpu_offload if isinstance(cpu_offload, CPUOffload) else CPUOffload(offload_params=cpu_offload) + ) + + def _optimizer_has_flat_params(optimizer: Optimizer) -> bool: from torch.distributed.fsdp import FlatParameter diff --git a/src/pytorch_lightning/strategies/fully_sharded_native.py b/src/pytorch_lightning/strategies/fully_sharded_native.py index 69110db45507f..6c48dc09929af 100644 --- a/src/pytorch_lightning/strategies/fully_sharded_native.py +++ b/src/pytorch_lightning/strategies/fully_sharded_native.py @@ -20,7 +20,7 @@ import pytorch_lightning as pl from lightning_lite.plugins import CheckpointIO, ClusterEnvironment -from lightning_lite.strategies.fsdp import _optimizer_has_flat_params +from lightning_lite.strategies.fsdp import _optimizer_has_flat_params, _init_cpu_offload from lightning_lite.utilities.distributed import ( _get_default_process_group_backend_for_device, _init_dist_connection, @@ -83,14 +83,10 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy): `this tutorial `__ for more information. Arguments: - cpu_offload: - CPU offloading config. Currently, only parameter and gradient CPU - offload is supported. It can be enabled via passing in - ``cpu_offload=CPUOffload(offload_params=True)``. Note that this - currently implicitly enables gradient offloading to CPU in order for - params and grads to be on same device to work with optimizer. This - API is subject to change. Default is ``None`` in which case there - will be no offloading. + cpu_offload: Enable offloading parameters and gradients to CPU to save GPU memory at the cost of speed. + You can also pass a config: ``cpu_offload=CPUOffload(offload_params=True)``. Note that this currently + implicitly enables gradient offloading to CPU in order for parameters and gradients to be on same device + to work with the optimizer. This API is subject to change. Default: no offoading backward_prefetch: This is an experimental feature that is subject to change in the the near future. It allows users to enable two different backward_prefetch @@ -115,7 +111,7 @@ def __init__( checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[PrecisionPlugin] = None, process_group_backend: Optional[str] = None, - cpu_offload: Optional[CPUOffload] = None, + cpu_offload: Optional[Union[bool, "CPUOffload"]] = None, backward_prefetch: Optional[BackwardPrefetch] = None, mixed_precision: Optional[MixedPrecision] = None, **kwargs: Any, @@ -135,7 +131,7 @@ def __init__( self._process_group = None self.num_nodes = 1 self._process_group_backend = process_group_backend - self.cpu_offload = cpu_offload + self.cpu_offload = _init_cpu_offload(cpu_offload) self.backward_prefetch = backward_prefetch self.mixed_precision = mixed_precision self._rank_0_will_call_children_scripts: bool = False @@ -386,6 +382,6 @@ def register_strategies(cls, strategy_registry: Dict) -> None: "fsdp_native_full_shard_offload", cls, description="Native FSDP with Full Sharding and CPU Offloading", - cpu_offload=CPUOffload(offload_params=True), + cpu_offload=True, ) cls._registered_strategies.append("fsdp_native_full_shard_offload") diff --git a/tests/tests_lite/strategies/test_fsdp.py b/tests/tests_lite/strategies/test_fsdp.py index 8f609d53c253a..03c458822d651 100644 --- a/tests/tests_lite/strategies/test_fsdp.py +++ b/tests/tests_lite/strategies/test_fsdp.py @@ -26,7 +26,7 @@ from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_12 if _TORCH_GREATER_EQUAL_1_12: - from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel, MixedPrecision + from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel, MixedPrecision, CPUOffload @mock.patch("lightning_lite.strategies.fsdp._TORCH_GREATER_EQUAL_1_12", False) @@ -36,13 +36,26 @@ def test_fsdp_support(*_): @RunIf(min_torch="1.12") -def test_fsdp_custom_mixed_precision(*_): +def test_fsdp_custom_mixed_precision(): """Test that passing a custom mixed precision config works.""" config = MixedPrecision() strategy = FSDPStrategy(mixed_precision=config) assert strategy.mixed_precision_config == config +@RunIf(min_torch="1.12") +def test_fsdp_cpu_offload(): + """Test the different ways cpu offloading can be enabled.""" + # bool + strategy = FSDPStrategy(cpu_offload=True) + assert strategy.cpu_offload == CPUOffload(offload_params=True) + + # dataclass + config = CPUOffload() + strategy = FSDPStrategy(cpu_offload=config) + assert strategy.cpu_offload == config + + @RunIf(min_torch="1.12") def test_fsdp_setup_optimizer_validation(): """Test that `setup_optimizer()` validates the param groups and reference to FSDP parameters.""" diff --git a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py index 5bb6b84d9e0f0..dd04626ca446f 100644 --- a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py +++ b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py @@ -14,7 +14,7 @@ from tests_pytorch.helpers.runif import RunIf if _TORCH_GREATER_EQUAL_1_12: - from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel, MixedPrecision + from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel, MixedPrecision, CPUOffload from torch.distributed.fsdp.wrap import wrap @@ -259,3 +259,16 @@ def configure_optimizers(self): model = NoFlatParametersModel() with pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters"): trainer.fit(model) + + +@RunIf(min_torch="1.12") +def test_fully_sharded_native_strategy_cpu_offload(): + """Test the different ways cpu offloading can be enabled.""" + # bool + strategy = DDPFullyShardedNativeStrategy(cpu_offload=True) + assert strategy.cpu_offload == CPUOffload(offload_params=True) + + # dataclass + config = CPUOffload() + strategy = DDPFullyShardedNativeStrategy(cpu_offload=config) + assert strategy.cpu_offload == config From e92a9334da63186337963d174c9544e518f768d5 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 26 Nov 2022 11:32:22 +0100 Subject: [PATCH 2/9] update docs --- docs/source-pytorch/advanced/model_parallel.rst | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/source-pytorch/advanced/model_parallel.rst b/docs/source-pytorch/advanced/model_parallel.rst index a9922f4274154..9f5e99d5fdb40 100644 --- a/docs/source-pytorch/advanced/model_parallel.rst +++ b/docs/source-pytorch/advanced/model_parallel.rst @@ -424,10 +424,9 @@ You can customize the strategy configuration by adjusting the arguments of :clas from pytorch_lightning import Trainer from pytorch_lightning.strategies import DDPFullyShardedNativeStrategy - from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload - native_fsdp = DDPFullyShardedNativeStrategy(cpu_offload=CPUOffload(offload_params=True)) + native_fsdp = DDPFullyShardedNativeStrategy(cpu_offload=True) trainer = pl.Trainer(strategy=native_fsdp, accelerator="gpu", device=4) From 88f142e22ec5ec6f809770b09ec0f553913defed Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 26 Nov 2022 10:35:22 +0000 Subject: [PATCH 3/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning_lite/strategies/fsdp.py | 4 +--- src/pytorch_lightning/strategies/fully_sharded_native.py | 2 +- tests/tests_lite/strategies/test_fsdp.py | 2 +- .../tests_pytorch/strategies/test_ddp_fully_sharded_native.py | 2 +- 4 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/lightning_lite/strategies/fsdp.py b/src/lightning_lite/strategies/fsdp.py index 9ca87dbc145b7..8236f92ce9deb 100644 --- a/src/lightning_lite/strategies/fsdp.py +++ b/src/lightning_lite/strategies/fsdp.py @@ -310,9 +310,7 @@ def no_backward_sync(self, module: Module) -> Generator: def _init_cpu_offload(cpu_offload: Optional[Union[bool, "CPUOffload"]]) -> "CPUOffload": from torch.distributed.fsdp import CPUOffload - return ( - cpu_offload if isinstance(cpu_offload, CPUOffload) else CPUOffload(offload_params=cpu_offload) - ) + return cpu_offload if isinstance(cpu_offload, CPUOffload) else CPUOffload(offload_params=cpu_offload) def _optimizer_has_flat_params(optimizer: Optimizer) -> bool: diff --git a/src/pytorch_lightning/strategies/fully_sharded_native.py b/src/pytorch_lightning/strategies/fully_sharded_native.py index 6c48dc09929af..83ac26683359c 100644 --- a/src/pytorch_lightning/strategies/fully_sharded_native.py +++ b/src/pytorch_lightning/strategies/fully_sharded_native.py @@ -20,7 +20,7 @@ import pytorch_lightning as pl from lightning_lite.plugins import CheckpointIO, ClusterEnvironment -from lightning_lite.strategies.fsdp import _optimizer_has_flat_params, _init_cpu_offload +from lightning_lite.strategies.fsdp import _init_cpu_offload, _optimizer_has_flat_params from lightning_lite.utilities.distributed import ( _get_default_process_group_backend_for_device, _init_dist_connection, diff --git a/tests/tests_lite/strategies/test_fsdp.py b/tests/tests_lite/strategies/test_fsdp.py index 03c458822d651..72f86adb43373 100644 --- a/tests/tests_lite/strategies/test_fsdp.py +++ b/tests/tests_lite/strategies/test_fsdp.py @@ -26,7 +26,7 @@ from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_12 if _TORCH_GREATER_EQUAL_1_12: - from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel, MixedPrecision, CPUOffload + from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision @mock.patch("lightning_lite.strategies.fsdp._TORCH_GREATER_EQUAL_1_12", False) diff --git a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py index dd04626ca446f..b92a2a3c4ab78 100644 --- a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py +++ b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py @@ -14,7 +14,7 @@ from tests_pytorch.helpers.runif import RunIf if _TORCH_GREATER_EQUAL_1_12: - from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel, MixedPrecision, CPUOffload + from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision from torch.distributed.fsdp.wrap import wrap From 8a5c3bf2a5f9c7e60b4ef00b0c3454fcd2ca45bc Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 26 Nov 2022 11:35:52 +0100 Subject: [PATCH 4/9] changelog --- src/pytorch_lightning/CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index b93c90e3d4c7e..41a37f41d075d 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -29,6 +29,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a warning when `self.log(..., logger=True)` is called without a configured logger ([#15814](https://github.com/Lightning-AI/lightning/pull/15814)) + +- Added the option to set `DDPFullyShardedNativeStrategy(cpu_offload=True|False)` via bool instead of needing to pass a configufation object ([#15832](https://github.com/Lightning-AI/lightning/pull/15832)) + + ### Changed - Drop PyTorch 1.9 support ([#15347](https://github.com/Lightning-AI/lightning/pull/15347)) From 0a83951230ece66de88bf6e2e118e089e9c3057a Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 26 Nov 2022 11:39:45 +0100 Subject: [PATCH 5/9] precommit --- src/lightning_lite/strategies/fsdp.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/lightning_lite/strategies/fsdp.py b/src/lightning_lite/strategies/fsdp.py index 9ca87dbc145b7..55cc8667cf57f 100644 --- a/src/lightning_lite/strategies/fsdp.py +++ b/src/lightning_lite/strategies/fsdp.py @@ -257,7 +257,6 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: def register_strategies(cls, strategy_registry: Dict) -> None: if not _TORCH_GREATER_EQUAL_1_12 or not torch.distributed.is_available(): return - from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload strategy_registry.register( "fsdp", From 4e2d0698ed21e3fd50e048a5ee90003913fa37d0 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 26 Nov 2022 11:46:50 +0100 Subject: [PATCH 6/9] fix mypy issue --- src/lightning_lite/strategies/fsdp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning_lite/strategies/fsdp.py b/src/lightning_lite/strategies/fsdp.py index 4067407bf68da..150114cf4695b 100644 --- a/src/lightning_lite/strategies/fsdp.py +++ b/src/lightning_lite/strategies/fsdp.py @@ -309,7 +309,7 @@ def no_backward_sync(self, module: Module) -> Generator: def _init_cpu_offload(cpu_offload: Optional[Union[bool, "CPUOffload"]]) -> "CPUOffload": from torch.distributed.fsdp import CPUOffload - return cpu_offload if isinstance(cpu_offload, CPUOffload) else CPUOffload(offload_params=cpu_offload) + return cpu_offload if isinstance(cpu_offload, CPUOffload) else CPUOffload(offload_params=bool(cpu_offload)) def _optimizer_has_flat_params(optimizer: Optimizer) -> bool: From 78dc79b3b4a61dcb10bc1f08dd616f3c52fb16f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 1 Dec 2022 18:53:59 -0500 Subject: [PATCH 7/9] Update src/lightning_lite/strategies/fsdp.py Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- src/lightning_lite/strategies/fsdp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning_lite/strategies/fsdp.py b/src/lightning_lite/strategies/fsdp.py index 150114cf4695b..8cdc89f3c0879 100644 --- a/src/lightning_lite/strategies/fsdp.py +++ b/src/lightning_lite/strategies/fsdp.py @@ -90,7 +90,7 @@ def __init__( precision: Optional[Precision] = None, process_group_backend: Optional[str] = None, timeout: Optional[timedelta] = default_pg_timeout, - cpu_offload: Optional[Union[bool, "CPUOffload"]] = None, + cpu_offload: Union[bool, "CPUOffload", None] = None, backward_prefetch: Optional["BackwardPrefetch"] = None, mixed_precision: Optional["MixedPrecision"] = None, **kwargs: Any, From fb50fa2367899942390bfbf525a66a8482bf6489 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 1 Dec 2022 18:54:07 -0500 Subject: [PATCH 8/9] Update src/pytorch_lightning/strategies/fully_sharded_native.py Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- src/pytorch_lightning/strategies/fully_sharded_native.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/strategies/fully_sharded_native.py b/src/pytorch_lightning/strategies/fully_sharded_native.py index 83ac26683359c..7650db669fd7e 100644 --- a/src/pytorch_lightning/strategies/fully_sharded_native.py +++ b/src/pytorch_lightning/strategies/fully_sharded_native.py @@ -111,7 +111,7 @@ def __init__( checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[PrecisionPlugin] = None, process_group_backend: Optional[str] = None, - cpu_offload: Optional[Union[bool, "CPUOffload"]] = None, + cpu_offload: Union[bool, "CPUOffload", None] = None, backward_prefetch: Optional[BackwardPrefetch] = None, mixed_precision: Optional[MixedPrecision] = None, **kwargs: Any, From 58786cdd2cbc930c2659b1953248a16c34e8680b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 7 Dec 2022 02:18:40 +0000 Subject: [PATCH 9/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_lightning/strategies/fully_sharded_native.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/pytorch_lightning/strategies/fully_sharded_native.py b/src/pytorch_lightning/strategies/fully_sharded_native.py index f9824743bfa98..f96b90f7b8b7c 100644 --- a/src/pytorch_lightning/strategies/fully_sharded_native.py +++ b/src/pytorch_lightning/strategies/fully_sharded_native.py @@ -21,7 +21,11 @@ import pytorch_lightning as pl from lightning_lite.plugins import CheckpointIO, ClusterEnvironment -from lightning_lite.strategies.fsdp import _optimizer_has_flat_params, _setup_activation_checkpointing, _init_cpu_offload +from lightning_lite.strategies.fsdp import ( + _init_cpu_offload, + _optimizer_has_flat_params, + _setup_activation_checkpointing, +) from lightning_lite.utilities.distributed import ( _get_default_process_group_backend_for_device, _init_dist_connection,