11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
+ import functools
14
15
from contextlib import contextmanager
15
16
from datetime import timedelta
16
- from typing import Any , Dict , Generator , List , Optional , Tuple , TYPE_CHECKING , Union
17
+ from typing import Any , Dict , Generator , List , Optional , Tuple , Type , TYPE_CHECKING , Union
17
18
18
19
import torch
19
20
from torch import Tensor
35
36
)
36
37
from lightning_lite .utilities .distributed import group as _group
37
38
from lightning_lite .utilities .distributed import ReduceOp
38
- from lightning_lite .utilities .imports import _TORCH_GREATER_EQUAL_1_12
39
+ from lightning_lite .utilities .imports import _TORCH_GREATER_EQUAL_1_12 , _TORCH_GREATER_EQUAL_1_13
39
40
from lightning_lite .utilities .rank_zero import rank_zero_only
40
41
from lightning_lite .utilities .seed import reset_seed
41
42
@@ -78,6 +79,10 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
78
79
computation overlapping. The pros and cons of each algorithm is explained in the class ``BackwardPrefetch``.
79
80
mixed_precision: Mixed Precision config. By default, Lightning will enable FP16 if ``precision=16`` or BF16
80
81
if ``precision=bf16`` unless a config is passed in. This is only available in PyTorch 1.12 and later.
82
+ activation_checkpointing: A single layer or a list of layer classes for which you want to enable activation
83
+ checkpointing. This is typically your transformer block (including attention + feed-forward).
84
+ Enabling this can free up a significant amount of memory at the cost of speed since activations in
85
+ these layers need to be recomputed during backpropagation.
81
86
\**kwargs: Optional keywoard arguments passed to the FSDP context manager which will configure the FSDP class
82
87
when wrapping modules.
83
88
"""
@@ -94,6 +99,7 @@ def __init__(
94
99
cpu_offload : Optional ["CPUOffload" ] = None ,
95
100
backward_prefetch : Optional ["BackwardPrefetch" ] = None ,
96
101
mixed_precision : Optional ["MixedPrecision" ] = None ,
102
+ activation_checkpointing : Optional [Union [Type [Module ], List [Type [Module ]]]] = None ,
97
103
** kwargs : Any ,
98
104
) -> None :
99
105
if not _TORCH_GREATER_EQUAL_1_12 :
@@ -112,6 +118,13 @@ def __init__(
112
118
self ._backward_sync_control = _FSDPBackwardSyncControl ()
113
119
self ._ddp_kwargs = kwargs
114
120
121
+ if activation_checkpointing and not _TORCH_GREATER_EQUAL_1_13 :
122
+ raise ValueError ("Activation checkpointing requires torch >= 1.13.0. HINT: `pip install -U torch`" )
123
+ activation_checkpointing = activation_checkpointing or []
124
+ self ._activation_checkpointing = (
125
+ [activation_checkpointing ] if not isinstance (activation_checkpointing , list ) else activation_checkpointing
126
+ )
127
+
115
128
self .cpu_offload = cpu_offload
116
129
self .backward_prefetch = backward_prefetch
117
130
self .mixed_precision = mixed_precision
@@ -175,13 +188,12 @@ def setup_module(self, module: Module) -> "FullyShardedDataParallel":
175
188
:class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module."""
176
189
from torch .distributed .fsdp .fully_sharded_data_parallel import FullyShardedDataParallel
177
190
178
- if (
179
- any (isinstance (mod , FullyShardedDataParallel ) for mod in module .modules ())
180
- and "auto_wrap_policy" in self ._ddp_kwargs
191
+ if "auto_wrap_policy" in self ._ddp_kwargs and any (
192
+ isinstance (mod , FullyShardedDataParallel ) for mod in module .modules ()
181
193
):
182
194
# If model is already wrapped, we need to avoid sending the `auto_wrap_policy`
183
195
del self ._ddp_kwargs ["auto_wrap_policy" ]
184
- return FullyShardedDataParallel (
196
+ wrapped_module = FullyShardedDataParallel (
185
197
module = module ,
186
198
cpu_offload = self .cpu_offload ,
187
199
backward_prefetch = self .backward_prefetch ,
@@ -190,6 +202,12 @@ def setup_module(self, module: Module) -> "FullyShardedDataParallel":
190
202
** self ._ddp_kwargs ,
191
203
)
192
204
205
+ # activation checkpointing needs to be set up after wrapping the model
206
+ if _TORCH_GREATER_EQUAL_1_13 and self ._activation_checkpointing :
207
+ _setup_activation_checkpointing (module = wrapped_module , layers = self ._activation_checkpointing )
208
+
209
+ return wrapped_module
210
+
193
211
def setup_optimizer (self , optimizer : Optimizer ) -> Optimizer :
194
212
"""Set up an optimizer for a model wrapped with FSDP.
195
213
@@ -291,6 +309,21 @@ def _set_world_ranks(self) -> None:
291
309
rank_zero_only .rank = self .cluster_environment .global_rank ()
292
310
293
311
312
+ def _setup_activation_checkpointing (module : "FullyShardedDataParallel" , layers : List [Type [Module ]]) -> None :
313
+ from torch .distributed .algorithms ._checkpoint .checkpoint_wrapper import (
314
+ apply_activation_checkpointing ,
315
+ checkpoint_wrapper ,
316
+ CheckpointImpl ,
317
+ )
318
+
319
+ check_fn = lambda submodule : isinstance (submodule , tuple (layers ))
320
+ wrapper = functools .partial (
321
+ checkpoint_wrapper ,
322
+ checkpoint_impl = CheckpointImpl .NO_REENTRANT ,
323
+ )
324
+ apply_activation_checkpointing (module , checkpoint_wrapper_fn = wrapper , check_fn = check_fn )
325
+
326
+
294
327
class _FSDPBackwardSyncControl (_BackwardSyncControl ):
295
328
@contextmanager
296
329
def no_backward_sync (self , module : Module ) -> Generator :
0 commit comments