Skip to content

Commit 2debd1c

Browse files
awaelchliBorda
andauthored
Simplify enabling CPU offload in FSDP (#15832)
Co-authored-by: Jirka Borovec <[email protected]>
1 parent 852089e commit 2debd1c

File tree

6 files changed

+58
-25
lines changed

6 files changed

+58
-25
lines changed

docs/source-pytorch/advanced/model_parallel.rst

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -424,10 +424,9 @@ You can customize the strategy configuration by adjusting the arguments of :clas
424424
425425
from pytorch_lightning import Trainer
426426
from pytorch_lightning.strategies import DDPFullyShardedNativeStrategy
427-
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
428427
429428
430-
native_fsdp = DDPFullyShardedNativeStrategy(cpu_offload=CPUOffload(offload_params=True))
429+
native_fsdp = DDPFullyShardedNativeStrategy(cpu_offload=True)
431430
trainer = pl.Trainer(strategy=native_fsdp, accelerator="gpu", devices=4)
432431
433432

src/lightning_lite/strategies/fsdp.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,10 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
6969
`this tutorial <https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html>`__ for more information.
7070
7171
Arguments:
72-
cpu_offload: CPU offloading config. Currently, only parameter and gradient CPU offload is supported. It
73-
can be enabled via passing in ``cpu_offload=CPUOffload(offload_params=True)``. Note that this currently
72+
cpu_offload: Enable offloading parameters and gradients to CPU to save GPU memory at the cost of speed.
73+
You can also pass a config: ``cpu_offload=CPUOffload(offload_params=True)``. Note that this currently
7474
implicitly enables gradient offloading to CPU in order for parameters and gradients to be on same device
75-
to work with the optimizer. This API is subject to change. Default is ``None`` in which case there
76-
will be no offloading.
75+
to work with the optimizer. This API is subject to change. Default: no offoading
7776
backward_prefetch: This is an experimental feature that is subject to change in the near future. It allows
7877
users to enable two different backward prefetching algorithms to help backward communication and
7978
computation overlapping. The pros and cons of each algorithm is explained in the class ``BackwardPrefetch``.
@@ -96,7 +95,7 @@ def __init__(
9695
precision: Optional[Precision] = None,
9796
process_group_backend: Optional[str] = None,
9897
timeout: Optional[timedelta] = default_pg_timeout,
99-
cpu_offload: Optional["CPUOffload"] = None,
98+
cpu_offload: Union[bool, "CPUOffload", None] = None,
10099
backward_prefetch: Optional["BackwardPrefetch"] = None,
101100
mixed_precision: Optional["MixedPrecision"] = None,
102101
activation_checkpointing: Optional[Union[Type[Module], List[Type[Module]]]] = None,
@@ -125,7 +124,7 @@ def __init__(
125124
[activation_checkpointing] if not isinstance(activation_checkpointing, list) else activation_checkpointing
126125
)
127126

128-
self.cpu_offload = cpu_offload
127+
self.cpu_offload = _init_cpu_offload(cpu_offload)
129128
self.backward_prefetch = backward_prefetch
130129
self.mixed_precision = mixed_precision
131130

@@ -276,7 +275,6 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
276275
def register_strategies(cls, strategy_registry: Dict) -> None:
277276
if not _TORCH_GREATER_EQUAL_1_12 or not torch.distributed.is_available():
278277
return
279-
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
280278

281279
strategy_registry.register(
282280
"fsdp",
@@ -287,7 +285,7 @@ def register_strategies(cls, strategy_registry: Dict) -> None:
287285
"fsdp_full_shard_offload",
288286
cls,
289287
description="Native FSDP with Full Sharding and CPU Offloading",
290-
cpu_offload=CPUOffload(offload_params=True),
288+
cpu_offload=True,
291289
)
292290

293291
def _setup_distributed(self) -> None:
@@ -341,6 +339,12 @@ def no_backward_sync(self, module: Module) -> Generator:
341339
yield
342340

343341

342+
def _init_cpu_offload(cpu_offload: Optional[Union[bool, "CPUOffload"]]) -> "CPUOffload":
343+
from torch.distributed.fsdp import CPUOffload
344+
345+
return cpu_offload if isinstance(cpu_offload, CPUOffload) else CPUOffload(offload_params=bool(cpu_offload))
346+
347+
344348
def _optimizer_has_flat_params(optimizer: Optimizer) -> bool:
345349
from torch.distributed.fsdp import FlatParameter
346350

src/pytorch_lightning/CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3333
- Added support for activation checkpointing for the `DDPFullyShardedNativeStrategy` strategy ([#15826](https://github.com/Lightning-AI/lightning/pull/15826))
3434

3535

36+
37+
- Added the option to set `DDPFullyShardedNativeStrategy(cpu_offload=True|False)` via bool instead of needing to pass a configufation object ([#15832](https://github.com/Lightning-AI/lightning/pull/15832))
38+
39+
3640
### Changed
3741

3842
- Drop PyTorch 1.9 support ([#15347](https://github.com/Lightning-AI/lightning/pull/15347))

src/pytorch_lightning/strategies/fully_sharded_native.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,11 @@
2121

2222
import pytorch_lightning as pl
2323
from lightning_lite.plugins import CheckpointIO, ClusterEnvironment
24-
from lightning_lite.strategies.fsdp import _optimizer_has_flat_params, _setup_activation_checkpointing
24+
from lightning_lite.strategies.fsdp import (
25+
_init_cpu_offload,
26+
_optimizer_has_flat_params,
27+
_setup_activation_checkpointing,
28+
)
2529
from lightning_lite.utilities.distributed import (
2630
_get_default_process_group_backend_for_device,
2731
_init_dist_connection,
@@ -84,14 +88,10 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
8488
`this tutorial <https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html>`__ for more information.
8589
8690
Arguments:
87-
cpu_offload:
88-
CPU offloading config. Currently, only parameter and gradient CPU
89-
offload is supported. It can be enabled via passing in
90-
``cpu_offload=CPUOffload(offload_params=True)``. Note that this
91-
currently implicitly enables gradient offloading to CPU in order for
92-
params and grads to be on same device to work with optimizer. This
93-
API is subject to change. Default is ``None`` in which case there
94-
will be no offloading.
91+
cpu_offload: Enable offloading parameters and gradients to CPU to save GPU memory at the cost of speed.
92+
You can also pass a config: ``cpu_offload=CPUOffload(offload_params=True)``. Note that this currently
93+
implicitly enables gradient offloading to CPU in order for parameters and gradients to be on same device
94+
to work with the optimizer. This API is subject to change. Default: no offoading
9595
backward_prefetch:
9696
This is an experimental feature that is subject to change in the
9797
the near future. It allows users to enable two different backward_prefetch
@@ -120,7 +120,7 @@ def __init__(
120120
checkpoint_io: Optional[CheckpointIO] = None,
121121
precision_plugin: Optional[PrecisionPlugin] = None,
122122
process_group_backend: Optional[str] = None,
123-
cpu_offload: Optional[CPUOffload] = None,
123+
cpu_offload: Union[bool, "CPUOffload", None] = None,
124124
backward_prefetch: Optional[BackwardPrefetch] = None,
125125
mixed_precision: Optional[MixedPrecision] = None,
126126
activation_checkpointing: Optional[Union[Type[Module], List[Type[Module]]]] = None,
@@ -141,7 +141,7 @@ def __init__(
141141
self._process_group = None
142142
self.num_nodes = 1
143143
self._process_group_backend = process_group_backend
144-
self.cpu_offload = cpu_offload
144+
self.cpu_offload = _init_cpu_offload(cpu_offload)
145145
self.backward_prefetch = backward_prefetch
146146
self.mixed_precision = mixed_precision
147147
self._rank_0_will_call_children_scripts: bool = False
@@ -403,6 +403,6 @@ def register_strategies(cls, strategy_registry: Dict) -> None:
403403
"fsdp_native_full_shard_offload",
404404
cls,
405405
description="Native FSDP with Full Sharding and CPU Offloading",
406-
cpu_offload=CPUOffload(offload_params=True),
406+
cpu_offload=True,
407407
)
408408
cls._registered_strategies.append("fsdp_native_full_shard_offload")

tests/tests_lite/strategies/test_fsdp.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_12
2727

2828
if _TORCH_GREATER_EQUAL_1_12:
29-
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel, MixedPrecision
29+
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision
3030

3131

3232
@mock.patch("lightning_lite.strategies.fsdp._TORCH_GREATER_EQUAL_1_12", False)
@@ -36,13 +36,26 @@ def test_fsdp_support(*_):
3636

3737

3838
@RunIf(min_torch="1.12")
39-
def test_fsdp_custom_mixed_precision(*_):
39+
def test_fsdp_custom_mixed_precision():
4040
"""Test that passing a custom mixed precision config works."""
4141
config = MixedPrecision()
4242
strategy = FSDPStrategy(mixed_precision=config)
4343
assert strategy.mixed_precision_config == config
4444

4545

46+
@RunIf(min_torch="1.12")
47+
def test_fsdp_cpu_offload():
48+
"""Test the different ways cpu offloading can be enabled."""
49+
# bool
50+
strategy = FSDPStrategy(cpu_offload=True)
51+
assert strategy.cpu_offload == CPUOffload(offload_params=True)
52+
53+
# dataclass
54+
config = CPUOffload()
55+
strategy = FSDPStrategy(cpu_offload=config)
56+
assert strategy.cpu_offload == config
57+
58+
4659
@RunIf(min_torch="1.12")
4760
def test_fsdp_setup_optimizer_validation():
4861
"""Test that `setup_optimizer()` validates the param groups and reference to FSDP parameters."""

tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from tests_pytorch.helpers.runif import RunIf
1818

1919
if _TORCH_GREATER_EQUAL_1_12:
20-
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel, MixedPrecision
20+
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision
2121
from torch.distributed.fsdp.wrap import wrap
2222

2323

@@ -306,3 +306,16 @@ def __init__(self):
306306
) as ckpt_mock:
307307
strategy._setup_model(model)
308308
ckpt_mock.assert_called_with(fsdp_mock(), checkpoint_wrapper_fn=ANY, check_fn=ANY)
309+
310+
311+
@RunIf(min_torch="1.12")
312+
def test_fully_sharded_native_strategy_cpu_offload():
313+
"""Test the different ways cpu offloading can be enabled."""
314+
# bool
315+
strategy = DDPFullyShardedNativeStrategy(cpu_offload=True)
316+
assert strategy.cpu_offload == CPUOffload(offload_params=True)
317+
318+
# dataclass
319+
config = CPUOffload()
320+
strategy = DDPFullyShardedNativeStrategy(cpu_offload=config)
321+
assert strategy.cpu_offload == config

0 commit comments

Comments
 (0)