Skip to content

Commit 0f4f809

Browse files
authored
Deprecate the FairScale integration (#16353)
1 parent fce54a4 commit 0f4f809

File tree

20 files changed

+251
-387
lines changed

20 files changed

+251
-387
lines changed

docs/source-pytorch/advanced/model_parallel.rst

Lines changed: 8 additions & 164 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@ Train 1 trillion+ parameter models
66

77
When training large models, fitting larger batch sizes, or trying to increase throughput using multi-GPU compute, Lightning provides advanced optimized distributed training strategies to support these cases and offer substantial improvements in memory usage.
88

9-
In many cases these strategies are some flavour of model parallelism however we only introduce concepts at a high level to get you started. Refer to the `FairScale documentation <https://fairscale.readthedocs.io/en/latest/deep_dive/oss_sdp_fsdp.html>`_ for more information about model parallelism.
10-
119
Note that some of the extreme memory saving configurations will affect the speed of training. This Speed/Memory trade-off in most cases can be adjusted.
1210

1311
Some of these memory-efficient strategies rely on offloading onto other forms of memory, such as CPU RAM or NVMe. This means you can even see memory benefits on a **single GPU**, using a strategy such as :ref:`deepspeed-zero-stage-3-offload`.
@@ -40,7 +38,7 @@ Overall:
4038

4139
* When **fine-tuning** a model, use advanced memory efficient strategies such as :ref:`deepspeed-zero-stage-3` or :ref:`deepspeed-zero-stage-3-offload`, allowing you to fine-tune larger models if you are limited on compute
4240
* When **pre-training** a model, use simpler optimizations such :ref:`sharded-training`, :ref:`deepspeed-zero-stage-2` or :ref:`fully-sharded-training`, scaling the number of GPUs to reach larger parameter sizes
43-
* For both fine-tuning and pre-training, use :ref:`deepspeed-activation-checkpointing` or :ref:`fairscale-activation-checkpointing` as the throughput degradation is not significant
41+
* For both fine-tuning and pre-training, use :ref:`deepspeed-activation-checkpointing` as the throughput degradation is not significant
4442

4543
For example when using 128 GPUs, you can **pre-train** large 10 to 20 Billion parameter models using :ref:`deepspeed-zero-stage-2` without having to take a performance hit with more advanced optimized multi-gpu strategy.
4644

@@ -153,11 +151,10 @@ Here's an example of changing the placement policy to "cpu".
153151
154152
.. _sharded-training:
155153

156-
**************************
157-
FairScale Sharded Training
158-
**************************
154+
****************
155+
Sharded Training
156+
****************
159157

160-
Lightning integration of optimizer sharded training provided by `FairScale <https://github.com/facebookresearch/fairscale>`_.
161158
The technique can be found within `DeepSpeed ZeRO <https://arxiv.org/abs/1910.02054>`_ and
162159
`ZeRO-2 <https://www.microsoft.com/en-us/research/blog/zero-2-deepspeed-shattering-barriers-of-deep-learning-speed-scale/>`_,
163160
however the implementation is built from the ground up to be PyTorch compatible and standalone.
@@ -171,178 +168,25 @@ these benefits in multi-GPU setups are almost free and throughput scales well wi
171168

172169
It is highly recommended to use Sharded Training in multi-GPU environments where memory is limited, or where training larger models are beneficial (500M+ parameter models).
173170
A technical note: as batch size scales, storing activations for the backwards pass becomes the bottleneck in training. As a result, sharding optimizer state and gradients becomes less impactful.
174-
Use :ref:`fairscale-activation-checkpointing` to see even more benefit at the cost of some throughput.
175-
176-
To use Sharded Training, you need to first install FairScale using the command below.
177-
178-
.. code-block:: bash
179-
180-
pip install fairscale
181-
182171

183172
.. code-block:: python
184173
185174
# train using Sharded DDP
186175
trainer = Trainer(strategy="ddp_sharded")
187176
188-
Sharded Training can work across all DDP variants by adding the additional ``--strategy ddp_sharded`` flag via command line using a PyTorch Lightning script.
189-
190177
Internally we re-initialize your optimizers and shard them across your machines and processes. We handle all communication using PyTorch distributed, so no code changes are required.
191178

192179
----
193180

194181
.. _fully-sharded-training:
195182

196-
FairScale Fully Sharded Training
197-
================================
198-
199-
.. warning::
200-
FairScale Fully Sharded Training is in BETA and the API is subject to change. Please create an `issue <https://github.com/Lightning-AI/lightning/issues>`_ if you run into any problems.
201-
202-
`Fully Sharded <https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html>`_ shards optimizer state, gradients, and parameters across data parallel workers. This allows you to fit much larger models onto multiple GPUs into memory.
203-
204-
Fully Sharded Training alleviates the need to worry about balancing layers onto specific devices using some form of pipe parallelism, and optimizes for distributed communication with minimal effort.
205-
206-
Shard Parameters to Reach 10+ Billion Parameters
207-
------------------------------------------------
208-
209-
To reach larger parameter sizes and to be memory efficient, we have to shard parameters. There are various ways to enable this.
210-
211-
.. note::
212-
Currently Fully Sharded Training relies on the user to wrap the model with Fully Sharded within the ``LightningModule``.
213-
This means you must create a single model that is treated as a ``torch.nn.Module`` within the ``LightningModule``.
214-
This is a limitation of Fully Sharded Training that will be resolved in the future.
215-
216-
Enabling Module Sharding for Maximum Memory Efficiency
217-
------------------------------------------------------
218-
219-
Auto Wrapping
220-
^^^^^^^^^^^^^
221-
222-
Model layers should be wrapped in FSDP in a nested way to save peak memory and enable communication and computation overlapping. The
223-
simplest way to do it is auto wrapping, which can serve as a drop-in replacement for DDP without changing the rest of the code. You don't
224-
have to ``wrap`` layers manually as in the case of manual wrapping.
225-
226-
.. note::
227-
While initializing the optimizers inside ``configure_optimizers`` hook, make sure to use ``self.trainer.model.parameters()``, else
228-
PyTorch will raise an error. This is required because when you use auto-wrap, the model layers are sharded and your
229-
``lightning_module.parameters()`` will return a generator with no params. This inconvenience will be addressed in the future.
230-
231-
.. code-block:: python
232-
233-
class MyModel(BoringModel):
234-
def configure_optimizers(self):
235-
return torch.optim.AdamW(self.trainer.model.parameters(), lr=1e-2)
236-
237-
238-
model = MyModel()
239-
trainer = Trainer(accelerator="gpu", devices=4, strategy="fsdp", precision=16)
240-
trainer.fit(model)
241-
242-
243-
Manual Wrapping
244-
^^^^^^^^^^^^^^^
245-
246-
Manual wrapping can be useful to explore complex sharding strategies by applying ``wrap`` selectively to some parts of the model. To activate
247-
parameter sharding with manual wrapping, you can wrap your model using the ``wrap`` function. Internally in Lightning, we enable a context manager around the ``configure_sharded_model`` function to make sure the ``wrap`` parameters are passed correctly.
248-
249-
When not using Fully Sharded Training these wrap functions are a no-op. That means once the changes have been made, there is no need to remove the changes for other strategies.
250-
251-
``auto_wrap`` recursively wraps :class:`~torch.nn.Module` within the ``LightningModule`` with nested Fully Sharded Wrappers,
252-
signalling that we'd like to partition these modules across data parallel devices, discarding the full weights when not required (information :class:`here <fairscale.nn.fsdp>`).
253-
254-
``auto_wrap`` can have varying levels of success based on the complexity of your model. **Auto Wrap does not support models with shared parameters**.
255-
256-
``wrap`` simply wraps the module with a Fully Sharded Parallel class with the correct parameters from the Lightning context manager.
257-
258-
Here's an example using both ``wrap`` and ``auto_wrap`` to create your model:
259-
260-
.. code-block:: python
261-
262-
import torch
263-
import torch.nn as nn
264-
import pytorch_lightning as pl
265-
from pytorch_lightning import Trainer
266-
from fairscale.nn import checkpoint_wrapper, auto_wrap, wrap
267-
268-
269-
class MyModel(pl.LightningModule):
270-
def __init__(self):
271-
super().__init__()
272-
self.linear_layer = nn.Linear(32, 32)
273-
self.block = nn.Sequential(nn.Linear(32, 32), nn.ReLU())
274-
self.final_block = nn.Sequential(nn.Linear(32, 32), nn.ReLU())
275-
276-
def configure_sharded_model(self):
277-
# modules are sharded across processes
278-
# as soon as they are wrapped with `wrap` or `auto_wrap`.
279-
# During the forward/backward passes, weights get synced across processes
280-
# and de-allocated once computation is complete, saving memory.
281-
282-
# Wraps the layer in a Fully Sharded Wrapper automatically
283-
linear_layer = wrap(self.linear_layer)
284-
285-
# Wraps the module recursively
286-
# based on a minimum number of parameters (default 100M parameters)
287-
block = auto_wrap(self.block)
288-
289-
# For best memory efficiency,
290-
# add FairScale activation checkpointing
291-
final_block = auto_wrap(checkpoint_wrapper(self.final_block))
292-
self.model = nn.Sequential(linear_layer, nn.ReLU(), block, final_block)
293-
294-
def configure_optimizers(self):
295-
return torch.optim.AdamW(self.model.parameters(), lr=1e-2)
296-
297-
298-
model = MyModel()
299-
trainer = Trainer(accelerator="gpu", devices=4, strategy="fsdp", precision=16)
300-
trainer.fit(model)
301-
302-
trainer.test()
303-
trainer.predict()
304-
305-
----
306-
307-
.. _fairscale-activation-checkpointing:
308-
309-
Activation Checkpointing
310-
------------------------
311-
312-
Activation checkpointing frees activations from memory as soon as they are not needed during the forward pass. They are then re-computed for the backwards pass as needed. Activation checkpointing is very useful when you have intermediate layers that produce large activations.
313-
314-
FairScale's checkpointing wrapper also handles batch norm layers correctly, unlike the PyTorch implementation, ensuring stats are tracked correctly due to the multiple forward passes.
315-
316-
This saves memory when training larger models, however it requires wrapping modules you'd like to use activation checkpointing on. See :class:`here <fairscale.nn.checkpoint.checkpoint_wrapper>` for more information.
317-
318-
.. warning::
319-
320-
Do not wrap the entire model with activation checkpointing. This is not the intended use of activation checkpointing, and will lead to failures as seen in `this discussion <https://github.com/Lightning-AI/lightning/discussions/9144>`_.
321-
322-
.. code-block:: python
323-
324-
from pytorch_lightning import Trainer
325-
from fairscale.nn import checkpoint_wrapper
326-
327-
328-
class MyModel(pl.LightningModule):
329-
def __init__(self):
330-
super().__init__()
331-
# Wrap layers using checkpoint_wrapper
332-
self.block_1 = checkpoint_wrapper(nn.Sequential(nn.Linear(32, 32), nn.ReLU()))
333-
self.block_2 = nn.Linear(32, 2)
334-
335-
----
336-
337-
.. _fully-sharded-native-training:
338-
339-
******************************
340-
PyTorch Fully Sharded Training
341-
******************************
183+
**********************
184+
Fully Sharded Training
185+
**********************
342186

343187
PyTorch has it's own version of `FSDP <https://pytorch.org/docs/stable/fsdp.html>`_ which is upstreamed from their `fairscale <https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html>`__ project.
344188
It was introduced in their `v1.11.0 release <https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/>`_ but it is recommended to use it with PyTorch v1.12 or more and that's what
345-
Lightning supports. The API is pretty similar to that of FairScale.
189+
Lightning supports.
346190

347191

348192
Auto Wrapping

docs/source-pytorch/extensions/strategy.rst

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -80,16 +80,7 @@ The below table lists all relevant strategies available in Lightning with their
8080
- Colossal-AI provides a collection of parallel components for you. It aims to support you to write your distributed deep learning models just like how you write your model on your laptop. `Learn more. <https://www.colossalai.org/>`__
8181
* - fsdp_native
8282
- :class:`~pytorch_lightning.strategies.DDPFullyShardedNativeStrategy`
83-
- Strategy for Fully Sharded Data Parallel provided by PyTorch. :ref:`Learn more. <advanced/model_parallel:PyTorch Fully Sharded Training>`
84-
* - fsdp
85-
- :class:`~pytorch_lightning.strategies.DDPFullyShardedStrategy`
86-
- Strategy for Fully Sharded Data Parallel provided by FairScale. :ref:`Learn more. <advanced/model_parallel:FairScale Fully Sharded Training>`
87-
* - ddp_sharded
88-
- :class:`~pytorch_lightning.strategies.DDPShardedStrategy`
89-
- Optimizer and gradient sharded training provided by FairScale. :ref:`Learn more. <advanced/model_parallel:FairScale Sharded Training>`
90-
* - ddp_sharded_spawn
91-
- :class:`~pytorch_lightning.strategies.DDPSpawnShardedStrategy`
92-
- Optimizer sharded training provided by FairScale. :ref:`Learn more. <advanced/model_parallel:FairScale Sharded Training>`
83+
- Strategy for Fully Sharded Data Parallel. :ref:`Learn more. <advanced/model_parallel:Fully Sharded Training>`
9384
* - ddp_spawn
9485
- :class:`~pytorch_lightning.strategies.DDPSpawnStrategy`
9586
- Spawns processes using the :func:`torch.multiprocessing.spawn` method and joins processes after training finishes. :ref:`Learn more. <accelerators/gpu_intermediate:Distributed Data Parallel Spawn>`

src/pytorch_lightning/CHANGELOG.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,19 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
5151
* Deprecated the `Trainer.amp_backend` property
5252
* Deprecated the `Trainer(amp_level=...)` argument
5353
* Deprecated the `pytorch_lightning.plugins.ApexMixedPrecisionPlugin` class
54-
* Deprecates the `pytorch_lightning.utilities.enum.sAMPType` enum
54+
* Deprecates the `pytorch_lightning.utilities.enums.AMPType` enum
5555
* Deprecates the `DeepSpeedPrecisionPlugin(amp_type=..., amp_level=...)` arguments
5656
- `horovod` deprecation ([#16141](https://github.com/PyTorchLightning/pytorch-lightning/pull/16141))
5757
* Deprecated `Trainer(strategy="horovod")`
5858
* Deprecated the `HorovodStrategy` class
5959
- Deprecated `pytorch_lightning.lite.LightningLite` in favor of `lightning.fabric.Fabric` ([#16314](https://github.com/Lightning-AI/lightning/pull/16314))
60+
- `FairScale` deprecation (in favor of PyTorch's FSDP implementation) ([#16353](https://github.com/PyTorchLightning/pytorch-lightning/pull/16353))
61+
* Deprecated the `pytorch_lightning.overrides.fairscale.LightningShardedDataParallel` class
62+
* Deprecated the `pytorch_lightning.plugins.precision.fully_sharded_native_amp.FullyShardedNativeMixedPrecisionPlugin` class
63+
* Deprecated the `pytorch_lightning.plugins.precision.sharded_native_amp.ShardedNativeMixedPrecisionPlugin` class
64+
* Deprecated the `pytorch_lightning.strategies.fully_sharded.DDPFullyShardedStrategy` class
65+
* Deprecated the `pytorch_lightning.strategies.sharded.DDPShardedStrategy` class
66+
* Deprecated the `pytorch_lightning.strategies.sharded_spawn.DDPSpawnShardedStrategy` class
6067

6168

6269
### Removed

src/pytorch_lightning/overrides/fairscale.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,13 @@ def __init__(
4141
forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
4242
pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
4343
) -> None:
44+
rank_zero_deprecation(
45+
"PyTorch Lightning's sharded implementation using FairScale has been deprecated in v1.9.0 and will be"
46+
" removed in v2.0.0. You can try using the `Trainer(strategy='fsdp_native')` instead."
47+
" The difference is that native FSDP uses PyTorch's implementation and the current strategy uses"
48+
" FairScale's implementation (which was upstreamed to PyTorch). After removal, `strategy='fsdp'` will use"
49+
" the native version by default."
50+
)
4451
self._validate_init_arguments(pl_module, forward_module)
4552
super().__init__(forward_module=(pl_module or forward_module))
4653

src/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,22 @@
1515

1616
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin
1717
from pytorch_lightning.utilities.exceptions import MisconfigurationException
18+
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
1819

1920

2021
class FullyShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin):
2122
"""Native AMP for Fully Sharded Training."""
2223

24+
def __init__(self, *args: Any, **kwargs: Any) -> None:
25+
rank_zero_deprecation(
26+
"PyTorch Lightning's sharded implementation using FairScale has been deprecated in v1.9.0 and will be"
27+
" removed in v2.0.0. You can try using the `Trainer(strategy='fsdp_native')` instead."
28+
" The difference is that native FSDP uses PyTorch's implementation and the current strategy uses"
29+
" FairScale's implementation (which was upstreamed to PyTorch). After removal, `strategy='fsdp'` will use"
30+
" the native version by default."
31+
)
32+
super().__init__(*args, **kwargs)
33+
2334
def clip_grad_by_norm(self, *_: Any, **__: Any) -> None:
2435
# see https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html
2536
# section `Gradient Clipping`, using `torch.nn.utils.clip_grad_norm_` is incorrect

src/pytorch_lightning/plugins/precision/sharded_native_amp.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE
1919
from pytorch_lightning.plugins.precision.native_amp import MixedPrecisionPlugin
2020
from pytorch_lightning.utilities.exceptions import MisconfigurationException
21+
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
2122

2223
if _FAIRSCALE_AVAILABLE:
2324
from fairscale.optim import OSS
@@ -32,6 +33,13 @@ class ShardedNativeMixedPrecisionPlugin(MixedPrecisionPlugin):
3233
def __init__(
3334
self, precision: Literal["16", 16, "bf16"], device: str, scaler: Optional[ShardedGradScaler] = None
3435
) -> None:
36+
rank_zero_deprecation(
37+
"PyTorch Lightning's sharded implementation using FairScale has been deprecated in v1.9.0 and will be"
38+
" removed in v2.0.0. You can try using the `Trainer(strategy='fsdp_native')` instead."
39+
" The difference is that native FSDP uses PyTorch's implementation and the current strategy uses"
40+
" FairScale's implementation (which was upstreamed to PyTorch). After removal, `strategy='fsdp'` will use"
41+
" the native version by default."
42+
)
3543
if not _FAIRSCALE_AVAILABLE:
3644
raise MisconfigurationException(
3745
"You have asked for sharded AMP but you have not installed it."

src/pytorch_lightning/strategies/fully_sharded.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from pytorch_lightning.trainer.states import TrainerFn
2929
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3030
from pytorch_lightning.utilities.model_helpers import is_overridden
31+
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
3132
from pytorch_lightning.utilities.types import STEP_OUTPUT
3233

3334
if _FAIRSCALE_AVAILABLE:
@@ -117,7 +118,13 @@ def __init__(
117118
If ``False``, this will default to ``compute_device``.
118119
(Default: True).
119120
"""
120-
121+
rank_zero_deprecation(
122+
"PyTorch Lightning's sharded implementation using FairScale has been deprecated in v1.9.0 and will be"
123+
" removed in v2.0.0. You can try using the `Trainer(strategy='fsdp_native')` instead."
124+
" The difference is that native FSDP uses PyTorch's implementation and the current strategy uses"
125+
" FairScale's implementation (which was upstreamed to PyTorch). After removal, `strategy='fsdp'` will use"
126+
" the native version by default."
127+
)
121128
super().__init__(
122129
accelerator=accelerator,
123130
parallel_devices=parallel_devices,

0 commit comments

Comments
 (0)