You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When training large models, fitting larger batch sizes, or trying to increase throughput using multi-GPU compute, Lightning provides advanced optimized distributed training strategies to support these cases and offer substantial improvements in memory usage.
8
8
9
-
In many cases these strategies are some flavour of model parallelism however we only introduce concepts at a high level to get you started. Refer to the `FairScale documentation <https://fairscale.readthedocs.io/en/latest/deep_dive/oss_sdp_fsdp.html>`_ for more information about model parallelism.
10
-
11
9
Note that some of the extreme memory saving configurations will affect the speed of training. This Speed/Memory trade-off in most cases can be adjusted.
12
10
13
11
Some of these memory-efficient strategies rely on offloading onto other forms of memory, such as CPU RAM or NVMe. This means you can even see memory benefits on a **single GPU**, using a strategy such as :ref:`deepspeed-zero-stage-3-offload`.
@@ -40,7 +38,7 @@ Overall:
40
38
41
39
* When **fine-tuning** a model, use advanced memory efficient strategies such as :ref:`deepspeed-zero-stage-3` or :ref:`deepspeed-zero-stage-3-offload`, allowing you to fine-tune larger models if you are limited on compute
42
40
* When **pre-training** a model, use simpler optimizations such :ref:`sharded-training`, :ref:`deepspeed-zero-stage-2` or :ref:`fully-sharded-training`, scaling the number of GPUs to reach larger parameter sizes
43
-
* For both fine-tuning and pre-training, use :ref:`deepspeed-activation-checkpointing` or :ref:`fairscale-activation-checkpointing` as the throughput degradation is not significant
41
+
* For both fine-tuning and pre-training, use :ref:`deepspeed-activation-checkpointing` as the throughput degradation is not significant
44
42
45
43
For example when using 128 GPUs, you can **pre-train** large 10 to 20 Billion parameter models using :ref:`deepspeed-zero-stage-2` without having to take a performance hit with more advanced optimized multi-gpu strategy.
46
44
@@ -153,11 +151,10 @@ Here's an example of changing the placement policy to "cpu".
153
151
154
152
.. _sharded-training:
155
153
156
-
**************************
157
-
FairScale Sharded Training
158
-
**************************
154
+
****************
155
+
Sharded Training
156
+
****************
159
157
160
-
Lightning integration of optimizer sharded training provided by `FairScale <https://github.com/facebookresearch/fairscale>`_.
161
158
The technique can be found within `DeepSpeed ZeRO <https://arxiv.org/abs/1910.02054>`_ and
however the implementation is built from the ground up to be PyTorch compatible and standalone.
@@ -171,178 +168,25 @@ these benefits in multi-GPU setups are almost free and throughput scales well wi
171
168
172
169
It is highly recommended to use Sharded Training in multi-GPU environments where memory is limited, or where training larger models are beneficial (500M+ parameter models).
173
170
A technical note: as batch size scales, storing activations for the backwards pass becomes the bottleneck in training. As a result, sharding optimizer state and gradients becomes less impactful.
174
-
Use :ref:`fairscale-activation-checkpointing` to see even more benefit at the cost of some throughput.
175
-
176
-
To use Sharded Training, you need to first install FairScale using the command below.
177
-
178
-
.. code-block:: bash
179
-
180
-
pip install fairscale
181
-
182
171
183
172
.. code-block:: python
184
173
185
174
# train using Sharded DDP
186
175
trainer = Trainer(strategy="ddp_sharded")
187
176
188
-
Sharded Training can work across all DDP variants by adding the additional ``--strategy ddp_sharded`` flag via command line using a PyTorch Lightning script.
189
-
190
177
Internally we re-initialize your optimizers and shard them across your machines and processes. We handle all communication using PyTorch distributed, so no code changes are required.
191
178
192
179
----
193
180
194
181
.. _fully-sharded-training:
195
182
196
-
FairScale Fully Sharded Training
197
-
================================
198
-
199
-
.. warning::
200
-
FairScale Fully Sharded Training is in BETA and the API is subject to change. Please create an `issue <https://github.com/Lightning-AI/lightning/issues>`_ if you run into any problems.
201
-
202
-
`Fully Sharded <https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html>`_ shards optimizer state, gradients, and parameters across data parallel workers. This allows you to fit much larger models onto multiple GPUs into memory.
203
-
204
-
Fully Sharded Training alleviates the need to worry about balancing layers onto specific devices using some form of pipe parallelism, and optimizes for distributed communication with minimal effort.
205
-
206
-
Shard Parameters to Reach 10+ Billion Parameters
207
-
------------------------------------------------
208
-
209
-
To reach larger parameter sizes and to be memory efficient, we have to shard parameters. There are various ways to enable this.
210
-
211
-
.. note::
212
-
Currently Fully Sharded Training relies on the user to wrap the model with Fully Sharded within the ``LightningModule``.
213
-
This means you must create a single model that is treated as a ``torch.nn.Module`` within the ``LightningModule``.
214
-
This is a limitation of Fully Sharded Training that will be resolved in the future.
215
-
216
-
Enabling Module Sharding for Maximum Memory Efficiency
Manual wrapping can be useful to explore complex sharding strategies by applying ``wrap`` selectively to some parts of the model. To activate
247
-
parameter sharding with manual wrapping, you can wrap your model using the ``wrap`` function. Internally in Lightning, we enable a context manager around the ``configure_sharded_model`` function to make sure the ``wrap`` parameters are passed correctly.
248
-
249
-
When not using Fully Sharded Training these wrap functions are a no-op. That means once the changes have been made, there is no need to remove the changes for other strategies.
250
-
251
-
``auto_wrap`` recursively wraps :class:`~torch.nn.Module` within the ``LightningModule`` with nested Fully Sharded Wrappers,
252
-
signalling that we'd like to partition these modules across data parallel devices, discarding the full weights when not required (information :class:`here <fairscale.nn.fsdp>`).
253
-
254
-
``auto_wrap`` can have varying levels of success based on the complexity of your model. **Auto Wrap does not support models with shared parameters**.
255
-
256
-
``wrap`` simply wraps the module with a Fully Sharded Parallel class with the correct parameters from the Lightning context manager.
257
-
258
-
Here's an example using both ``wrap`` and ``auto_wrap`` to create your model:
259
-
260
-
.. code-block:: python
261
-
262
-
import torch
263
-
import torch.nn as nn
264
-
import pytorch_lightning as pl
265
-
from pytorch_lightning import Trainer
266
-
from fairscale.nn import checkpoint_wrapper, auto_wrap, wrap
Activation checkpointing frees activations from memory as soon as they are not needed during the forward pass. They are then re-computed for the backwards pass as needed. Activation checkpointing is very useful when you have intermediate layers that produce large activations.
313
-
314
-
FairScale's checkpointing wrapper also handles batch norm layers correctly, unlike the PyTorch implementation, ensuring stats are tracked correctly due to the multiple forward passes.
315
-
316
-
This saves memory when training larger models, however it requires wrapping modules you'd like to use activation checkpointing on. See :class:`here <fairscale.nn.checkpoint.checkpoint_wrapper>` for more information.
317
-
318
-
.. warning::
319
-
320
-
Do not wrap the entire model with activation checkpointing. This is not the intended use of activation checkpointing, and will lead to failures as seen in `this discussion <https://github.com/Lightning-AI/lightning/discussions/9144>`_.
PyTorch has it's own version of `FSDP <https://pytorch.org/docs/stable/fsdp.html>`_ which is upstreamed from their `fairscale <https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html>`__ project.
344
188
It was introduced in their `v1.11.0 release <https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/>`_ but it is recommended to use it with PyTorch v1.12 or more and that's what
345
-
Lightning supports. The API is pretty similar to that of FairScale.
Copy file name to clipboardExpand all lines: docs/source-pytorch/extensions/strategy.rst
+1-10Lines changed: 1 addition & 10 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -80,16 +80,7 @@ The below table lists all relevant strategies available in Lightning with their
80
80
- Colossal-AI provides a collection of parallel components for you. It aims to support you to write your distributed deep learning models just like how you write your model on your laptop. `Learn more. <https://www.colossalai.org/>`__
- Spawns processes using the :func:`torch.multiprocessing.spawn` method and joins processes after training finishes. :ref:`Learn more. <accelerators/gpu_intermediate:Distributed Data Parallel Spawn>`
- Deprecated `pytorch_lightning.lite.LightningLite` in favor of `lightning.fabric.Fabric` ([#16314](https://github.com/Lightning-AI/lightning/pull/16314))
60
+
-`FairScale` deprecation (in favor of PyTorch's FSDP implementation) ([#16353](https://github.com/PyTorchLightning/pytorch-lightning/pull/16353))
61
+
* Deprecated the `pytorch_lightning.overrides.fairscale.LightningShardedDataParallel` class
62
+
* Deprecated the `pytorch_lightning.plugins.precision.fully_sharded_native_amp.FullyShardedNativeMixedPrecisionPlugin` class
63
+
* Deprecated the `pytorch_lightning.plugins.precision.sharded_native_amp.ShardedNativeMixedPrecisionPlugin` class
64
+
* Deprecated the `pytorch_lightning.strategies.fully_sharded.DDPFullyShardedStrategy` class
65
+
* Deprecated the `pytorch_lightning.strategies.sharded.DDPShardedStrategy` class
66
+
* Deprecated the `pytorch_lightning.strategies.sharded_spawn.DDPSpawnShardedStrategy` class
0 commit comments