From 350c369bf04bbc2dabdeab4b21810655c7a49cd6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 26 Nov 2022 02:11:40 +0100 Subject: [PATCH 01/16] initial --- src/lightning_lite/strategies/fsdp.py | 31 ++++++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/src/lightning_lite/strategies/fsdp.py b/src/lightning_lite/strategies/fsdp.py index 8053992d18525..95d7606208f8f 100644 --- a/src/lightning_lite/strategies/fsdp.py +++ b/src/lightning_lite/strategies/fsdp.py @@ -13,7 +13,8 @@ # limitations under the License. from contextlib import contextmanager from datetime import timedelta -from typing import Any, Dict, Generator, List, Optional, Tuple, TYPE_CHECKING, Union +import functools +from typing import Any, Dict, Generator, List, Optional, Tuple, TYPE_CHECKING, Type, Union import torch from torch import Tensor @@ -35,7 +36,7 @@ ) from lightning_lite.utilities.distributed import group as _group from lightning_lite.utilities.distributed import ReduceOp -from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_12 +from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _TORCH_GREATER_EQUAL_1_13 from lightning_lite.utilities.rank_zero import rank_zero_only from lightning_lite.utilities.seed import reset_seed @@ -78,6 +79,9 @@ class FSDPStrategy(ParallelStrategy, _Sharded): computation overlapping. The pros and cons of each algorithm is explained in the class ``BackwardPrefetch``. mixed_precision: Mixed Precision config. By default, Lightning will enable FP16 if ``precision=16`` or BF16 if ``precision=bf16`` unless a config is passed in. This is only available in PyTorch 1.12 and later. + activation_checkpointing: A list of layer classes for which you want to enable activation checkpointing. This is typically + your transformer block (including attention + feed-forward). Enabling this can free up a significant amount of + memory at the cost of speed since activations in these layers need to be recomputed during backprop. \**kwargs: Optional keywoard arguments passed to the FSDP context manager which will configure the FSDP class when wrapping modules. """ @@ -94,6 +98,7 @@ def __init__( cpu_offload: Optional["CPUOffload"] = None, backward_prefetch: Optional["BackwardPrefetch"] = None, mixed_precision: Optional["MixedPrecision"] = None, + activation_checkpointing: Optional[List[Type[Module]]] = None, **kwargs: Any, ) -> None: if not _TORCH_GREATER_EQUAL_1_12: @@ -110,6 +115,7 @@ def __init__( self._process_group_backend: Optional[str] = process_group_backend self._timeout: Optional[timedelta] = timeout self._backward_sync_control = _FSDPBackwardSyncControl() + self._activation_checkpointing = activation_checkpointing or [] self._ddp_kwargs = kwargs self.cpu_offload = cpu_offload @@ -181,7 +187,7 @@ def setup_module(self, module: Module) -> "FullyShardedDataParallel": ): # If model is already wrapped, we need to avoid sending the `auto_wrap_policy` del self._ddp_kwargs["auto_wrap_policy"] - return FullyShardedDataParallel( + wrapped_module = FullyShardedDataParallel( module=module, cpu_offload=self.cpu_offload, backward_prefetch=self.backward_prefetch, @@ -189,6 +195,10 @@ def setup_module(self, module: Module) -> "FullyShardedDataParallel": device_id=self.root_device.index, **self._ddp_kwargs, ) + # activation checkpointing needs to be set up after wrapping the model + if _TORCH_GREATER_EQUAL_1_13 and self._activation_checkpointing: + _setup_activation_checkpointing(module=wrapped_module, layers=self._activation_checkpointing) + return wrapped_module def setup_optimizer(self, optimizer: Optimizer) -> Optimizer: """Set up an optimizer for a model wrapped with FSDP. @@ -291,6 +301,21 @@ def _set_world_ranks(self) -> None: rank_zero_only.rank = self.cluster_environment.global_rank() +def _setup_activation_checkpointing(module: "FullyShardedDataParallel", layers: List[Type[Module]]) -> None: + from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + CheckpointImpl, apply_activation_checkpointing, checkpoint_wrapper + ) + check_fn = lambda submodule: isinstance(submodule, layers) + wrapper = functools.partial( + checkpoint_wrapper, + offload_to_cpu=False, + checkpoint_impl=CheckpointImpl.NO_REENTRANT, + ) + apply_activation_checkpointing( + module, checkpoint_wrapper_fn=wrapper, check_fn=check_fn + ) + + class _FSDPBackwardSyncControl(_BackwardSyncControl): @contextmanager def no_backward_sync(self, module: Module) -> Generator: From 19de47437c89e60d85cab6b2df4fa5faa04baecb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 26 Nov 2022 02:19:18 +0100 Subject: [PATCH 02/16] input type --- src/lightning_lite/strategies/fsdp.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/lightning_lite/strategies/fsdp.py b/src/lightning_lite/strategies/fsdp.py index 95d7606208f8f..76975f1f39d2b 100644 --- a/src/lightning_lite/strategies/fsdp.py +++ b/src/lightning_lite/strategies/fsdp.py @@ -79,9 +79,9 @@ class FSDPStrategy(ParallelStrategy, _Sharded): computation overlapping. The pros and cons of each algorithm is explained in the class ``BackwardPrefetch``. mixed_precision: Mixed Precision config. By default, Lightning will enable FP16 if ``precision=16`` or BF16 if ``precision=bf16`` unless a config is passed in. This is only available in PyTorch 1.12 and later. - activation_checkpointing: A list of layer classes for which you want to enable activation checkpointing. This is typically - your transformer block (including attention + feed-forward). Enabling this can free up a significant amount of - memory at the cost of speed since activations in these layers need to be recomputed during backprop. + activation_checkpointing: A single layer or a list of layer classes for which you want to enable activation checkpointing. + This is typically your transformer block (including attention + feed-forward). Enabling this can free up a significant + amount of memory at the cost of speed since activations in these layers need to be recomputed during backpropagation. \**kwargs: Optional keywoard arguments passed to the FSDP context manager which will configure the FSDP class when wrapping modules. """ @@ -98,7 +98,7 @@ def __init__( cpu_offload: Optional["CPUOffload"] = None, backward_prefetch: Optional["BackwardPrefetch"] = None, mixed_precision: Optional["MixedPrecision"] = None, - activation_checkpointing: Optional[List[Type[Module]]] = None, + activation_checkpointing: Optional[Union[Type[Module], List[Type[Module]]]] = None, **kwargs: Any, ) -> None: if not _TORCH_GREATER_EQUAL_1_12: @@ -115,7 +115,8 @@ def __init__( self._process_group_backend: Optional[str] = process_group_backend self._timeout: Optional[timedelta] = timeout self._backward_sync_control = _FSDPBackwardSyncControl() - self._activation_checkpointing = activation_checkpointing or [] + activation_checkpointing = activation_checkpointing or [] + self._activation_checkpointing = list(activation_checkpointing) if not isinstance(activation_checkpointing, list) else activation_checkpointing self._ddp_kwargs = kwargs self.cpu_offload = cpu_offload @@ -305,7 +306,7 @@ def _setup_activation_checkpointing(module: "FullyShardedDataParallel", layers: from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( CheckpointImpl, apply_activation_checkpointing, checkpoint_wrapper ) - check_fn = lambda submodule: isinstance(submodule, layers) + check_fn = lambda submodule: isinstance(submodule, tuple(layers)) wrapper = functools.partial( checkpoint_wrapper, offload_to_cpu=False, From afba10fd9cc2a5e80e3f5ffbdf7a8111bb371f7f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 26 Nov 2022 02:20:54 +0100 Subject: [PATCH 03/16] input type --- src/lightning_lite/strategies/fsdp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning_lite/strategies/fsdp.py b/src/lightning_lite/strategies/fsdp.py index 76975f1f39d2b..bdbdaa31d2e6c 100644 --- a/src/lightning_lite/strategies/fsdp.py +++ b/src/lightning_lite/strategies/fsdp.py @@ -116,7 +116,7 @@ def __init__( self._timeout: Optional[timedelta] = timeout self._backward_sync_control = _FSDPBackwardSyncControl() activation_checkpointing = activation_checkpointing or [] - self._activation_checkpointing = list(activation_checkpointing) if not isinstance(activation_checkpointing, list) else activation_checkpointing + self._activation_checkpointing = [activation_checkpointing] if not isinstance(activation_checkpointing, list) else activation_checkpointing self._ddp_kwargs = kwargs self.cpu_offload = cpu_offload From bb0113b357af7fa0d6f6f49db769945225a67b87 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 26 Nov 2022 01:33:19 +0000 Subject: [PATCH 04/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning_lite/strategies/fsdp.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/lightning_lite/strategies/fsdp.py b/src/lightning_lite/strategies/fsdp.py index bdbdaa31d2e6c..d6b71bd029dde 100644 --- a/src/lightning_lite/strategies/fsdp.py +++ b/src/lightning_lite/strategies/fsdp.py @@ -11,10 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import functools from contextlib import contextmanager from datetime import timedelta -import functools -from typing import Any, Dict, Generator, List, Optional, Tuple, TYPE_CHECKING, Type, Union +from typing import Any, Dict, Generator, List, Optional, Tuple, Type, TYPE_CHECKING, Union import torch from torch import Tensor @@ -116,7 +116,9 @@ def __init__( self._timeout: Optional[timedelta] = timeout self._backward_sync_control = _FSDPBackwardSyncControl() activation_checkpointing = activation_checkpointing or [] - self._activation_checkpointing = [activation_checkpointing] if not isinstance(activation_checkpointing, list) else activation_checkpointing + self._activation_checkpointing = ( + [activation_checkpointing] if not isinstance(activation_checkpointing, list) else activation_checkpointing + ) self._ddp_kwargs = kwargs self.cpu_offload = cpu_offload @@ -304,17 +306,18 @@ def _set_world_ranks(self) -> None: def _setup_activation_checkpointing(module: "FullyShardedDataParallel", layers: List[Type[Module]]) -> None: from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( - CheckpointImpl, apply_activation_checkpointing, checkpoint_wrapper + apply_activation_checkpointing, + checkpoint_wrapper, + CheckpointImpl, ) + check_fn = lambda submodule: isinstance(submodule, tuple(layers)) wrapper = functools.partial( checkpoint_wrapper, offload_to_cpu=False, checkpoint_impl=CheckpointImpl.NO_REENTRANT, ) - apply_activation_checkpointing( - module, checkpoint_wrapper_fn=wrapper, check_fn=check_fn - ) + apply_activation_checkpointing(module, checkpoint_wrapper_fn=wrapper, check_fn=check_fn) class _FSDPBackwardSyncControl(_BackwardSyncControl): From 6c89b0f91a9462644c354236c4c8de66573c62fe Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 26 Nov 2022 03:41:11 +0100 Subject: [PATCH 05/16] tests --- src/lightning_lite/strategies/fsdp.py | 12 ++++--- tests/tests_lite/strategies/test_fsdp.py | 41 +++++++++++++++++++++++- 2 files changed, 48 insertions(+), 5 deletions(-) diff --git a/src/lightning_lite/strategies/fsdp.py b/src/lightning_lite/strategies/fsdp.py index cd12f306ed8c6..6e506fd6b3683 100644 --- a/src/lightning_lite/strategies/fsdp.py +++ b/src/lightning_lite/strategies/fsdp.py @@ -115,11 +115,16 @@ def __init__( self._process_group_backend: Optional[str] = process_group_backend self._timeout: Optional[timedelta] = timeout self._backward_sync_control = _FSDPBackwardSyncControl() + self._ddp_kwargs = kwargs + + if activation_checkpointing and not _TORCH_GREATER_EQUAL_1_13: + raise ValueError( + "Activation checkpointing requires torch >= 1.13.0. HINT: `pip install -U torch`" + ) activation_checkpointing = activation_checkpointing or [] self._activation_checkpointing = ( [activation_checkpointing] if not isinstance(activation_checkpointing, list) else activation_checkpointing ) - self._ddp_kwargs = kwargs self.cpu_offload = cpu_offload self.backward_prefetch = backward_prefetch @@ -184,9 +189,8 @@ def setup_module(self, module: Module) -> "FullyShardedDataParallel": :class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module.""" from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel - if ( - any(isinstance(mod, FullyShardedDataParallel) for mod in module.modules()) - and "auto_wrap_policy" in self._ddp_kwargs + if "auto_wrap_policy" in self._ddp_kwargs and any( + isinstance(mod, FullyShardedDataParallel) for mod in module.modules() ): # If model is already wrapped, we need to avoid sending the `auto_wrap_policy` del self._ddp_kwargs["auto_wrap_policy"] diff --git a/tests/tests_lite/strategies/test_fsdp.py b/tests/tests_lite/strategies/test_fsdp.py index 8f609d53c253a..535c037b999e4 100644 --- a/tests/tests_lite/strategies/test_fsdp.py +++ b/tests/tests_lite/strategies/test_fsdp.py @@ -13,7 +13,7 @@ # limitations under the License. from unittest import mock -from unittest.mock import MagicMock, Mock +from unittest.mock import ANY, MagicMock, Mock import pytest import torch @@ -77,3 +77,42 @@ def test_fsdp_no_backward_sync(): pass module.no_sync.assert_called_once() + + +@mock.patch("lightning_lite.strategies.fsdp._TORCH_GREATER_EQUAL_1_13", False) +def test_fsdp_activation_checkpointing_support(): + """Test that we error out if activation checkpointing requires a newer PyTorch version.""" + with pytest.raises(ValueError, match="Activation checkpointing requires torch >= 1.13.0"): + FSDPStrategy(activation_checkpointing=Mock()) + + +@RunIf(min_torch="1.13") +def test_fsdp_activation_checkpointing(): + """Test that the FSDP strategy can apply activation checkpointing to the given layers.""" + class Block1(nn.Linear): + pass + + class Block2(nn.Linear): + pass + + class Model(nn.Module): + def __init__(self): + super().__init__() + self.layer0 = nn.Sequential(Block1(4, 4), Block1(5, 5)) + self.layer1 = Block2(2, 2) + self.layer2 = nn.Linear(3, 3) + + strategy = FSDPStrategy(activation_checkpointing=Block1) + assert strategy._activation_checkpointing == [Block1] + + strategy = FSDPStrategy(activation_checkpointing=[Block1, Block2]) + assert strategy._activation_checkpointing == [Block1, Block2] + + strategy._parallel_devices = [torch.device("cuda", 0)] + with mock.patch( + "torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel" + ) as fsdp_mock, mock.patch( + "torch.distributed.algorithms._checkpoint.checkpoint_wrapper.apply_activation_checkpointing" + ) as ckpt_mock: + strategy.setup_module(Model()) + ckpt_mock.assert_called_with(fsdp_mock(), checkpoint_wrapper_fn=ANY, check_fn=ANY) From 9d038094b76d926c22ea86e13c00f2acb816995f Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 26 Nov 2022 03:42:10 +0100 Subject: [PATCH 06/16] checkpointing --- tests/tests_lite/strategies/test_fsdp_integration.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/tests_lite/strategies/test_fsdp_integration.py b/tests/tests_lite/strategies/test_fsdp_integration.py index 052133e265e4c..3e1c5217f5a9e 100644 --- a/tests/tests_lite/strategies/test_fsdp_integration.py +++ b/tests/tests_lite/strategies/test_fsdp_integration.py @@ -83,7 +83,10 @@ def _custom_auto_wrap_policy(module, recurse, unwrapped_params: int, min_num_par @pytest.mark.parametrize("manual_wrapping", [True, False]) def test_fsdp_train_save_load(manual_wrapping, precision): """Test FSDP training, saving and loading with different wrapping and precision settings.""" - strategy = FSDPStrategy(auto_wrap_policy=_custom_auto_wrap_policy) + strategy = FSDPStrategy( + auto_wrap_policy=_custom_auto_wrap_policy, + activation_checkpointing=[torch.nn.Linear], + ) lite = LightningLite(accelerator="cuda", strategy=strategy, devices=2, precision=precision) lite.launch() From 10e56759ad2fc8b8192f9297a2e43ff8dba5b031 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 26 Nov 2022 02:42:35 +0000 Subject: [PATCH 07/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning_lite/strategies/fsdp.py | 4 +--- tests/tests_lite/strategies/test_fsdp.py | 1 + 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/lightning_lite/strategies/fsdp.py b/src/lightning_lite/strategies/fsdp.py index 6e506fd6b3683..0fcd20a6c5803 100644 --- a/src/lightning_lite/strategies/fsdp.py +++ b/src/lightning_lite/strategies/fsdp.py @@ -118,9 +118,7 @@ def __init__( self._ddp_kwargs = kwargs if activation_checkpointing and not _TORCH_GREATER_EQUAL_1_13: - raise ValueError( - "Activation checkpointing requires torch >= 1.13.0. HINT: `pip install -U torch`" - ) + raise ValueError("Activation checkpointing requires torch >= 1.13.0. HINT: `pip install -U torch`") activation_checkpointing = activation_checkpointing or [] self._activation_checkpointing = ( [activation_checkpointing] if not isinstance(activation_checkpointing, list) else activation_checkpointing diff --git a/tests/tests_lite/strategies/test_fsdp.py b/tests/tests_lite/strategies/test_fsdp.py index 535c037b999e4..55a0f2a0c92e2 100644 --- a/tests/tests_lite/strategies/test_fsdp.py +++ b/tests/tests_lite/strategies/test_fsdp.py @@ -89,6 +89,7 @@ def test_fsdp_activation_checkpointing_support(): @RunIf(min_torch="1.13") def test_fsdp_activation_checkpointing(): """Test that the FSDP strategy can apply activation checkpointing to the given layers.""" + class Block1(nn.Linear): pass From 3f251b921e089761baca4168a0d9cb331e34b1d7 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 26 Nov 2022 04:08:12 +0100 Subject: [PATCH 08/16] fsdp in pl --- src/lightning_lite/strategies/fsdp.py | 9 ++-- .../strategies/fully_sharded_native.py | 31 ++++++++++--- .../test_ddp_fully_sharded_native.py | 46 +++++++++++++++++++ 3 files changed, 76 insertions(+), 10 deletions(-) diff --git a/src/lightning_lite/strategies/fsdp.py b/src/lightning_lite/strategies/fsdp.py index 0fcd20a6c5803..0a77164a356cc 100644 --- a/src/lightning_lite/strategies/fsdp.py +++ b/src/lightning_lite/strategies/fsdp.py @@ -79,9 +79,10 @@ class FSDPStrategy(ParallelStrategy, _Sharded): computation overlapping. The pros and cons of each algorithm is explained in the class ``BackwardPrefetch``. mixed_precision: Mixed Precision config. By default, Lightning will enable FP16 if ``precision=16`` or BF16 if ``precision=bf16`` unless a config is passed in. This is only available in PyTorch 1.12 and later. - activation_checkpointing: A single layer or a list of layer classes for which you want to enable activation checkpointing. - This is typically your transformer block (including attention + feed-forward). Enabling this can free up a significant - amount of memory at the cost of speed since activations in these layers need to be recomputed during backpropagation. + activation_checkpointing: A single layer or a list of layer classes for which you want to enable activation + checkpointing. This is typically your transformer block (including attention + feed-forward). + Enabling this can free up a significant amount of memory at the cost of speed since activations in + these layers need to be recomputed during backpropagation. \**kwargs: Optional keywoard arguments passed to the FSDP context manager which will configure the FSDP class when wrapping modules. """ @@ -200,9 +201,11 @@ def setup_module(self, module: Module) -> "FullyShardedDataParallel": device_id=self.root_device.index, **self._ddp_kwargs, ) + # activation checkpointing needs to be set up after wrapping the model if _TORCH_GREATER_EQUAL_1_13 and self._activation_checkpointing: _setup_activation_checkpointing(module=wrapped_module, layers=self._activation_checkpointing) + return wrapped_module def setup_optimizer(self, optimizer: Optimizer) -> Optimizer: diff --git a/src/pytorch_lightning/strategies/fully_sharded_native.py b/src/pytorch_lightning/strategies/fully_sharded_native.py index 69110db45507f..38ed803235ece 100644 --- a/src/pytorch_lightning/strategies/fully_sharded_native.py +++ b/src/pytorch_lightning/strategies/fully_sharded_native.py @@ -13,14 +13,15 @@ # limitations under the License. import contextlib import logging -from typing import Any, Dict, Generator, List, Optional, Union +from typing import Any, Dict, Generator, List, Optional, Type, Union import torch from torch import Tensor +from torch.nn import Module import pytorch_lightning as pl from lightning_lite.plugins import CheckpointIO, ClusterEnvironment -from lightning_lite.strategies.fsdp import _optimizer_has_flat_params +from lightning_lite.strategies.fsdp import _optimizer_has_flat_params, _setup_activation_checkpointing from lightning_lite.utilities.distributed import ( _get_default_process_group_backend_for_device, _init_dist_connection, @@ -38,7 +39,7 @@ from pytorch_lightning.strategies.strategy import TBroadcast from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12 +from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _TORCH_GREATER_EQUAL_1_13 from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -100,6 +101,10 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy): Mixed Precision config. By default, Lightning will enable FP16 if ``precision=16`` or BF16 if ``precision=bf16`` unless a config is passed in. This is only available in PyTorch 1.12 and later. + activation_checkpointing: A single layer or a list of layer classes for which you want to enable activation + checkpointing. This is typically your transformer block (including attention + feed-forward). + Enabling this can free up a significant amount of memory at the cost of speed since activations in + these layers need to be recomputed during backpropagation. \**kwargs: Passed to the FSDP context manager which will configure the FSDP class when wrapping modules. """ @@ -118,6 +123,7 @@ def __init__( cpu_offload: Optional[CPUOffload] = None, backward_prefetch: Optional[BackwardPrefetch] = None, mixed_precision: Optional[MixedPrecision] = None, + activation_checkpointing: Optional[Union[Type[Module], List[Type[Module]]]] = None, **kwargs: Any, ) -> None: if not _TORCH_GREATER_EQUAL_1_12: @@ -139,6 +145,12 @@ def __init__( self.backward_prefetch = backward_prefetch self.mixed_precision = mixed_precision self._rank_0_will_call_children_scripts: bool = False + if activation_checkpointing and not _TORCH_GREATER_EQUAL_1_13: + raise ValueError("Activation checkpointing requires torch >= 1.13.0. HINT: `pip install -U torch`") + activation_checkpointing = activation_checkpointing or [] + self._activation_checkpointing = ( + [activation_checkpointing] if not isinstance(activation_checkpointing, list) else activation_checkpointing + ) self.kwargs = kwargs @property @@ -209,15 +221,14 @@ def _setup_model(self, model: torch.nn.Module) -> FullyShardedDataParallel: :class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module.""" # If model is already wrapped, we need to avoid sending the `auto_wrap_policy` assert self.lightning_module is not None - if ( - any(isinstance(mod, FullyShardedDataParallel) for mod in self.lightning_module.modules()) - and "auto_wrap_policy" in self.kwargs + if "auto_wrap_policy" in self.kwargs and any( + isinstance(mod, FullyShardedDataParallel) for mod in self.lightning_module.modules() ): del self.kwargs["auto_wrap_policy"] log.detail(f"setting up FSDP model with device id: {self.root_device.index}, kwargs: {self.kwargs}") - return FullyShardedDataParallel( + wrapped_module = FullyShardedDataParallel( module=model, process_group=self.process_group, cpu_offload=self.cpu_offload, @@ -227,6 +238,12 @@ def _setup_model(self, model: torch.nn.Module) -> FullyShardedDataParallel: **self.kwargs, ) + # activation checkpointing needs to be set up after wrapping the model + if _TORCH_GREATER_EQUAL_1_13 and self._activation_checkpointing: + _setup_activation_checkpointing(module=wrapped_module, layers=self._activation_checkpointing) + + return wrapped_module + def setup(self, trainer: "pl.Trainer") -> None: assert self.accelerator is not None self.accelerator.setup(trainer) diff --git a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py index 5bb6b84d9e0f0..f6ad8fd1180cc 100644 --- a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py +++ b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py @@ -1,8 +1,11 @@ import os from typing import Any, Dict, Optional +from unittest import mock +from unittest.mock import ANY, Mock import pytest import torch +import torch.nn as nn from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint @@ -259,3 +262,46 @@ def configure_optimizers(self): model = NoFlatParametersModel() with pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters"): trainer.fit(model) + + +@mock.patch("pytorch_lightning.strategies.fully_sharded_native._TORCH_GREATER_EQUAL_1_13", False) +def test_fully_sharded_native_activation_checkpointing_support(): + """Test that we error out if activation checkpointing requires a newer PyTorch version.""" + with pytest.raises(ValueError, match="Activation checkpointing requires torch >= 1.13.0"): + DDPFullyShardedNativeStrategy(activation_checkpointing=Mock()) + + +@RunIf(min_torch="1.13") +def test_fsdp_activation_checkpointing(): + """Test that the FSDP strategy can apply activation checkpointing to the given layers.""" + + class Block1(nn.Linear): + pass + + class Block2(nn.Linear): + pass + + class Model(BoringModel): + def __init__(self): + super().__init__() + self.layer0 = nn.Sequential(Block1(4, 4), Block1(5, 5)) + self.layer1 = Block2(2, 2) + self.layer2 = nn.Linear(3, 3) + + strategy = DDPFullyShardedNativeStrategy(activation_checkpointing=Block1) + assert strategy._activation_checkpointing == [Block1] + + strategy = DDPFullyShardedNativeStrategy(activation_checkpointing=[Block1, Block2]) + assert strategy._activation_checkpointing == [Block1, Block2] + + model = Model() + strategy._parallel_devices = [torch.device("cuda", 0)] + strategy._lightning_module = model + strategy._process_group = Mock() + with mock.patch( + "pytorch_lightning.strategies.fully_sharded_native.FullyShardedDataParallel" + ) as fsdp_mock, mock.patch( + "torch.distributed.algorithms._checkpoint.checkpoint_wrapper.apply_activation_checkpointing" + ) as ckpt_mock: + strategy._setup_model(Model()) + ckpt_mock.assert_called_with(fsdp_mock(), checkpoint_wrapper_fn=ANY, check_fn=ANY) From c9e2e19b2b19eaea5dbe7e97d782ed7d6bddf4f8 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 26 Nov 2022 04:09:19 +0100 Subject: [PATCH 09/16] model --- tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py index f6ad8fd1180cc..f869ae981caa4 100644 --- a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py +++ b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py @@ -303,5 +303,5 @@ def __init__(self): ) as fsdp_mock, mock.patch( "torch.distributed.algorithms._checkpoint.checkpoint_wrapper.apply_activation_checkpointing" ) as ckpt_mock: - strategy._setup_model(Model()) + strategy._setup_model(model) ckpt_mock.assert_called_with(fsdp_mock(), checkpoint_wrapper_fn=ANY, check_fn=ANY) From 1993733d1027d31345bd6880a616bbe27858d26b Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 26 Nov 2022 04:09:46 +0100 Subject: [PATCH 10/16] name --- tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py index f869ae981caa4..6fdf614621b7f 100644 --- a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py +++ b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py @@ -272,7 +272,7 @@ def test_fully_sharded_native_activation_checkpointing_support(): @RunIf(min_torch="1.13") -def test_fsdp_activation_checkpointing(): +def test_fully_sharded_native_activation_checkpointing(): """Test that the FSDP strategy can apply activation checkpointing to the given layers.""" class Block1(nn.Linear): From 44f2971832ba37266e7d3d636598f2323361ae9e Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 26 Nov 2022 04:27:41 +0100 Subject: [PATCH 11/16] docs --- .../advanced/model_parallel.rst | 25 ++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/docs/source-pytorch/advanced/model_parallel.rst b/docs/source-pytorch/advanced/model_parallel.rst index a9922f4274154..ae9a83a7101d7 100644 --- a/docs/source-pytorch/advanced/model_parallel.rst +++ b/docs/source-pytorch/advanced/model_parallel.rst @@ -428,13 +428,36 @@ You can customize the strategy configuration by adjusting the arguments of :clas native_fsdp = DDPFullyShardedNativeStrategy(cpu_offload=CPUOffload(offload_params=True)) - trainer = pl.Trainer(strategy=native_fsdp, accelerator="gpu", device=4) + trainer = pl.Trainer(strategy=native_fsdp, accelerator="gpu", devices=4) Check out `this tutorial `__ to learn more about the native support. ---- + +Activation Checkpointing +======================== + +Activation checkpointing reduces GPU memory usage by avoiding the storing of intermediate activation tensors in +selected layers. The tradoff is that computation cost for the backpropagation increases, as the dropped activations +need to be recomputed. + +Enable checkpointing on large layers (like Transformers) by providing the layer class/type to the strategy: + +.. code-block:: python + + from pytorch_lightning.strategies import DDPFullyShardedNativeStrategy + + fsdp = DDPFullyShardedNativeStrategy( + activation_checkpointing=MyTransformerBlock, # or pass a list with multiple types + ) + trainer = pl.Trainer(strategy=fsdp, accelerator="gpu", devices=4) + + +---- + + .. _deepspeed_advanced: ********* From 0ea7468a907c7d394086195b11110ba4049cffa3 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 26 Nov 2022 04:30:05 +0100 Subject: [PATCH 12/16] changelog --- src/pytorch_lightning/CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index b93c90e3d4c7e..b021207289b7e 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -29,6 +29,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a warning when `self.log(..., logger=True)` is called without a configured logger ([#15814](https://github.com/Lightning-AI/lightning/pull/15814)) + +- Added support for activation checkpointing for the `DDPFullyShardedNativeStrategy` strategy ([#15826](https://github.com/Lightning-AI/lightning/pull/15826)) + + ### Changed - Drop PyTorch 1.9 support ([#15347](https://github.com/Lightning-AI/lightning/pull/15347)) From 05b8ad7761e06dd0e99d11cf575e5c62b46fb7f8 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 26 Nov 2022 04:31:09 +0100 Subject: [PATCH 13/16] typos --- docs/source-pytorch/advanced/model_parallel.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source-pytorch/advanced/model_parallel.rst b/docs/source-pytorch/advanced/model_parallel.rst index ae9a83a7101d7..b184463c1e88d 100644 --- a/docs/source-pytorch/advanced/model_parallel.rst +++ b/docs/source-pytorch/advanced/model_parallel.rst @@ -439,8 +439,8 @@ Check out `this tutorial Date: Sat, 26 Nov 2022 11:41:36 +0100 Subject: [PATCH 14/16] all_close --- tests/tests_lite/strategies/test_fsdp_integration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_lite/strategies/test_fsdp_integration.py b/tests/tests_lite/strategies/test_fsdp_integration.py index 3e1c5217f5a9e..193c8ed8c0dfd 100644 --- a/tests/tests_lite/strategies/test_fsdp_integration.py +++ b/tests/tests_lite/strategies/test_fsdp_integration.py @@ -71,7 +71,7 @@ def _assert_save_equality(lite, model, ckpt_path): # model parameters are identical after loading for current_param, loaded_param in zip(current_state_dict.values(), loaded_model.state_dict().values()): - assert torch.equal(current_param.float().cpu(), loaded_param.cpu()) + assert torch.allclose(current_param.float().cpu(), loaded_param.cpu()) def _custom_auto_wrap_policy(module, recurse, unwrapped_params: int, min_num_params: int = int(1e8)) -> bool: From e8bb5aa6ecfaab233e1063dc8f831d0335dcc3ac Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 26 Nov 2022 11:43:15 +0100 Subject: [PATCH 15/16] fix version check --- tests/tests_lite/strategies/test_fsdp.py | 1 + tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py | 1 + 2 files changed, 2 insertions(+) diff --git a/tests/tests_lite/strategies/test_fsdp.py b/tests/tests_lite/strategies/test_fsdp.py index 55a0f2a0c92e2..62880ea6e5ce9 100644 --- a/tests/tests_lite/strategies/test_fsdp.py +++ b/tests/tests_lite/strategies/test_fsdp.py @@ -79,6 +79,7 @@ def test_fsdp_no_backward_sync(): module.no_sync.assert_called_once() +@RunIf(min_torch="1.12") @mock.patch("lightning_lite.strategies.fsdp._TORCH_GREATER_EQUAL_1_13", False) def test_fsdp_activation_checkpointing_support(): """Test that we error out if activation checkpointing requires a newer PyTorch version.""" diff --git a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py index 6fdf614621b7f..a9b47aad1dca5 100644 --- a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py +++ b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py @@ -264,6 +264,7 @@ def configure_optimizers(self): trainer.fit(model) +@RunIf(min_torch="1.12") @mock.patch("pytorch_lightning.strategies.fully_sharded_native._TORCH_GREATER_EQUAL_1_13", False) def test_fully_sharded_native_activation_checkpointing_support(): """Test that we error out if activation checkpointing requires a newer PyTorch version.""" From 5e2cb259e64f18a9c2405f8984da0db5d8b791de Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 6 Dec 2022 07:10:45 +0100 Subject: [PATCH 16/16] remove default --- src/lightning_lite/strategies/fsdp.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/lightning_lite/strategies/fsdp.py b/src/lightning_lite/strategies/fsdp.py index 0a77164a356cc..9e19a5b77c359 100644 --- a/src/lightning_lite/strategies/fsdp.py +++ b/src/lightning_lite/strategies/fsdp.py @@ -319,7 +319,6 @@ def _setup_activation_checkpointing(module: "FullyShardedDataParallel", layers: check_fn = lambda submodule: isinstance(submodule, tuple(layers)) wrapper = functools.partial( checkpoint_wrapper, - offload_to_cpu=False, checkpoint_impl=CheckpointImpl.NO_REENTRANT, ) apply_activation_checkpointing(module, checkpoint_wrapper_fn=wrapper, check_fn=check_fn)