diff --git a/docs/source-pytorch/advanced/model_parallel.rst b/docs/source-pytorch/advanced/model_parallel.rst index a6619b3be5658..0fa5fa5b21cf3 100644 --- a/docs/source-pytorch/advanced/model_parallel.rst +++ b/docs/source-pytorch/advanced/model_parallel.rst @@ -146,6 +146,31 @@ have to ``wrap`` layers manually as in the case of manual wrapping. trainer.fit(model) +You can customize the strategy configuration by adjusting the arguments of :class:`~lightning.pytorch.strategies.FSDPStrategy` and pass that to the ``strategy`` argument inside the ``Trainer``. + +.. code-block:: python + + from lightning.pytorch import Trainer + from lightning.pytorch.strategies import FSDPStrategy + + # equivalent to passing `"fsdp_cpu_offload"` + fsdp = FSDPStrategy(cpu_offload=True) + trainer = pl.Trainer(strategy=fsdp, accelerator="gpu", devices=4) + + # configure the wrapping condition + if torch.__version__ >= "2.1": + from torch.distributed.fsdp.wrap import ModuleWrapPolicy + + my_policy = ModuleWrapPolicy({MyTransformerBlock}) + else: + from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy + import functools + + my_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda module: isinstance(module, torch.nn.Linear)) + fsdp = FSDPStrategy(auto_wrap_policy=my_policy) + trainer = pl.Trainer(strategy=fsdp, accelerator="gpu", devices=4) + + Read more `here `__. @@ -198,20 +223,6 @@ Here's an example using that uses ``wrap`` to create your model: trainer = Trainer(accelerator="gpu", devices=4, strategy="fsdp", precision=16) trainer.fit(model) - -You can customize the strategy configuration by adjusting the arguments of :class:`~lightning.pytorch.strategies.FSDPStrategy` and pass that to the ``strategy`` argument inside the ``Trainer``. - -.. code-block:: python - - from lightning.pytorch import Trainer - from lightning.pytorch.strategies import FSDPStrategy - - - fsdp = FSDPStrategy(cpu_offload=True) - # equivalent to passing `"fsdp_cpu_offload"` - trainer = pl.Trainer(strategy=fsdp, accelerator="gpu", devices=4) - - Check out `this tutorial `__ to learn more about it. ---- @@ -224,13 +235,20 @@ Activation checkpointing reduces GPU memory usage by avoiding the storage of int selected layers. The tradeoff is that computation cost for the backpropagation increases, as the dropped activations need to be recomputed. -Enable checkpointing on large layers (like Transformers) by providing the layer class/type to the strategy: +Enable checkpointing on large layers (like Transformers) by providing a policy: .. code-block:: python from lightning.pytorch.strategies import FSDPStrategy - fsdp = FSDPStrategy(activation_checkpointing=MyTransformerBlock) # or pass a list with multiple types + if torch.__version__ >= "2.1": + from torch.distributed.fsdp.wrap import ModuleWrapPolicy + + my_policy = ModuleWrapPolicy({MyTransformerBlock}) + fsdp = FSDPStrategy(activation_checkpointing_policy=my_policy) + else: + fsdp = FSDPStrategy(activation_checkpointing=MyTransformerBlock) # or pass a list with multiple types + trainer = pl.Trainer(strategy=fsdp, accelerator="gpu", devices=4) diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index cc95ab5b8b903..ad6f68e2be496 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -85,6 +85,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for all half-precision modes in FSDP precision plugin ([#17807](https://github.com/Lightning-AI/lightning/pull/17807)) +- Added `FSDPStrategy(activation_checkpointing_policy=...)` to customize the layer policy for automatic activation checkpointing (requires torch>=2.1) ([#18045](https://github.com/Lightning-AI/lightning/pull/18045)) + + - Added a callback for spike-detection ([#18014](https://github.com/Lightning-AI/lightning/pull/18014)) diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index fd72c4e498464..0f2e07b742741 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -49,25 +49,25 @@ _TORCH_GREATER_EQUAL_1_12, _TORCH_GREATER_EQUAL_1_13, _TORCH_GREATER_EQUAL_2_0, + _TORCH_GREATER_EQUAL_2_1, ) from lightning.fabric.utilities.init import _EmptyInit -from lightning.fabric.utilities.rank_zero import rank_zero_only, rank_zero_warn +from lightning.fabric.utilities.rank_zero import rank_zero_deprecation, rank_zero_only, rank_zero_warn from lightning.fabric.utilities.seed import reset_seed from lightning.fabric.utilities.types import _PATH -_SUPPORTS_OPTIMIZER_IN_FSDP_BACKWARD = False -if _TORCH_GREATER_EQUAL_2_0 and torch.distributed.is_available(): - from torch.distributed.fsdp._common_utils import _get_module_fsdp_state - from torch.distributed.fsdp._traversal_utils import _get_fsdp_handles - from torch.distributed.fsdp.flat_param import FlatParameter, FlatParamHandle - - _SUPPORTS_OPTIMIZER_IN_FSDP_BACKWARD = True - if TYPE_CHECKING: from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision from lightning.fabric.wrappers import _FabricModule + if _TORCH_GREATER_EQUAL_2_0: + from torch.distributed.fsdp.wrap import _FSDPPolicy + + _POLICY = Union[Callable[[Module, bool, int], bool], _FSDPPolicy] + else: + _POLICY = Callable[[Module, bool, int], bool] # type: ignore[misc] + _FSDP_ALIASES = ("fsdp", "fsdp_cpu_offload") _METADATA_FILENAME = "meta.pt" @@ -92,10 +92,13 @@ class FSDPStrategy(ParallelStrategy, _Sharded): Arguments: cpu_offload: See ``cpu_offload`` parameter in :class:`torch.distributed.fsdp.FullyShardedDataParallel`. mixed_precision: See ``mixed_precision`` parameter in :class:`torch.distributed.fsdp.FullyShardedDataParallel`. - activation_checkpointing: A single layer or a list of layer classes for which you want to enable activation - checkpointing. This is typically your transformer block (including attention + feed-forward). - Enabling this can free up a significant amount of memory at the cost of speed since activations in - these layers need to be recomputed during backpropagation. + activation_checkpointing: Deprecated. Use ``activation_checkpointing_policy``. A single layer or a list of + layer classes for which you want to enable activation checkpointing. This is typically your transformer + block (including attention + feed-forward). + activation_checkpointing_policy: Same as ``auto_wrap_policy`` parameter in + :class:`torch.distributed.fsdp.FullyShardedDataParallel` but used when selecting the modules for which you + want to enable activation checkpointing. Enabling this can free up a significant amount of memory at the + cost of speed since activations in these layers need to be recomputed during backpropagation. state_dict_type: The format in which the state of the model and optimizers gets saved into the checkpoint. - ``"full"``: The full weights and optimizer states get assembled on rank 0 and saved to a single file. @@ -117,6 +120,7 @@ def __init__( cpu_offload: Union[bool, "CPUOffload", None] = None, mixed_precision: Optional["MixedPrecision"] = None, activation_checkpointing: Optional[Union[Type[Module], List[Type[Module]]]] = None, + activation_checkpointing_policy: Optional["_POLICY"] = None, state_dict_type: Literal["full", "sharded"] = "sharded", **kwargs: Any, ) -> None: @@ -140,11 +144,8 @@ def __init__( # Enables joint setup of model and optimizer, multiple optimizer param groups, and `torch.compile()` self._fsdp_kwargs.setdefault("use_orig_params", True) - if activation_checkpointing and not _TORCH_GREATER_EQUAL_1_13: - raise ValueError("Activation checkpointing requires torch >= 1.13.0. HINT: `pip install -U torch`") - activation_checkpointing = activation_checkpointing or [] - self._activation_checkpointing = ( - [activation_checkpointing] if not isinstance(activation_checkpointing, list) else activation_checkpointing + self._activation_checkpointing_kwargs = _activation_checkpointing_kwargs( + activation_checkpointing, activation_checkpointing_policy ) self._state_dict_type = state_dict_type self.cpu_offload = _init_cpu_offload(cpu_offload) @@ -236,8 +237,8 @@ def setup_module(self, module: Module) -> "FullyShardedDataParallel": ) # activation checkpointing needs to be set up after wrapping the model - if _TORCH_GREATER_EQUAL_1_13 and self._activation_checkpointing: - _setup_activation_checkpointing(module=wrapped_module, layers=self._activation_checkpointing) + if _TORCH_GREATER_EQUAL_1_13: + _setup_activation_checkpointing(wrapped_module, self._activation_checkpointing_kwargs) return wrapped_module @@ -594,7 +595,49 @@ def _set_world_ranks(self) -> None: rank_zero_only.rank = self.global_rank -def _setup_activation_checkpointing(module: "FullyShardedDataParallel", layers: List[Type[Module]]) -> None: +def _activation_checkpointing_kwargs( + activation_checkpointing: Optional[Union[Type[Module], List[Type[Module]]]] = None, + activation_checkpointing_policy: Optional["_POLICY"] = None, +) -> Dict: + if activation_checkpointing is None and activation_checkpointing_policy is None: + return {} + if activation_checkpointing is not None and activation_checkpointing_policy is not None: + raise ValueError( + "You cannot set both `activation_checkpointing` and `activation_checkpointing_policy`. Use the latter." + ) + if activation_checkpointing is not None: + if not _TORCH_GREATER_EQUAL_1_13: + raise ValueError("`activation_checkpointing` requires torch >= 1.13.0. HINT: `pip install -U torch`") + if isinstance(activation_checkpointing, list): + classes = tuple(activation_checkpointing) + else: + classes = (activation_checkpointing,) + if _TORCH_GREATER_EQUAL_2_1: + rank_zero_deprecation( + f"`FSDPStrategy(activation_checkpointing={activation_checkpointing})` is deprecated, use " + "`FSDPStrategy(activation_checkpointing_policy=torch.distributed.fsdp.wrap.ModuleWrapPolicy" + f"({set(classes)}))` instead." + ) + return {"check_fn": lambda submodule: isinstance(submodule, classes)} + assert activation_checkpointing_policy is not None + if not _TORCH_GREATER_EQUAL_2_1: + raise ValueError("`activation_checkpointing_policy` requires torch >= 2.1.0. HINT: `pip install -U torch`") + return {"auto_wrap_policy": activation_checkpointing_policy} + + +def _setup_activation_checkpointing(module: Module, activation_checkpointing_kwargs: Dict) -> None: + if not activation_checkpointing_kwargs: + return + + from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import CheckpointWrapper + + if any(isinstance(mod, CheckpointWrapper) for mod in module.modules()): + rank_zero_warn( + "FSDP checkpointing is configured, but the model already contains checkpointed layers." + " Checkpointing will be ignored." + ) + return + from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( apply_activation_checkpointing, checkpoint_wrapper, @@ -602,21 +645,8 @@ def _setup_activation_checkpointing(module: "FullyShardedDataParallel", layers: CheckpointWrapper, ) - if any(isinstance(mod, CheckpointWrapper) for mod in module.modules()): - if layers: - rank_zero_warn( - f"FSDP checkpointing for the layers {[layer.__name__ for layer in layers]} is configured, but the model" - " already contains checkpointed layers. Checkpointing will be ignored." - ) - # the module is already wrapped with activation checkpointing, avoid wrapping again - return - - check_fn = lambda submodule: isinstance(submodule, tuple(layers)) - wrapper = functools.partial( - checkpoint_wrapper, - checkpoint_impl=CheckpointImpl.NO_REENTRANT, - ) - apply_activation_checkpointing(module, checkpoint_wrapper_fn=wrapper, check_fn=check_fn) + wrapper = functools.partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT) + apply_activation_checkpointing(module, checkpoint_wrapper_fn=wrapper, **activation_checkpointing_kwargs) class _FSDPBackwardSyncControl(_BackwardSyncControl): @@ -710,7 +740,10 @@ def _apply_optimizers_during_fsdp_backward( By moving optimizer step invocation into the backward call we can free gradients earlier and reduce peak memory. """ - assert _SUPPORTS_OPTIMIZER_IN_FSDP_BACKWARD + from torch.distributed.fsdp._common_utils import _get_module_fsdp_state + from torch.distributed.fsdp._traversal_utils import _get_fsdp_handles + from torch.distributed.fsdp.flat_param import FlatParameter, FlatParamHandle + apply_lock = threading.Lock() param_handles = _get_fsdp_handles(module) @@ -791,6 +824,11 @@ def fsdp_overlap_step_with_backward( optimizers: Union[Optimizer, Iterable[Optimizer]], fabric_module: "_FabricModule", ) -> _GeneratorContextManager: + if not _TORCH_GREATER_EQUAL_2_0: + raise NotImplementedError( + "`fsdp_overlap_step_with_backward` requires torch >= 2.0.0. HINT: `pip install -U torch`" + ) + from lightning.fabric.wrappers import _FabricModule assert isinstance(fabric_module, _FabricModule) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 037cba5af19d5..0f0b220a981c7 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -50,6 +50,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added the process group timeout argument `FSDPStrategy(timeout=...)` for the FSDP strategy ([#17274](https://github.com/Lightning-AI/lightning/pull/17274)) +- Added `FSDPStrategy(activation_checkpointing_policy=...)` to customize the layer policy for automatic activation checkpointing (requires torch>=2.1) ([#18045](https://github.com/Lightning-AI/lightning/pull/18045)) + + - Added CLI option `--map-to-cpu` to the checkpoint upgrade script to enable converting GPU checkpoints on a CPU-only machine ([#17527](https://github.com/Lightning-AI/lightning/pull/17527)) diff --git a/src/lightning/pytorch/strategies/fsdp.py b/src/lightning/pytorch/strategies/fsdp.py index 3a73d8252c134..35fb8fa6c7748 100644 --- a/src/lightning/pytorch/strategies/fsdp.py +++ b/src/lightning/pytorch/strategies/fsdp.py @@ -14,7 +14,7 @@ import logging from contextlib import contextmanager, nullcontext from datetime import timedelta -from typing import Any, Dict, Generator, List, Mapping, Optional, Type, Union +from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Type, TYPE_CHECKING, Union import torch from torch import Tensor @@ -26,6 +26,7 @@ from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout from lightning.fabric.strategies import _StrategyRegistry from lightning.fabric.strategies.fsdp import ( + _activation_checkpointing_kwargs, _get_full_state_dict_context, _init_cpu_offload, _optimizer_has_flat_params, @@ -57,26 +58,15 @@ from lightning.pytorch.utilities.model_helpers import is_overridden from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_only -_distributed_available = torch.distributed.is_available() -_fsdp_available = _TORCH_GREATER_EQUAL_1_12 and _distributed_available -if _fsdp_available: - from torch.distributed.fsdp import ( - CPUOffload, - FullStateDictConfig, - FullyShardedDataParallel, - MixedPrecision, - OptimStateKeyType, - StateDictType, - ) - from torch.distributed.fsdp.wrap import enable_wrap -else: - FullyShardedDataParallel = None # type: ignore[misc,assignment] - OptimStateKeyType = None # type: ignore[misc,assignment] - MixedPrecision = None # type: ignore[misc,assignment] - CPUOffload = None # type: ignore[misc,assignment] - -if _distributed_available: - from torch.distributed.distributed_c10d import _get_default_group +if TYPE_CHECKING: + from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision + + if _TORCH_GREATER_EQUAL_2_0: + from torch.distributed.fsdp.wrap import _FSDPPolicy + + _POLICY = Union[Callable[[Module, bool, int], bool], _FSDPPolicy] + else: + _POLICY = Callable[[Module, bool, int], bool] # type: ignore[misc] log = logging.getLogger(__name__) @@ -101,10 +91,13 @@ class FSDPStrategy(ParallelStrategy): Arguments: cpu_offload: See ``cpu_offload`` parameter in :class:`torch.distributed.fsdp.FullyShardedDataParallel`. mixed_precision: See ``mixed_precision`` parameter in :class:`torch.distributed.fsdp.FullyShardedDataParallel`. - activation_checkpointing: A single layer or a list of layer classes for which you want to enable activation - checkpointing. This is typically your transformer block (including attention + feed-forward). - Enabling this can free up a significant amount of memory at the cost of speed since activations in - these layers need to be recomputed during backpropagation. + activation_checkpointing: Deprecated. Use ``activation_checkpointing_policy``. A single layer or a list of + layer classes for which you want to enable activation checkpointing. This is typically your transformer + block (including attention + feed-forward). + activation_checkpointing_policy: Same as ``auto_wrap_policy`` parameter in + :class:`torch.distributed.fsdp.FullyShardedDataParallel` but used when selecting the modules for which you + want to enable activation checkpointing. Enabling this can free up a significant amount of memory at the + cost of speed since activations in these layers need to be recomputed during backpropagation. \**kwargs: See available parameters in :class:`torch.distributed.fsdp.FullyShardedDataParallel`. """ @@ -121,8 +114,9 @@ def __init__( process_group_backend: Optional[str] = None, timeout: Optional[timedelta] = default_pg_timeout, cpu_offload: Union[bool, "CPUOffload", None] = None, - mixed_precision: Optional[MixedPrecision] = None, + mixed_precision: Optional["MixedPrecision"] = None, activation_checkpointing: Optional[Union[Type[Module], List[Type[Module]]]] = None, + activation_checkpointing_policy: Optional["_POLICY"] = None, **kwargs: Any, ) -> None: if not _TORCH_GREATER_EQUAL_1_12: @@ -141,17 +135,14 @@ def __init__( self._timeout: Optional[timedelta] = timeout self.cpu_offload = _init_cpu_offload(cpu_offload) self.mixed_precision = mixed_precision - if activation_checkpointing and not _TORCH_GREATER_EQUAL_1_13: - raise ValueError("Activation checkpointing requires torch >= 1.13.0. HINT: `pip install -U torch`") - activation_checkpointing = activation_checkpointing or [] - self._activation_checkpointing = ( - [activation_checkpointing] if not isinstance(activation_checkpointing, list) else activation_checkpointing - ) self.kwargs = kwargs if _TORCH_GREATER_EQUAL_2_0: # Avoids the need for user to reference params in `configure_optimizers` via # `self.trainer.model.parameters()` and enables support for multiple parameter groups. self.kwargs.setdefault("use_orig_params", True) + self._activation_checkpointing_kwargs = _activation_checkpointing_kwargs( + activation_checkpointing, activation_checkpointing_policy + ) def lightning_module_state_dict(self) -> Dict[str, Any]: """Gathers the full state dict by unsharding all the parameters. @@ -159,6 +150,9 @@ def lightning_module_state_dict(self) -> Dict[str, Any]: To avoid OOM, the returned parameters will only be returned on rank 0 and on CPU. All other ranks get an empty dict. """ + from torch.distributed.fsdp import FullyShardedDataParallel + from torch.distributed.fsdp.api import FullStateDictConfig, StateDictType + assert self.model is not None with FullyShardedDataParallel.state_dict_type( @@ -180,6 +174,8 @@ def num_processes(self) -> int: @property def process_group(self) -> Optional[ProcessGroup]: if self._process_group is None: + from torch.distributed.distributed_c10d import _get_default_group + # The strategy should have already initilized process group in setup_environment() self._process_group = _get_default_group() return self._process_group @@ -189,7 +185,7 @@ def process_group_backend(self) -> Optional[str]: return self._process_group_backend @property - def mixed_precision_config(self) -> Optional[MixedPrecision]: + def mixed_precision_config(self) -> Optional["MixedPrecision"]: if self.mixed_precision: return self.mixed_precision plugin = self.precision_plugin @@ -229,9 +225,11 @@ def _configure_launcher(self) -> None: if not self.cluster_environment.creates_processes_externally: self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes) - def _setup_model(self, model: Module) -> FullyShardedDataParallel: + def _setup_model(self, model: Module) -> "FullyShardedDataParallel": """Wraps the model into a :class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module.""" + from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel + # If model is already wrapped, we need to avoid sending the `auto_wrap_policy` assert self.lightning_module is not None if "auto_wrap_policy" in self.kwargs and any( @@ -251,8 +249,8 @@ def _setup_model(self, model: Module) -> FullyShardedDataParallel: ) # activation checkpointing needs to be set up after wrapping the model - if _TORCH_GREATER_EQUAL_1_13 and self._activation_checkpointing: - _setup_activation_checkpointing(module=wrapped_module, layers=self._activation_checkpointing) + if _TORCH_GREATER_EQUAL_1_13: + _setup_activation_checkpointing(wrapped_module, self._activation_checkpointing_kwargs) return wrapped_module @@ -318,6 +316,9 @@ def tensor_init_context(self, empty_init: Optional[bool] = None) -> Generator[No @contextmanager def model_sharded_context(self) -> Generator[None, None, None]: log.debug(f"{self.__class__.__name__}: entered model_sharded_context.") + from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel + from torch.distributed.fsdp.wrap import enable_wrap + with enable_wrap( wrapper_cls=FullyShardedDataParallel, process_group=self.process_group, @@ -395,7 +396,7 @@ def get_registered_strategies(cls) -> List[str]: @classmethod def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None: - if not _fsdp_available: + if not _TORCH_GREATER_EQUAL_1_12 or not torch.distributed.is_available(): return strategy_registry.register( "fsdp", @@ -413,6 +414,8 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None: cls._registered_strategies.append("fsdp_cpu_offload") def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: + from torch.distributed.fsdp import FullyShardedDataParallel, OptimStateKeyType + optimizer_states = checkpoint.get("optimizer_states") # If the optimizer states are not present, we don't need to do anything (backward compatibility) @@ -445,6 +448,8 @@ def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: optimizer.load_state_dict(opt_state) def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]: + from torch.distributed.fsdp import FullyShardedDataParallel, OptimStateKeyType + if isinstance(optimizer, LightningOptimizer): optimizer = optimizer._optimizer diff --git a/tests/tests_fabric/strategies/test_fsdp.py b/tests/tests_fabric/strategies/test_fsdp.py index 4de3207e8a146..9f5421b1b7632 100644 --- a/tests/tests_fabric/strategies/test_fsdp.py +++ b/tests/tests_fabric/strategies/test_fsdp.py @@ -26,15 +26,12 @@ from lightning_utilities.core.imports import RequirementCache from torch.optim import Adam +import lightning.fabric from lightning.fabric import Fabric from lightning.fabric.plugins.environments import LightningEnvironment from lightning.fabric.strategies import FSDPStrategy -from lightning.fabric.strategies.fsdp import ( - _FSDPBackwardSyncControl, - _SUPPORTS_OPTIMIZER_IN_FSDP_BACKWARD, - fsdp_overlap_step_with_backward, -) -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12 +from lightning.fabric.strategies.fsdp import _FSDPBackwardSyncControl, fsdp_overlap_step_with_backward +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _TORCH_GREATER_EQUAL_2_1 from tests_fabric.helpers.runif import RunIf from tests_fabric.strategies.test_single_device import _MyFabricGradNorm @@ -129,12 +126,16 @@ def test_fsdp_no_backward_sync(): @RunIf(min_torch="1.12") -@mock.patch("lightning.fabric.strategies.fsdp._TORCH_GREATER_EQUAL_1_13", False) -def test_fsdp_activation_checkpointing_support(): +def test_fsdp_activation_checkpointing_support(monkeypatch): """Test that we error out if activation checkpointing requires a newer PyTorch version.""" - with pytest.raises(ValueError, match="Activation checkpointing requires torch >= 1.13.0"): + monkeypatch.setattr(lightning.fabric.strategies.fsdp, "_TORCH_GREATER_EQUAL_1_13", False) + with pytest.raises(ValueError, match="activation_checkpointing` requires torch >= 1.13.0"): FSDPStrategy(activation_checkpointing=Mock()) + monkeypatch.setattr(lightning.fabric.strategies.fsdp, "_TORCH_GREATER_EQUAL_2_1", False) + with pytest.raises(ValueError, match="activation_checkpointing_policy` requires torch >= 2.1.0"): + FSDPStrategy(activation_checkpointing_policy=Mock()) + @RunIf(min_torch="1.13") def test_fsdp_activation_checkpointing(): @@ -153,11 +154,20 @@ def __init__(self): self.layer1 = Block2(2, 2) self.layer2 = nn.Linear(3, 3) - strategy = FSDPStrategy(activation_checkpointing=Block1) - assert strategy._activation_checkpointing == [Block1] + if _TORCH_GREATER_EQUAL_2_1: + from torch.distributed.fsdp.wrap import ModuleWrapPolicy + + strategy = FSDPStrategy(activation_checkpointing_policy=ModuleWrapPolicy({Block1})) + assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"} - strategy = FSDPStrategy(activation_checkpointing=[Block1, Block2]) - assert strategy._activation_checkpointing == [Block1, Block2] + strategy = FSDPStrategy(activation_checkpointing_policy=ModuleWrapPolicy({Block1, Block2})) + assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"} + else: + strategy = FSDPStrategy(activation_checkpointing=Block1) + assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"} + + strategy = FSDPStrategy(activation_checkpointing=[Block1, Block2]) + assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"} strategy._parallel_devices = [torch.device("cuda", 0)] with mock.patch( @@ -166,7 +176,9 @@ def __init__(self): "torch.distributed.algorithms._checkpoint.checkpoint_wrapper.apply_activation_checkpointing" ) as ckpt_mock: strategy.setup_module(Model()) - ckpt_mock.assert_called_with(fsdp_mock(), checkpoint_wrapper_fn=ANY, check_fn=ANY) + ckpt_mock.assert_called_with( + fsdp_mock(), checkpoint_wrapper_fn=ANY, **strategy._activation_checkpointing_kwargs + ) @RunIf(min_torch="1.13") @@ -384,7 +396,6 @@ def __del__(self) -> None: @pytest.mark.skip(reason="Flaky test") # See also: https://github.com/Lightning-AI/lightning/pull/17774 @RunIf(min_torch="2.0.0", min_cuda_gpus=2, skip_windows=True, standalone=True) -@pytest.mark.skipif(not _SUPPORTS_OPTIMIZER_IN_FSDP_BACKWARD, reason="Not supported in this version of PyTorch") @pytest.mark.skipif(not RequirementCache("psutil"), reason="psutil is needed to help prevent deadlocks.") @pytest.mark.parametrize( "checkpoint", diff --git a/tests/tests_fabric/strategies/test_fsdp_integration.py b/tests/tests_fabric/strategies/test_fsdp_integration.py index 39404bd392f58..51911aadf2c31 100644 --- a/tests/tests_fabric/strategies/test_fsdp_integration.py +++ b/tests/tests_fabric/strategies/test_fsdp_integration.py @@ -23,7 +23,11 @@ from lightning.fabric import Fabric from lightning.fabric.plugins import FSDPPrecision from lightning.fabric.strategies import FSDPStrategy -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _TORCH_GREATER_EQUAL_2_0 +from lightning.fabric.utilities.imports import ( + _TORCH_GREATER_EQUAL_1_12, + _TORCH_GREATER_EQUAL_2_0, + _TORCH_GREATER_EQUAL_2_1, +) from lightning.fabric.wrappers import _FabricOptimizer from tests_fabric.helpers.models import BoringFabric from tests_fabric.helpers.runif import RunIf @@ -405,7 +409,14 @@ def test_fsdp_save_filter(tmp_path): @RunIf(min_torch="1.13", min_cuda_gpus=1) def test_fsdp_manual_activation_checkpointing(): model = torch.nn.Sequential(torch.nn.Linear(1, 1), torch.nn.Linear(1, 1)) - strategy = FSDPStrategy(activation_checkpointing=torch.nn.Linear) + + if _TORCH_GREATER_EQUAL_2_1: + from torch.distributed.fsdp.wrap import ModuleWrapPolicy + + strategy = FSDPStrategy(activation_checkpointing_policy=ModuleWrapPolicy({torch.nn.Linear})) + else: + strategy = FSDPStrategy(activation_checkpointing=torch.nn.Linear) + fabric = Fabric(devices=1, accelerator="cuda", strategy=strategy) fabric.launch() @@ -421,7 +432,7 @@ def test_fsdp_manual_activation_checkpointing(): assert wrappers == {"0", "1"} # let fabric set up the model, it shouldn't apply activation checkpointing again - with pytest.warns(match="Linear'] is configured, but the model already contains checkpointed"): + with pytest.warns(match="is configured, but the model already contains checkpointed"): model = fabric.setup(model) wrappers = {name for name, mod in model._forward_module.named_modules() if isinstance(mod, CheckpointWrapper)} diff --git a/tests/tests_pytorch/deprecated_api/test_no_removal_version.py b/tests/tests_pytorch/deprecated_api/test_no_removal_version.py index 2aab97aecd99e..89579feb1f689 100644 --- a/tests/tests_pytorch/deprecated_api/test_no_removal_version.py +++ b/tests/tests_pytorch/deprecated_api/test_no_removal_version.py @@ -1,8 +1,11 @@ import pytest +import torch.nn +import lightning.fabric from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel -from lightning.pytorch.strategies import DDPStrategy +from lightning.pytorch.strategies import DDPStrategy, FSDPStrategy +from tests_pytorch.helpers.runif import RunIf def test_configure_sharded_model(): @@ -30,3 +33,13 @@ def test_ddp_is_distributed(): strategy = DDPStrategy() with pytest.deprecated_call(match="is deprecated"): _ = strategy.is_distributed + + +@RunIf(min_torch="1.13") +def test_fsdp_activation_checkpointing(monkeypatch): + with pytest.raises(ValueError, match="cannot set both `activation_checkpointing"): + FSDPStrategy(activation_checkpointing=torch.nn.Linear, activation_checkpointing_policy=lambda *_: True) + + monkeypatch.setattr(lightning.fabric.strategies.fsdp, "_TORCH_GREATER_EQUAL_2_1", True) + with pytest.deprecated_call(match=r"use `FSDPStrategy\(activation_checkpointing_policy"): + FSDPStrategy(activation_checkpointing=torch.nn.Linear) diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py index 5daac8820d98a..778f7ff5fb266 100644 --- a/tests/tests_pytorch/strategies/test_fsdp.py +++ b/tests/tests_pytorch/strategies/test_fsdp.py @@ -196,7 +196,7 @@ def _assert_save_equality(trainer, ckpt_path, cls=TestFSDPModel): @RunIf(min_torch="1.12") -def test_invalid_on_cpu(tmpdir): +def test_invalid_on_cpu(tmpdir, cuda_count_0): """Test to ensure that we raise Misconfiguration for FSDP on CPU.""" with pytest.raises( MisconfigurationException, @@ -411,10 +411,10 @@ def configure_optimizers(self): @RunIf(min_torch="1.12") -@mock.patch("lightning.pytorch.strategies.fsdp._TORCH_GREATER_EQUAL_1_13", False) +@mock.patch("lightning.fabric.strategies.fsdp._TORCH_GREATER_EQUAL_1_13", False) def test_fsdp_activation_checkpointing_support(): """Test that we error out if activation checkpointing requires a newer PyTorch version.""" - with pytest.raises(ValueError, match="Activation checkpointing requires torch >= 1.13.0"): + with pytest.raises(ValueError, match="activation_checkpointing` requires torch >= 1.13.0"): FSDPStrategy(activation_checkpointing=Mock()) @@ -435,21 +435,34 @@ def __init__(self): self.layer1 = Block2(2, 2) self.layer2 = nn.Linear(3, 3) - strategy = FSDPStrategy(activation_checkpointing=Block1) - assert strategy._activation_checkpointing == [Block1] + if _TORCH_GREATER_EQUAL_2_1: + from torch.distributed.fsdp.wrap import ModuleWrapPolicy - strategy = FSDPStrategy(activation_checkpointing=[Block1, Block2]) - assert strategy._activation_checkpointing == [Block1, Block2] + strategy = FSDPStrategy(activation_checkpointing_policy=ModuleWrapPolicy({Block1})) + assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"} + + strategy = FSDPStrategy(activation_checkpointing_policy=ModuleWrapPolicy({Block1, Block2})) + assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"} + else: + strategy = FSDPStrategy(activation_checkpointing=Block1) + assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"} + + strategy = FSDPStrategy(activation_checkpointing=[Block1, Block2]) + assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"} model = Model() strategy._parallel_devices = [torch.device("cuda", 0)] strategy._lightning_module = model strategy._process_group = Mock() - with mock.patch("lightning.pytorch.strategies.fsdp.FullyShardedDataParallel") as fsdp_mock, mock.patch( + with mock.patch( + "torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel" + ) as fsdp_mock, mock.patch( "torch.distributed.algorithms._checkpoint.checkpoint_wrapper.apply_activation_checkpointing" ) as ckpt_mock: strategy._setup_model(model) - ckpt_mock.assert_called_with(fsdp_mock(), checkpoint_wrapper_fn=ANY, check_fn=ANY) + ckpt_mock.assert_called_with( + fsdp_mock(), checkpoint_wrapper_fn=ANY, **strategy._activation_checkpointing_kwargs + ) @RunIf(min_torch="1.12")