-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Description & Motivation
The pull request #18045 is about to change the way activation checkpointing gets configured in FSDP. Previously, #15825 enabled a convenient way to configure activation checkpointing without importing boilerplate code from PyTorch FSDP:
fabric = Fabric(strategy=FSDPStrategy(..., activation_checkpointing=MyTransformerBlock))
However, if #18045 it will now be necessary again to configure it like so:
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
strategy = FSDPStrategy(activation_checkpointing_policy=ModuleWrapPolicy({MyTransformerBlock}))
or via a functional policy. This is now more boilerplate code than necessary, since Lightning could easily accept the class or list of classes directly and wrap it internally, which is what this issue proposes.
Pitch
Part 1
Allow passing the module classes to wrap directly to the argument:
strategy = FSDPStrategy(activation_checkpointing_policy=MyTransformerBlock)
Part 2
In addition, #18045 also proposes to rename activation_checkpointing
to activation_checkpointing_policy
with the argument that the newly provided functionality is a policy and consistent with a similarly functioning argument auto_wrap_policy
. Given my pitch above, and the fact that it is no longer necessary to explicitly construct a policy (it is optional), we could go for a simplified and convenient API like this:
strategy = FSDPStrategy(
auto_wrap={MyTransformerBlock, ...} # automatically converted to a policy under the hood
activation_checkpointing_policy={MyTransformerBlock, ...} # automatically converted to a policy under the hood
)
Alternatives
No response
Additional context
I strongly vote for this pitch because the purpose of Lightning is to make these tools accessible to as many users as possible, whether they are distributed training experts or not. Making complex features in PyTorch easily configurable is at the core of Lightning's philosophy :)