Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion docs/source-pytorch/advanced/model_parallel.rst
Original file line number Diff line number Diff line change
Expand Up @@ -428,13 +428,36 @@ You can customize the strategy configuration by adjusting the arguments of :clas


native_fsdp = DDPFullyShardedNativeStrategy(cpu_offload=CPUOffload(offload_params=True))
trainer = pl.Trainer(strategy=native_fsdp, accelerator="gpu", device=4)
trainer = pl.Trainer(strategy=native_fsdp, accelerator="gpu", devices=4)


Check out `this tutorial <https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html>`__ to learn more about the native support.

----


Activation Checkpointing
========================

Activation checkpointing reduces GPU memory usage by avoiding the storage of intermediate activation tensors in
selected layers. The tradeoff is that computation cost for the backpropagation increases, as the dropped activations
need to be recomputed.

Enable checkpointing on large layers (like Transformers) by providing the layer class/type to the strategy:

.. code-block:: python

from pytorch_lightning.strategies import DDPFullyShardedNativeStrategy

fsdp = DDPFullyShardedNativeStrategy(
activation_checkpointing=MyTransformerBlock, # or pass a list with multiple types
)
trainer = pl.Trainer(strategy=fsdp, accelerator="gpu", devices=4)


----


.. _deepspeed_advanced:

*********
Expand Down
45 changes: 39 additions & 6 deletions src/lightning_lite/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from contextlib import contextmanager
from datetime import timedelta
from typing import Any, Dict, Generator, List, Optional, Tuple, TYPE_CHECKING, Union
from typing import Any, Dict, Generator, List, Optional, Tuple, Type, TYPE_CHECKING, Union

import torch
from torch import Tensor
Expand All @@ -35,7 +36,7 @@
)
from lightning_lite.utilities.distributed import group as _group
from lightning_lite.utilities.distributed import ReduceOp
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_12
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _TORCH_GREATER_EQUAL_1_13
from lightning_lite.utilities.rank_zero import rank_zero_only
from lightning_lite.utilities.seed import reset_seed

Expand Down Expand Up @@ -78,6 +79,10 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
computation overlapping. The pros and cons of each algorithm is explained in the class ``BackwardPrefetch``.
mixed_precision: Mixed Precision config. By default, Lightning will enable FP16 if ``precision=16`` or BF16
if ``precision=bf16`` unless a config is passed in. This is only available in PyTorch 1.12 and later.
activation_checkpointing: A single layer or a list of layer classes for which you want to enable activation
checkpointing. This is typically your transformer block (including attention + feed-forward).
Enabling this can free up a significant amount of memory at the cost of speed since activations in
these layers need to be recomputed during backpropagation.
\**kwargs: Optional keywoard arguments passed to the FSDP context manager which will configure the FSDP class
when wrapping modules.
"""
Expand All @@ -94,6 +99,7 @@ def __init__(
cpu_offload: Optional["CPUOffload"] = None,
backward_prefetch: Optional["BackwardPrefetch"] = None,
mixed_precision: Optional["MixedPrecision"] = None,
activation_checkpointing: Optional[Union[Type[Module], List[Type[Module]]]] = None,
**kwargs: Any,
) -> None:
if not _TORCH_GREATER_EQUAL_1_12:
Expand All @@ -112,6 +118,13 @@ def __init__(
self._backward_sync_control = _FSDPBackwardSyncControl()
self._ddp_kwargs = kwargs

if activation_checkpointing and not _TORCH_GREATER_EQUAL_1_13:
raise ValueError("Activation checkpointing requires torch >= 1.13.0. HINT: `pip install -U torch`")
activation_checkpointing = activation_checkpointing or []
self._activation_checkpointing = (
[activation_checkpointing] if not isinstance(activation_checkpointing, list) else activation_checkpointing
)

self.cpu_offload = cpu_offload
self.backward_prefetch = backward_prefetch
self.mixed_precision = mixed_precision
Expand Down Expand Up @@ -175,13 +188,12 @@ def setup_module(self, module: Module) -> "FullyShardedDataParallel":
:class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module."""
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel

