Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 34 additions & 16 deletions docs/source-pytorch/advanced/model_parallel.rst
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,31 @@ have to ``wrap`` layers manually as in the case of manual wrapping.
trainer.fit(model)


You can customize the strategy configuration by adjusting the arguments of :class:`~lightning.pytorch.strategies.FSDPStrategy` and pass that to the ``strategy`` argument inside the ``Trainer``.

.. code-block:: python

from lightning.pytorch import Trainer
from lightning.pytorch.strategies import FSDPStrategy

# equivalent to passing `"fsdp_cpu_offload"`
fsdp = FSDPStrategy(cpu_offload=True)
trainer = pl.Trainer(strategy=fsdp, accelerator="gpu", devices=4)

# configure the wrapping condition
if torch.__version__ >= "2.1":
from torch.distributed.fsdp.wrap import ModuleWrapPolicy

my_policy = ModuleWrapPolicy({MyTransformerBlock})
else:
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
import functools

my_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda module: isinstance(module, torch.nn.Linear))
fsdp = FSDPStrategy(auto_wrap_policy=my_policy)
trainer = pl.Trainer(strategy=fsdp, accelerator="gpu", devices=4)


Read more `here <https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/#auto-wrapping>`__.


Expand Down Expand Up @@ -198,20 +223,6 @@ Here's an example using that uses ``wrap`` to create your model:
trainer = Trainer(accelerator="gpu", devices=4, strategy="fsdp", precision=16)
trainer.fit(model)


You can customize the strategy configuration by adjusting the arguments of :class:`~lightning.pytorch.strategies.FSDPStrategy` and pass that to the ``strategy`` argument inside the ``Trainer``.

.. code-block:: python

from lightning.pytorch import Trainer
from lightning.pytorch.strategies import FSDPStrategy


fsdp = FSDPStrategy(cpu_offload=True)
# equivalent to passing `"fsdp_cpu_offload"`
trainer = pl.Trainer(strategy=fsdp, accelerator="gpu", devices=4)


Check out `this tutorial <https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html>`__ to learn more about it.

----
Expand All @@ -224,13 +235,20 @@ Activation checkpointing reduces GPU memory usage by avoiding the storage of int
selected layers. The tradeoff is that computation cost for the backpropagation increases, as the dropped activations
need to be recomputed.

Enable checkpointing on large layers (like Transformers) by providing the layer class/type to the strategy:
Enable checkpointing on large layers (like Transformers) by providing a policy:

.. code-block:: python

from lightning.pytorch.strategies import FSDPStrategy

fsdp = FSDPStrategy(activation_checkpointing=MyTransformerBlock) # or pass a list with multiple types
if torch.__version__ >= "2.1":
from torch.distributed.fsdp.wrap import ModuleWrapPolicy

my_policy = ModuleWrapPolicy({MyTransformerBlock})
fsdp = FSDPStrategy(activation_checkpointing_policy=my_policy)
else:
fsdp = FSDPStrategy(activation_checkpointing=MyTransformerBlock) # or pass a list with multiple types

trainer = pl.Trainer(strategy=fsdp, accelerator="gpu", devices=4)


