Skip to content

Support custom policies for activation checkpointing with FSDP #18007

@carmocca

Description

@carmocca

Description & Motivation

The current mechanism for activation checkpointing is limited to checking for module types:

FSDPStrategy(activation_checkpointing=torch.nn.Linear)
FSDPStrategy(activation_checkpointing=[torch.nn.Linear, TransformerBlock])

https://github.com/Lightning-AI/lightning/blob/f4240ca42c75ad67b2655351b38830fa0ba82cba/src/lightning/fabric/strategies/fsdp.py#L604-L609

PyTorch 2.1 has added support (pytorch/pytorch#102672) for

from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import apply_activation_checkpointing

apply_activation_checkpointing(model, auto_wrap_policy=...)

For example, an equivalent policy to the examples above would be

from torch.distributed.fsdp.wrap import ModuleWrapPolicy

my_policy = ModuleWrapPolicy({torch.nn.Linear, TransformerBlock})

And the FSDP polices would work too: always_wrap_policy, lambda_auto_wrap_policy, size_based_auto_wrap_policy

This pattern is convenient, not only because of the added expressiveness, but also because the policy can be exactly the same as the auto_wrap_policy which is often set together with activation_checkpointing (example on lit-gpt: https://github.com/Lightning-AI/lit-gpt/blob/main/pretrain/redpajama.py#L80-L83)

Pitch

Add an activation_checkpointing_policy argument to FSDPStrategy that takes a valid policy. Only available with torch>=2.1. Example use:

policy = ...

FSDPStrategy(
	auto_wrap_policy=policy,
	activation_checkpointing_policy=policy
)

Deprecate activation_checkpointing (with no removal version) when run with torch<2.1

Alternatives

  • Overload the activation_checkpointing argument, allowing it to additionally accept a policy
  • Not do it, apply_activation_checkpointing(check_fn=...) hasn't been deprecated upstream.

Additional context

A related issue is #16991. That issue proposes overloading auto_wrap_policy based on its value's type.
If activation_checkpointing is deprecated, then that issue should be closed.
On the other hand, if it is considered that passing a list of types is worth keeping and both the policy and list options are kept, then that would be example in line with #16991

cc @Borda @tchaton @carmocca @justusschock @awaelchli

Metadata

Metadata

Assignees

Labels

deprecationIncludes a deprecationfabriclightning.fabric.FabricfeatureIs an improvement or enhancementplGeneric label for PyTorch Lightning packagestrategy: fsdpFully Sharded Data Parallel

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions