-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Description & Motivation
The current mechanism for activation checkpointing is limited to checking for module types:
FSDPStrategy(activation_checkpointing=torch.nn.Linear)
FSDPStrategy(activation_checkpointing=[torch.nn.Linear, TransformerBlock])
PyTorch 2.1 has added support (pytorch/pytorch#102672) for
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import apply_activation_checkpointing
apply_activation_checkpointing(model, auto_wrap_policy=...)
For example, an equivalent policy to the examples above would be
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
my_policy = ModuleWrapPolicy({torch.nn.Linear, TransformerBlock})
And the FSDP polices would work too: always_wrap_policy, lambda_auto_wrap_policy, size_based_auto_wrap_policy
This pattern is convenient, not only because of the added expressiveness, but also because the policy can be exactly the same as the auto_wrap_policy
which is often set together with activation_checkpointing
(example on lit-gpt: https://github.com/Lightning-AI/lit-gpt/blob/main/pretrain/redpajama.py#L80-L83)
Pitch
Add an activation_checkpointing_policy
argument to FSDPStrategy
that takes a valid policy. Only available with torch>=2.1. Example use:
policy = ...
FSDPStrategy(
auto_wrap_policy=policy,
activation_checkpointing_policy=policy
)
Deprecate activation_checkpointing
(with no removal version) when run with torch<2.1
Alternatives
- Overload the
activation_checkpointing
argument, allowing it to additionally accept a policy - Not do it,
apply_activation_checkpointing(check_fn=...)
hasn't been deprecated upstream.
Additional context
A related issue is #16991. That issue proposes overloading auto_wrap_policy
based on its value's type.
If activation_checkpointing
is deprecated, then that issue should be closed.
On the other hand, if it is considered that passing a list of types is worth keeping and both the policy and list options are kept, then that would be example in line with #16991