Expand Down
3 changes: 3 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for all half-precision modes in FSDP precision plugin ([#17807](https://github.com/Lightning-AI/lightning/pull/17807))


- Added `FSDPStrategy(activation_checkpointing_policy=...)` to customize the layer policy for automatic activation checkpointing (requires torch>=2.1) ([#18045](https://github.com/Lightning-AI/lightning/pull/18045))


- Added a callback for spike-detection ([#18014](https://github.com/Lightning-AI/lightning/pull/18014))


Expand Down
112 changes: 75 additions & 37 deletions src/lightning/fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,25 +49,25 @@
_TORCH_GREATER_EQUAL_1_12,
_TORCH_GREATER_EQUAL_1_13,
_TORCH_GREATER_EQUAL_2_0,
_TORCH_GREATER_EQUAL_2_1,
)
from lightning.fabric.utilities.init import _EmptyInit
from lightning.fabric.utilities.rank_zero import rank_zero_only, rank_zero_warn
from lightning.fabric.utilities.rank_zero import rank_zero_deprecation, rank_zero_only, rank_zero_warn
from lightning.fabric.utilities.seed import reset_seed
from lightning.fabric.utilities.types import _PATH

_SUPPORTS_OPTIMIZER_IN_FSDP_BACKWARD = False
if _TORCH_GREATER_EQUAL_2_0 and torch.distributed.is_available():
from torch.distributed.fsdp._common_utils import _get_module_fsdp_state
from torch.distributed.fsdp._traversal_utils import _get_fsdp_handles
from torch.distributed.fsdp.flat_param import FlatParameter, FlatParamHandle

_SUPPORTS_OPTIMIZER_IN_FSDP_BACKWARD = True

if TYPE_CHECKING:
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision

from lightning.fabric.wrappers import _FabricModule

if _TORCH_GREATER_EQUAL_2_0:
from torch.distributed.fsdp.wrap import _FSDPPolicy

_POLICY = Union[Callable[[Module, bool, int], bool], _FSDPPolicy]
else:
_POLICY = Callable[[Module, bool, int], bool] # type: ignore[misc]

_FSDP_ALIASES = ("fsdp", "fsdp_cpu_offload")
_METADATA_FILENAME = "meta.pt"

Expand All @@ -92,10 +92,13 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
Arguments:
cpu_offload: See ``cpu_offload`` parameter in :class:`torch.distributed.fsdp.FullyShardedDataParallel`.
mixed_precision: See ``mixed_precision`` parameter in :class:`torch.distributed.fsdp.FullyShardedDataParallel`.
activation_checkpointing: A single layer or a list of layer classes for which you want to enable activation
checkpointing. This is typically your transformer block (including attention + feed-forward).
Enabling this can free up a significant amount of memory at the cost of speed since activations in
these layers need to be recomputed during backpropagation.
activation_checkpointing: Deprecated. Use ``activation_checkpointing_policy``. A single layer or a list of
layer classes for which you want to enable activation checkpointing. This is typically your transformer
block (including attention + feed-forward).
activation_checkpointing_policy: Same as ``auto_wrap_policy`` parameter in
:class:`torch.distributed.fsdp.FullyShardedDataParallel` but used when selecting the modules for which you
want to enable activation checkpointing. Enabling this can free up a significant amount of memory at the
cost of speed since activations in these layers need to be recomputed during backpropagation.
state_dict_type: The format in which the state of the model and optimizers gets saved into the checkpoint.

- ``"full"``: The full weights and optimizer states get assembled on rank 0 and saved to a single file.
Expand All @@ -117,6 +120,7 @@ def __init__(
cpu_offload: Union[bool, "CPUOffload", None] = None,
mixed_precision: Optional["MixedPrecision"] = None,
activation_checkpointing: Optional[Union[Type[Module], List[Type[Module]]]] = None,
activation_checkpointing_policy: Optional["_POLICY"] = None,
state_dict_type: Literal["full", "sharded"] = "sharded",
**kwargs: Any,
) -> None:
Expand All @@ -140,11 +144,8 @@ def __init__(
# Enables joint setup of model and optimizer, multiple optimizer param groups, and `torch.compile()`
self._fsdp_kwargs.setdefault("use_orig_params", True)

if activation_checkpointing and not _TORCH_GREATER_EQUAL_1_13:
raise ValueError("Activation checkpointing requires torch >= 1.13.0. HINT: `pip install -U torch`")
activation_checkpointing = activation_checkpointing or []
self._activation_checkpointing = (
[activation_checkpointing] if not isinstance(activation_checkpointing, list) else activation_checkpointing
self._activation_checkpointing_kwargs = _activation_checkpointing_kwargs(
activation_checkpointing, activation_checkpointing_policy
)
self._state_dict_type = state_dict_type
self.cpu_offload = _init_cpu_offload(cpu_offload)
Expand Down Expand Up @@ -236,8 +237,8 @@ def setup_module(self, module: Module) -> "FullyShardedDataParallel":
)

# activation checkpointing needs to be set up after wrapping the model
if _TORCH_GREATER_EQUAL_1_13 and self._activation_checkpointing:
_setup_activation_checkpointing(module=wrapped_module, layers=self._activation_checkpointing)
if _TORCH_GREATER_EQUAL_1_13:
_setup_activation_checkpointing(wrapped_module, self._activation_checkpointing_kwargs)

return wrapped_module

Expand Down Expand Up @@ -594,29 +595,58 @@ def _set_world_ranks(self) -> None:
rank_zero_only.rank = self.global_rank


def _setup_activation_checkpointing(module: "FullyShardedDataParallel", layers: List[Type[Module]]) -> None:
def _activation_checkpointing_kwargs(
activation_checkpointing: Optional[Union[Type[Module], List[Type[Module]]]] = None,
activation_checkpointing_policy: Optional["_POLICY"] = None,
) -> Dict:
if activation_checkpointing is None and activation_checkpointing_policy is None:
return {}
if activation_checkpointing is not None and activation_checkpointing_policy is not None:
raise ValueError(
"You cannot set both `activation_checkpointing` and `activation_checkpointing_policy`. Use the latter."
)
if activation_checkpointing is not None:
if not _TORCH_GREATER_EQUAL_1_13:
raise ValueError("`activation_checkpointing` requires torch >= 1.13.0. HINT: `pip install -U torch`")
if isinstance(activation_checkpointing, list):
classes = tuple(activation_checkpointing)
else:
classes = (activation_checkpointing,)
if _TORCH_GREATER_EQUAL_2_1:
rank_zero_deprecation(
f"`FSDPStrategy(activation_checkpointing={activation_checkpointing})` is deprecated, use "
"`FSDPStrategy(activation_checkpointing_policy=torch.distributed.fsdp.wrap.ModuleWrapPolicy"
f"({set(classes)}))` instead."
)
return {"check_fn": lambda submodule: isinstance(submodule, classes)}
assert activation_checkpointing_policy is not None
if not _TORCH_GREATER_EQUAL_2_1:
raise ValueError("`activation_checkpointing_policy` requires torch >= 2.1.0. HINT: `pip install -U torch`")
return {"auto_wrap_policy": activation_checkpointing_policy}


def _setup_activation_checkpointing(module: Module, activation_checkpointing_kwargs: Dict) -> None:
if not activation_checkpointing_kwargs:
return

from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import CheckpointWrapper

if any(isinstance(mod, CheckpointWrapper) for mod in module.modules()):
rank_zero_warn(
"FSDP checkpointing is configured, but the model already contains checkpointed layers."
" Checkpointing will be ignored."
)
return

from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
checkpoint_wrapper,
CheckpointImpl,
CheckpointWrapper,
)

if any(isinstance(mod, CheckpointWrapper) for mod in module.modules()):
if layers:
rank_zero_warn(
f"FSDP checkpointing for the layers {[layer.__name__ for layer in layers]} is configured, but the model"
" already contains checkpointed layers. Checkpointing will be ignored."
)
# the module is already wrapped with activation checkpointing, avoid wrapping again
return

check_fn = lambda submodule: isinstance(submodule, tuple(layers))
wrapper = functools.partial(
checkpoint_wrapper,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)
apply_activation_checkpointing(module, checkpoint_wrapper_fn=wrapper, check_fn=check_fn)
wrapper = functools.partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT)
apply_activation_checkpointing(module, checkpoint_wrapper_fn=wrapper, **activation_checkpointing_kwargs)


class _FSDPBackwardSyncControl(_BackwardSyncControl):
Expand Down Expand Up @@ -710,7 +740,10 @@ def _apply_optimizers_during_fsdp_backward(
By moving optimizer step invocation into the backward call we can free
gradients earlier and reduce peak memory.
"""
assert _SUPPORTS_OPTIMIZER_IN_FSDP_BACKWARD
from torch.distributed.fsdp._common_utils import _get_module_fsdp_state
from torch.distributed.fsdp._traversal_utils import _get_fsdp_handles
from torch.distributed.fsdp.flat_param import FlatParameter, FlatParamHandle

apply_lock = threading.Lock()

param_handles = _get_fsdp_handles(module)
Expand Down Expand Up @@ -791,6 +824,11 @@ def fsdp_overlap_step_with_backward(
optimizers: Union[Optimizer, Iterable[Optimizer]],
fabric_module: "_FabricModule",
) -> _GeneratorContextManager:
if not _TORCH_GREATER_EQUAL_2_0:
raise NotImplementedError(
"`fsdp_overlap_step_with_backward` requires torch >= 2.0.0. HINT: `pip install -U torch`"
)

from lightning.fabric.wrappers import _FabricModule

assert isinstance(fabric_module, _FabricModule)
Expand Down
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added the process group timeout argument `FSDPStrategy(timeout=...)` for the FSDP strategy ([#17274](https://github.com/Lightning-AI/lightning/pull/17274))


- Added `FSDPStrategy(activation_checkpointing_policy=...)` to customize the layer policy for automatic activation checkpointing (requires torch>=2.1) ([#18045](https://github.com/Lightning-AI/lightning/pull/18045))


- Added CLI option `--map-to-cpu` to the checkpoint upgrade script to enable converting GPU checkpoints on a CPU-only machine ([#17527](https://github.com/Lightning-AI/lightning/pull/17527))


Expand Down
Loading