if (
any(isinstance(mod, FullyShardedDataParallel) for mod in module.modules())
and "auto_wrap_policy" in self._ddp_kwargs
if "auto_wrap_policy" in self._ddp_kwargs and any(
isinstance(mod, FullyShardedDataParallel) for mod in module.modules()
):
# If model is already wrapped, we need to avoid sending the `auto_wrap_policy`
del self._ddp_kwargs["auto_wrap_policy"]
return FullyShardedDataParallel(
wrapped_module = FullyShardedDataParallel(
module=module,
cpu_offload=self.cpu_offload,
backward_prefetch=self.backward_prefetch,
Expand All @@ -190,6 +202,12 @@ def setup_module(self, module: Module) -> "FullyShardedDataParallel":
**self._ddp_kwargs,
)

# activation checkpointing needs to be set up after wrapping the model
if _TORCH_GREATER_EQUAL_1_13 and self._activation_checkpointing:
_setup_activation_checkpointing(module=wrapped_module, layers=self._activation_checkpointing)

return wrapped_module

def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
"""Set up an optimizer for a model wrapped with FSDP.

Expand Down Expand Up @@ -291,6 +309,21 @@ def _set_world_ranks(self) -> None:
rank_zero_only.rank = self.cluster_environment.global_rank()


def _setup_activation_checkpointing(module: "FullyShardedDataParallel", layers: List[Type[Module]]) -> None:
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
checkpoint_wrapper,
CheckpointImpl,
)

check_fn = lambda submodule: isinstance(submodule, tuple(layers))
wrapper = functools.partial(
checkpoint_wrapper,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)
apply_activation_checkpointing(module, checkpoint_wrapper_fn=wrapper, check_fn=check_fn)


class _FSDPBackwardSyncControl(_BackwardSyncControl):
@contextmanager
def no_backward_sync(self, module: Module) -> Generator:
Expand Down
4 changes: 4 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added a warning when `self.log(..., logger=True)` is called without a configured logger ([#15814](https://github.com/Lightning-AI/lightning/pull/15814))


- Added support for activation checkpointing for the `DDPFullyShardedNativeStrategy` strategy ([#15826](https://github.com/Lightning-AI/lightning/pull/15826))


### Changed

- Drop PyTorch 1.9 support ([#15347](https://github.com/Lightning-AI/lightning/pull/15347))
Expand Down
31 changes: 24 additions & 7 deletions src/pytorch_lightning/strategies/fully_sharded_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
# limitations under the License.
import contextlib
import logging
from typing import Any, Dict, Generator, List, Optional, Union
from typing import Any, Dict, Generator, List, Optional, Type, Union

import torch
from torch import Tensor
from torch.nn import Module

import pytorch_lightning as pl
from lightning_lite.plugins import CheckpointIO, ClusterEnvironment
from lightning_lite.strategies.fsdp import _optimizer_has_flat_params
from lightning_lite.strategies.fsdp import _optimizer_has_flat_params, _setup_activation_checkpointing
from lightning_lite.utilities.distributed import (
_get_default_process_group_backend_for_device,
_init_dist_connection,
Expand All @@ -38,7 +39,7 @@
from pytorch_lightning.strategies.strategy import TBroadcast
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _TORCH_GREATER_EQUAL_1_13
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only
from pytorch_lightning.utilities.types import STEP_OUTPUT
Expand Down Expand Up @@ -100,6 +101,10 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
Mixed Precision config. By default, Lightning will enable FP16 if ``precision=16``
or BF16 if ``precision=bf16`` unless a config is passed in.
This is only available in PyTorch 1.12 and later.
activation_checkpointing: A single layer or a list of layer classes for which you want to enable activation
checkpointing. This is typically your transformer block (including attention + feed-forward).
Enabling this can free up a significant amount of memory at the cost of speed since activations in
these layers need to be recomputed during backpropagation.
\**kwargs: Passed to the FSDP context manager which will configure the FSDP class when wrapping modules.

"""
Expand All @@ -118,6 +123,7 @@ def __init__(
cpu_offload: Optional[CPUOffload] = None,
backward_prefetch: Optional[BackwardPrefetch] = None,
mixed_precision: Optional[MixedPrecision] = None,
activation_checkpointing: Optional[Union[Type[Module], List[Type[Module]]]] = None,
**kwargs: Any,
) -> None:
if not _TORCH_GREATER_EQUAL_1_12:
Expand All @@ -139,6 +145,12 @@ def __init__(
self.backward_prefetch = backward_prefetch
self.mixed_precision = mixed_precision
self._rank_0_will_call_children_scripts: bool = False
if activation_checkpointing and not _TORCH_GREATER_EQUAL_1_13:
raise ValueError("Activation checkpointing requires torch >= 1.13.0. HINT: `pip install -U torch`")
activation_checkpointing = activation_checkpointing or []
self._activation_checkpointing = (
[activation_checkpointing] if not isinstance(activation_checkpointing, list) else activation_checkpointing
)
self.kwargs = kwargs

@property
Expand Down Expand Up @@ -209,15 +221,14 @@ def _setup_model(self, model: torch.nn.Module) -> FullyShardedDataParallel:
:class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module."""
# If model is already wrapped, we need to avoid sending the `auto_wrap_policy`
assert self.lightning_module is not None
if (
any(isinstance(mod, FullyShardedDataParallel) for mod in self.lightning_module.modules())
and "auto_wrap_policy" in self.kwargs
if "auto_wrap_policy" in self.kwargs and any(
isinstance(mod, FullyShardedDataParallel) for mod in self.lightning_module.modules()
):
del self.kwargs["auto_wrap_policy"]

log.detail(f"setting up FSDP model with device id: {self.root_device.index}, kwargs: {self.kwargs}")

return FullyShardedDataParallel(
wrapped_module = FullyShardedDataParallel(
module=model,
process_group=self.process_group,
cpu_offload=self.cpu_offload,
Expand All @@ -227,6 +238,12 @@ def _setup_model(self, model: torch.nn.Module) -> FullyShardedDataParallel:
**self.kwargs,
)

# activation checkpointing needs to be set up after wrapping the model
if _TORCH_GREATER_EQUAL_1_13 and self._activation_checkpointing:
_setup_activation_checkpointing(module=wrapped_module, layers=self._activation_checkpointing)

return wrapped_module

def setup(self, trainer: "pl.Trainer") -> None:
assert self.accelerator is not None
self.accelerator.setup(trainer)
Expand Down
43 changes: 42 additions & 1 deletion tests/tests_lite/strategies/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from unittest import mock
from unittest.mock import MagicMock, Mock
from unittest.mock import ANY, MagicMock, Mock

import pytest
import torch
Expand Down Expand Up @@ -77,3 +77,44 @@ def test_fsdp_no_backward_sync():
pass

module.no_sync.assert_called_once()


@RunIf(min_torch="1.12")
@mock.patch("lightning_lite.strategies.fsdp._TORCH_GREATER_EQUAL_1_13", False)
def test_fsdp_activation_checkpointing_support():
"""Test that we error out if activation checkpointing requires a newer PyTorch version."""
with pytest.raises(ValueError, match="Activation checkpointing requires torch >= 1.13.0"):
FSDPStrategy(activation_checkpointing=Mock())


@RunIf(min_torch="1.13")
def test_fsdp_activation_checkpointing():
"""Test that the FSDP strategy can apply activation checkpointing to the given layers."""

class Block1(nn.Linear):
pass

class Block2(nn.Linear):
pass

class Model(nn.Module):
def __init__(self):
super().__init__()
self.layer0 = nn.Sequential(Block1(4, 4), Block1(5, 5))
self.layer1 = Block2(2, 2)
self.layer2 = nn.Linear(3, 3)

strategy = FSDPStrategy(activation_checkpointing=Block1)
assert strategy._activation_checkpointing == [Block1]

strategy = FSDPStrategy(activation_checkpointing=[Block1, Block2])
assert strategy._activation_checkpointing == [Block1, Block2]

strategy._parallel_devices = [torch.device("cuda", 0)]
with mock.patch(
"torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel"
) as fsdp_mock, mock.patch(
"torch.distributed.algorithms._checkpoint.checkpoint_wrapper.apply_activation_checkpointing"
) as ckpt_mock:
strategy.setup_module(Model())
ckpt_mock.assert_called_with(fsdp_mock(), checkpoint_wrapper_fn=ANY, check_fn=ANY)
7 changes: 5 additions & 2 deletions tests/tests_lite/strategies/test_fsdp_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def _assert_save_equality(lite, model, ckpt_path):

# model parameters are identical after loading
for current_param, loaded_param in zip(current_state_dict.values(), loaded_model.state_dict().values()):
assert torch.equal(current_param.float().cpu(), loaded_param.cpu())
assert torch.allclose(current_param.float().cpu(), loaded_param.cpu())


def _custom_auto_wrap_policy(module, recurse, unwrapped_params: int, min_num_params: int = int(1e8)) -> bool:
Expand All @@ -84,7 +84,10 @@ def _custom_auto_wrap_policy(module, recurse, unwrapped_params: int, min_num_par
@pytest.mark.parametrize("manual_wrapping", [True, False])
def test_fsdp_train_save_load(manual_wrapping, precision):
"""Test FSDP training, saving and loading with different wrapping and precision settings."""
strategy = FSDPStrategy(auto_wrap_policy=_custom_auto_wrap_policy)
strategy = FSDPStrategy(
auto_wrap_policy=_custom_auto_wrap_policy,
activation_checkpointing=[torch.nn.Linear],
)
lite = LightningLite(accelerator="cuda", strategy=strategy, devices=2, precision=precision)
lite.launch()

Expand Down
47 changes: 47 additions & 0 deletions tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import os
from typing import Any, Dict, Optional
from unittest import mock
from unittest.mock import ANY, Mock

import pytest
import torch
import torch.nn as nn

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
Expand Down Expand Up @@ -259,3 +262,47 @@ def configure_optimizers(self):
model = NoFlatParametersModel()
with pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters"):
trainer.fit(model)


@RunIf(min_torch="1.12")
@mock.patch("pytorch_lightning.strategies.fully_sharded_native._TORCH_GREATER_EQUAL_1_13", False)
def test_fully_sharded_native_activation_checkpointing_support():
"""Test that we error out if activation checkpointing requires a newer PyTorch version."""
with pytest.raises(ValueError, match="Activation checkpointing requires torch >= 1.13.0"):
DDPFullyShardedNativeStrategy(activation_checkpointing=Mock())


@RunIf(min_torch="1.13")
def test_fully_sharded_native_activation_checkpointing():
"""Test that the FSDP strategy can apply activation checkpointing to the given layers."""

class Block1(nn.Linear):
pass

class Block2(nn.Linear):
pass

class Model(BoringModel):
def __init__(self):
super().__init__()
self.layer0 = nn.Sequential(Block1(4, 4), Block1(5, 5))
self.layer1 = Block2(2, 2)
self.layer2 = nn.Linear(3, 3)

strategy = DDPFullyShardedNativeStrategy(activation_checkpointing=Block1)
assert strategy._activation_checkpointing == [Block1]

strategy = DDPFullyShardedNativeStrategy(activation_checkpointing=[Block1, Block2])
assert strategy._activation_checkpointing == [Block1, Block2]

model = Model()
strategy._parallel_devices = [torch.device("cuda", 0)]
strategy._lightning_module = model
strategy._process_group = Mock()
with mock.patch(
"pytorch_lightning.strategies.fully_sharded_native.FullyShardedDataParallel"
) as fsdp_mock, mock.patch(
"torch.distributed.algorithms._checkpoint.checkpoint_wrapper.apply_activation_checkpointing"
) as ckpt_mock:
strategy._setup_model(model)
ckpt_mock.assert_called_with(fsdp_mock(), checkpoint_wrapper_fn=ANY, check_fn=ANY)