Skip to content

Commit 05dbf48

Browse files
Activation checkpointing in FSDP without boilerplate (#15826)
* initial * input type * checkpointing * fsdp in pl * all_close Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 2992002 commit 05dbf48

File tree

7 files changed

+185
-17
lines changed

7 files changed

+185
-17
lines changed

docs/source-pytorch/advanced/model_parallel.rst

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -428,13 +428,36 @@ You can customize the strategy configuration by adjusting the arguments of :clas
428428
429429
430430
native_fsdp = DDPFullyShardedNativeStrategy(cpu_offload=CPUOffload(offload_params=True))
431-
trainer = pl.Trainer(strategy=native_fsdp, accelerator="gpu", device=4)
431+
trainer = pl.Trainer(strategy=native_fsdp, accelerator="gpu", devices=4)
432432
433433
434434
Check out `this tutorial <https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html>`__ to learn more about the native support.
435435

436436
----
437437

438+
439+
Activation Checkpointing
440+
========================
441+
442+
Activation checkpointing reduces GPU memory usage by avoiding the storage of intermediate activation tensors in
443+
selected layers. The tradeoff is that computation cost for the backpropagation increases, as the dropped activations
444+
need to be recomputed.
445+
446+
Enable checkpointing on large layers (like Transformers) by providing the layer class/type to the strategy:
447+
448+
.. code-block:: python
449+
450+
from pytorch_lightning.strategies import DDPFullyShardedNativeStrategy
451+
452+
fsdp = DDPFullyShardedNativeStrategy(
453+
activation_checkpointing=MyTransformerBlock, # or pass a list with multiple types
454+
)
455+
trainer = pl.Trainer(strategy=fsdp, accelerator="gpu", devices=4)
456+
457+
458+
----
459+
460+
438461
.. _deepspeed_advanced:
439462

440463
*********

src/lightning_lite/strategies/fsdp.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import functools
1415
from contextlib import contextmanager
1516
from datetime import timedelta
16-
from typing import Any, Dict, Generator, List, Optional, Tuple, TYPE_CHECKING, Union
17+
from typing import Any, Dict, Generator, List, Optional, Tuple, Type, TYPE_CHECKING, Union
1718

1819
import torch
1920
from torch import Tensor
@@ -35,7 +36,7 @@
3536
)
3637
from lightning_lite.utilities.distributed import group as _group
3738
from lightning_lite.utilities.distributed import ReduceOp
38-
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_12
39+
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _TORCH_GREATER_EQUAL_1_13
3940
from lightning_lite.utilities.rank_zero import rank_zero_only
4041
from lightning_lite.utilities.seed import reset_seed
4142

@@ -78,6 +79,10 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
7879
computation overlapping. The pros and cons of each algorithm is explained in the class ``BackwardPrefetch``.
7980
mixed_precision: Mixed Precision config. By default, Lightning will enable FP16 if ``precision=16`` or BF16
8081
if ``precision=bf16`` unless a config is passed in. This is only available in PyTorch 1.12 and later.
82+
activation_checkpointing: A single layer or a list of layer classes for which you want to enable activation
83+
checkpointing. This is typically your transformer block (including attention + feed-forward).
84+
Enabling this can free up a significant amount of memory at the cost of speed since activations in
85+
these layers need to be recomputed during backpropagation.
8186
\**kwargs: Optional keywoard arguments passed to the FSDP context manager which will configure the FSDP class
8287
when wrapping modules.
8388
"""
@@ -94,6 +99,7 @@ def __init__(
9499
cpu_offload: Optional["CPUOffload"] = None,
95100
backward_prefetch: Optional["BackwardPrefetch"] = None,
96101
mixed_precision: Optional["MixedPrecision"] = None,
102+
activation_checkpointing: Optional[Union[Type[Module], List[Type[Module]]]] = None,
97103
**kwargs: Any,
98104
) -> None:
99105
if not _TORCH_GREATER_EQUAL_1_12:
@@ -112,6 +118,13 @@ def __init__(
112118
self._backward_sync_control = _FSDPBackwardSyncControl()
113119
self._ddp_kwargs = kwargs
114120

121+
if activation_checkpointing and not _TORCH_GREATER_EQUAL_1_13:
122+
raise ValueError("Activation checkpointing requires torch >= 1.13.0. HINT: `pip install -U torch`")
123+
activation_checkpointing = activation_checkpointing or []
124+
self._activation_checkpointing = (
125+
[activation_checkpointing] if not isinstance(activation_checkpointing, list) else activation_checkpointing
126+
)
127+
115128
self.cpu_offload = cpu_offload
116129
self.backward_prefetch = backward_prefetch
117130
self.mixed_precision = mixed_precision
@@ -175,13 +188,12 @@ def setup_module(self, module: Module) -> "FullyShardedDataParallel":
175188
:class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module."""
176189
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
177190

178-
if (
179-
any(isinstance(mod, FullyShardedDataParallel) for mod in module.modules())
180-
and "auto_wrap_policy" in self._ddp_kwargs
191+
if "auto_wrap_policy" in self._ddp_kwargs and any(
192+
isinstance(mod, FullyShardedDataParallel) for mod in module.modules()
181193
):
182194
# If model is already wrapped, we need to avoid sending the `auto_wrap_policy`
183195
del self._ddp_kwargs["auto_wrap_policy"]
184-
return FullyShardedDataParallel(
196+
wrapped_module = FullyShardedDataParallel(
185197
module=module,
186198
cpu_offload=self.cpu_offload,
187199
backward_prefetch=self.backward_prefetch,
@@ -190,6 +202,12 @@ def setup_module(self, module: Module) -> "FullyShardedDataParallel":
190202
**self._ddp_kwargs,
191203
)
192204

205+
# activation checkpointing needs to be set up after wrapping the model
206+
if _TORCH_GREATER_EQUAL_1_13 and self._activation_checkpointing:
207+
_setup_activation_checkpointing(module=wrapped_module, layers=self._activation_checkpointing)
208+
209+
return wrapped_module
210+
193211
def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
194212
"""Set up an optimizer for a model wrapped with FSDP.
195213
@@ -291,6 +309,21 @@ def _set_world_ranks(self) -> None:
291309
rank_zero_only.rank = self.cluster_environment.global_rank()
292310

293311

312+
def _setup_activation_checkpointing(module: "FullyShardedDataParallel", layers: List[Type[Module]]) -> None:
313+
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
314+
apply_activation_checkpointing,
315+
checkpoint_wrapper,
316+
CheckpointImpl,
317+
)
318+
319+
check_fn = lambda submodule: isinstance(submodule, tuple(layers))
320+
wrapper = functools.partial(
321+
checkpoint_wrapper,
322+
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
323+
)
324+
apply_activation_checkpointing(module, checkpoint_wrapper_fn=wrapper, check_fn=check_fn)
325+
326+
294327
class _FSDPBackwardSyncControl(_BackwardSyncControl):
295328
@contextmanager
296329
def no_backward_sync(self, module: Module) -> Generator:

src/pytorch_lightning/CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2929

3030
- Added a warning when `self.log(..., logger=True)` is called without a configured logger ([#15814](https://github.com/Lightning-AI/lightning/pull/15814))
3131

32+
33+
- Added support for activation checkpointing for the `DDPFullyShardedNativeStrategy` strategy ([#15826](https://github.com/Lightning-AI/lightning/pull/15826))
34+
35+
3236
### Changed
3337

3438
- Drop PyTorch 1.9 support ([#15347](https://github.com/Lightning-AI/lightning/pull/15347))

src/pytorch_lightning/strategies/fully_sharded_native.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,15 @@
1313
# limitations under the License.
1414
import contextlib
1515
import logging
16-
from typing import Any, Dict, Generator, List, Optional, Union
16+
from typing import Any, Dict, Generator, List, Optional, Type, Union
1717

1818
import torch
1919
from torch import Tensor
20+
from torch.nn import Module
2021

2122
import pytorch_lightning as pl
2223
from lightning_lite.plugins import CheckpointIO, ClusterEnvironment
23-
from lightning_lite.strategies.fsdp import _optimizer_has_flat_params
24+
from lightning_lite.strategies.fsdp import _optimizer_has_flat_params, _setup_activation_checkpointing
2425
from lightning_lite.utilities.distributed import (
2526
_get_default_process_group_backend_for_device,
2627
_init_dist_connection,
@@ -38,7 +39,7 @@
3839
from pytorch_lightning.strategies.strategy import TBroadcast
3940
from pytorch_lightning.trainer.states import TrainerFn
4041
from pytorch_lightning.utilities.exceptions import MisconfigurationException
41-
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12
42+
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _TORCH_GREATER_EQUAL_1_13
4243
from pytorch_lightning.utilities.model_helpers import is_overridden
4344
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only
4445
from pytorch_lightning.utilities.types import STEP_OUTPUT
@@ -100,6 +101,10 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
100101
Mixed Precision config. By default, Lightning will enable FP16 if ``precision=16``
101102
or BF16 if ``precision=bf16`` unless a config is passed in.
102103
This is only available in PyTorch 1.12 and later.
104+
activation_checkpointing: A single layer or a list of layer classes for which you want to enable activation
105+
checkpointing. This is typically your transformer block (including attention + feed-forward).
106+
Enabling this can free up a significant amount of memory at the cost of speed since activations in
107+
these layers need to be recomputed during backpropagation.
103108
\**kwargs: Passed to the FSDP context manager which will configure the FSDP class when wrapping modules.
104109
105110
"""
@@ -118,6 +123,7 @@ def __init__(
118123
cpu_offload: Optional[CPUOffload] = None,
119124
backward_prefetch: Optional[BackwardPrefetch] = None,
120125
mixed_precision: Optional[MixedPrecision] = None,
126+
activation_checkpointing: Optional[Union[Type[Module], List[Type[Module]]]] = None,
121127
**kwargs: Any,
122128
) -> None:
123129
if not _TORCH_GREATER_EQUAL_1_12:
@@ -139,6 +145,12 @@ def __init__(
139145
self.backward_prefetch = backward_prefetch
140146
self.mixed_precision = mixed_precision
141147
self._rank_0_will_call_children_scripts: bool = False
148+
if activation_checkpointing and not _TORCH_GREATER_EQUAL_1_13:
149+
raise ValueError("Activation checkpointing requires torch >= 1.13.0. HINT: `pip install -U torch`")
150+
activation_checkpointing = activation_checkpointing or []
151+
self._activation_checkpointing = (
152+
[activation_checkpointing] if not isinstance(activation_checkpointing, list) else activation_checkpointing
153+
)
142154
self.kwargs = kwargs
143155

144156
@property
@@ -209,15 +221,14 @@ def _setup_model(self, model: torch.nn.Module) -> FullyShardedDataParallel:
209221
:class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module."""
210222
# If model is already wrapped, we need to avoid sending the `auto_wrap_policy`
211223
assert self.lightning_module is not None
212-
if (
213-
any(isinstance(mod, FullyShardedDataParallel) for mod in self.lightning_module.modules())
214-
and "auto_wrap_policy" in self.kwargs
224+
if "auto_wrap_policy" in self.kwargs and any(
225+
isinstance(mod, FullyShardedDataParallel) for mod in self.lightning_module.modules()
215226
):
216227
del self.kwargs["auto_wrap_policy"]
217228

218229
log.detail(f"setting up FSDP model with device id: {self.root_device.index}, kwargs: {self.kwargs}")
219230

220-
return FullyShardedDataParallel(
231+
wrapped_module = FullyShardedDataParallel(
221232
module=model,
222233
process_group=self.process_group,
223234
cpu_offload=self.cpu_offload,
@@ -227,6 +238,12 @@ def _setup_model(self, model: torch.nn.Module) -> FullyShardedDataParallel:
227238
**self.kwargs,
228239
)
229240

241+
# activation checkpointing needs to be set up after wrapping the model
242+
if _TORCH_GREATER_EQUAL_1_13 and self._activation_checkpointing:
243+
_setup_activation_checkpointing(module=wrapped_module, layers=self._activation_checkpointing)
244+
245+
return wrapped_module
246+
230247
def setup(self, trainer: "pl.Trainer") -> None:
231248
assert self.accelerator is not None
232249
self.accelerator.setup(trainer)

tests/tests_lite/strategies/test_fsdp.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
from unittest import mock
16-
from unittest.mock import MagicMock, Mock
16+
from unittest.mock import ANY, MagicMock, Mock
1717

1818
import pytest
1919
import torch
@@ -77,3 +77,44 @@ def test_fsdp_no_backward_sync():
7777
pass
7878

7979
module.no_sync.assert_called_once()
80+
81+
82+
@RunIf(min_torch="1.12")
83+
@mock.patch("lightning_lite.strategies.fsdp._TORCH_GREATER_EQUAL_1_13", False)
84+
def test_fsdp_activation_checkpointing_support():
85+
"""Test that we error out if activation checkpointing requires a newer PyTorch version."""
86+
with pytest.raises(ValueError, match="Activation checkpointing requires torch >= 1.13.0"):
87+
FSDPStrategy(activation_checkpointing=Mock())
88+
89+
90+
@RunIf(min_torch="1.13")
91+
def test_fsdp_activation_checkpointing():
92+
"""Test that the FSDP strategy can apply activation checkpointing to the given layers."""
93+
94+
class Block1(nn.Linear):
95+
pass
96+
97+
class Block2(nn.Linear):
98+
pass
99+
100+
class Model(nn.Module):
101+
def __init__(self):
102+
super().__init__()
103+
self.layer0 = nn.Sequential(Block1(4, 4), Block1(5, 5))
104+
self.layer1 = Block2(2, 2)
105+
self.layer2 = nn.Linear(3, 3)
106+
107+
strategy = FSDPStrategy(activation_checkpointing=Block1)
108+
assert strategy._activation_checkpointing == [Block1]
109+
110+
strategy = FSDPStrategy(activation_checkpointing=[Block1, Block2])
111+
assert strategy._activation_checkpointing == [Block1, Block2]
112+
113+
strategy._parallel_devices = [torch.device("cuda", 0)]
114+
with mock.patch(
115+
"torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel"
116+
) as fsdp_mock, mock.patch(
117+
"torch.distributed.algorithms._checkpoint.checkpoint_wrapper.apply_activation_checkpointing"
118+
) as ckpt_mock:
119+
strategy.setup_module(Model())
120+
ckpt_mock.assert_called_with(fsdp_mock(), checkpoint_wrapper_fn=ANY, check_fn=ANY)

tests/tests_lite/strategies/test_fsdp_integration.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def _assert_save_equality(lite, model, ckpt_path):
7272

7373
# model parameters are identical after loading
7474
for current_param, loaded_param in zip(current_state_dict.values(), loaded_model.state_dict().values()):
75-
assert torch.equal(current_param.float().cpu(), loaded_param.cpu())
75+
assert torch.allclose(current_param.float().cpu(), loaded_param.cpu())
7676

7777

7878
def _custom_auto_wrap_policy(module, recurse, unwrapped_params: int, min_num_params: int = int(1e8)) -> bool:
@@ -84,7 +84,10 @@ def _custom_auto_wrap_policy(module, recurse, unwrapped_params: int, min_num_par
8484
@pytest.mark.parametrize("manual_wrapping", [True, False])
8585
def test_fsdp_train_save_load(manual_wrapping, precision):
8686
"""Test FSDP training, saving and loading with different wrapping and precision settings."""
87-
strategy = FSDPStrategy(auto_wrap_policy=_custom_auto_wrap_policy)
87+
strategy = FSDPStrategy(
88+
auto_wrap_policy=_custom_auto_wrap_policy,
89+
activation_checkpointing=[torch.nn.Linear],
90+
)
8891
lite = LightningLite(accelerator="cuda", strategy=strategy, devices=2, precision=precision)
8992
lite.launch()
9093

tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
import os
22
from typing import Any, Dict, Optional
3+
from unittest import mock
4+
from unittest.mock import ANY, Mock
35

46
import pytest
57
import torch
8+
import torch.nn as nn
69

710
from pytorch_lightning import Trainer
811
from pytorch_lightning.callbacks import ModelCheckpoint
@@ -259,3 +262,47 @@ def configure_optimizers(self):
259262
model = NoFlatParametersModel()
260263
with pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters"):
261264
trainer.fit(model)
265+
266+
267+
@RunIf(min_torch="1.12")
268+
@mock.patch("pytorch_lightning.strategies.fully_sharded_native._TORCH_GREATER_EQUAL_1_13", False)
269+
def test_fully_sharded_native_activation_checkpointing_support():
270+
"""Test that we error out if activation checkpointing requires a newer PyTorch version."""
271+
with pytest.raises(ValueError, match="Activation checkpointing requires torch >= 1.13.0"):
272+
DDPFullyShardedNativeStrategy(activation_checkpointing=Mock())
273+
274+
275+
@RunIf(min_torch="1.13")
276+
def test_fully_sharded_native_activation_checkpointing():
277+
"""Test that the FSDP strategy can apply activation checkpointing to the given layers."""
278+
279+
class Block1(nn.Linear):
280+
pass
281+
282+
class Block2(nn.Linear):
283+
pass
284+
285+
class Model(BoringModel):
286+
def __init__(self):
287+
super().__init__()
288+
self.layer0 = nn.Sequential(Block1(4, 4), Block1(5, 5))
289+
self.layer1 = Block2(2, 2)
290+
self.layer2 = nn.Linear(3, 3)
291+
292+
strategy = DDPFullyShardedNativeStrategy(activation_checkpointing=Block1)
293+
assert strategy._activation_checkpointing == [Block1]
294+
295+
strategy = DDPFullyShardedNativeStrategy(activation_checkpointing=[Block1, Block2])
296+
assert strategy._activation_checkpointing == [Block1, Block2]
297+
298+
model = Model()
299+
strategy._parallel_devices = [torch.device("cuda", 0)]
300+
strategy._lightning_module = model
301+
strategy._process_group = Mock()
302+
with mock.patch(
303+
"pytorch_lightning.strategies.fully_sharded_native.FullyShardedDataParallel"
304+
) as fsdp_mock, mock.patch(
305+
"torch.distributed.algorithms._checkpoint.checkpoint_wrapper.apply_activation_checkpointing"
306+
) as ckpt_mock:
307+
strategy._setup_model(model)
308+
ckpt_mock.assert_called_with(fsdp_mock(), checkpoint_wrapper_fn=ANY, check_fn=ANY)

0 commit comments

Comments
 (0)