Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source-pytorch/api_references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ precision
DeepSpeedPrecisionPlugin
DoublePrecisionPlugin
FSDPMixedPrecisionPlugin
HalfPrecisionPlugin
MixedPrecisionPlugin
PrecisionPlugin
XLABf16PrecisionPlugin
Expand Down
2 changes: 1 addition & 1 deletion docs/source-pytorch/common/lightning_module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -940,7 +940,7 @@ The type of precision used:
.. code-block:: python

def training_step(self, batch, batch_idx):
if self.precision == 16:
if self.precision == "16-true":
...

trainer
Expand Down
17 changes: 14 additions & 3 deletions docs/source-pytorch/common/precision_basic.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,25 @@ Higher precision, such as the 64-bit floating-point, can be used for highly sens
16-bit Precision
****************

Use 16-bit mixed precision to lower your memory consumption by up to half so that you can train and deploy larger models. If your GPUs are [`Tensor Core <https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html>`_] GPUs, you can also get a ~3x speed improvement. Half precision can sometimes lead to unstable training.
Use 16-bit mixed precision to speed up training and inference.
If your GPUs are [`Tensor Core <https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html>`_] GPUs, you can expect a ~3x speed improvement.

.. code::
.. code-block:: python

Trainer(precision="16-mixed")


With true 16-bit precision you can additionally lower your memory consumption by up to half so that you can train and deploy larger models.
However, this setting can sometimes lead to unstable training.

.. code-block:: python

Trainer(precision="16-true")

Trainer(precision='16-mixed')

----


****************
32-bit Precision
****************
Expand Down
46 changes: 43 additions & 3 deletions docs/source-pytorch/common/precision_intermediate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ delivers all of these benefits while ensuring that no task-specific accuracy is

----


********************
FP16 Mixed Precision
********************
Expand All @@ -68,7 +69,11 @@ Since computation happens in FP16, there is a chance of numerical instability du
.. testcode::
:skipif: not torch.cuda.is_available()

Trainer(accelerator="gpu", devices=1, precision=16)
Trainer(accelerator="gpu", devices=1, precision="16-mixed")


----


************************
BFloat16 Mixed Precision
Expand All @@ -86,16 +91,51 @@ Under the hood, we use `torch.autocast <https://pytorch.org/docs/stable/amp.html
.. testcode::
:skipif: not torch.cuda.is_available()

Trainer(accelerator="gpu", devices=1, precision="bf16")
Trainer(accelerator="gpu", devices=1, precision="bf16-mixed")

It is also possible to use BFloat16 mixed precision on the CPU, relying on MKLDNN under the hood.

.. testcode::

Trainer(precision="bf16")
Trainer(precision="bf16-mixed")


----


*******************
True Half Precision
*******************

As mentioned before, for numerical stability mixed precision keeps the model weights in full float32 precision while casting only supported operations to lower bit precision.
However, in some cases it is indeed possible to train completely in half precision. Similarly, for inference the model weights can often be cast to half precision without a loss in accuracy (even when trained with mixed precision).

.. code-block:: python

# Select FP16 precision
trainer = Trainer(precision="16-true")
trainer.fit(model) # model gets cast to torch.float16

# Select BF16 precision
trainer = Trainer(precision="bf16-true")
trainer.fit(model) # model gets cast to torch.bfloat16

Tip: For faster initialization, you can create model parameters with the desired dtype directly on the device:

.. code-block:: python

trainer = Trainer(precision="bf16-true")

# init the model directly on the device and with parameters in half-precision
with trainer.init_module():
model = MyModel()

trainer.fit(model)


----


***************
8-bit Optimizer
***************
Expand Down
21 changes: 15 additions & 6 deletions docs/source-pytorch/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -810,27 +810,36 @@ precision
^^^^^^^^^

Lightning supports either double (64), float (32), bfloat16 (bf16), or half (16) precision training.

Half precision, or mixed precision, is the combined use of 32 and 16 bit floating points to reduce memory footprint during model training. This can result in improved performance, achieving +3X speedups on modern GPUs.
Half precision is using 16 bit floating point operations while mixed precision is the combined use of 32 and 16 bit floating points to reduce memory footprint during model training. Since not all operations (like batchnorm) are numerically stable in lower bit precisions, these operations will still be carried out in fp32 whereas half precision unconditionally performs all operations in 16 bit.
This can result in improved performance, achieving +3X speedups on modern GPUs.

.. testcode::
:skipif: not torch.cuda.is_available()

# default used by the Trainer
trainer = Trainer(precision=32)

# 16-bit precision
trainer = Trainer(precision="16-mixed", accelerator="gpu", devices=1) # works only on CUDA
# 16-bit mixed precision
trainer = Trainer(precision="16-mixed")

# bfloat16 precision
# bfloat16 mixed precision
trainer = Trainer(precision="bf16-mixed")

# 16-bit true precision
trainer = Trainer(precision="16-true")

# bfloat16 true precision
trainer = Trainer(precision="bf16-true")

# 64-bit precision
trainer = Trainer(precision=64)


See the :doc:`N-bit precision guide <../common/precision>` for more details.

.. note:: When running on TPUs, torch.bfloat16 will be used but tensor printing will still show torch.float32.


profiler
^^^^^^^^

Expand All @@ -841,7 +850,7 @@ profiler

To profile individual steps during training and assist in identifying bottlenecks.

See the :doc:`profiler documentation <../tuning/profiler>`. for more details.
See the :doc:`profiler documentation <../tuning/profiler>` for more details.

.. testcode::

Expand Down
2 changes: 1 addition & 1 deletion docs/source-pytorch/levels/advanced_level_19.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ Explore Intelligence Processing Unit (IPU) for model scaling.

.. displayitem::
:header: Optimize models training on IPUs
:description: Tune model performance with mix-precision and the performance analyser.
:description: Tune model performance with mixed precision and the performance analyser.
:col_css: col-md-4
:button_link: ../accelerators/ipu_intermediate.html
:height: 150
Expand Down
2 changes: 1 addition & 1 deletion docs/source-pytorch/levels/advanced_level_20.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ Explore Intel Habana Processing Unit (HPU) for model scaling.

.. displayitem::
:header: Optimize models training on HPUs
:description: Enable state-of-the-art scaling with advanced mix-precision settings.
:description: Enable state-of-the-art scaling with advanced mixed-precision settings.
:col_css: col-md-6
:button_link: ../integrations/hpu/intermediate.html
:height: 150
Expand Down
2 changes: 1 addition & 1 deletion docs/source-pytorch/model/build_model_intermediate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Enable advanced training features using Trainer arguments. These are SOTA techni
devices=4,
accelerator="gpu",
strategy="deepspeed_stage_2",
precision=16
precision="16-mixed",
)

# 20+ helpful arguments for rapid idea iteration
Expand Down
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Allowed accessing rank information in the main process before processes are launched when using the `XLAStrategy` ([#18194](https://github.com/Lightning-AI/lightning/pull/18194))


- Added support for true half-precision training via `Trainer(precision="16-true"|"bf16-true")` ([#18193](https://github.com/Lightning-AI/lightning/pull/18193))


### Changed

- Removed the limitation to call `self.trainer.model.parameters()` in `LightningModule.configure_optimizers()` ([#17309](https://github.com/Lightning-AI/lightning/pull/17309))
Expand Down
2 changes: 2 additions & 0 deletions src/lightning/pytorch/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from lightning.pytorch.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin
from lightning.pytorch.plugins.precision.double import DoublePrecisionPlugin
from lightning.pytorch.plugins.precision.fsdp import FSDPMixedPrecisionPlugin
from lightning.pytorch.plugins.precision.half import HalfPrecisionPlugin
from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin
from lightning.pytorch.plugins.precision.xla import XLAPrecisionPlugin
from lightning.pytorch.plugins.precision.xlabf16 import XLABf16PrecisionPlugin
Expand All @@ -21,6 +22,7 @@
"XLACheckpointIO",
"DeepSpeedPrecisionPlugin",
"DoublePrecisionPlugin",
"HalfPrecisionPlugin",
"MixedPrecisionPlugin",
"PrecisionPlugin",
"FSDPMixedPrecisionPlugin",
Expand Down
2 changes: 2 additions & 0 deletions src/lightning/pytorch/plugins/precision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from lightning.pytorch.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin
from lightning.pytorch.plugins.precision.double import DoublePrecisionPlugin
from lightning.pytorch.plugins.precision.fsdp import FSDPMixedPrecisionPlugin
from lightning.pytorch.plugins.precision.half import HalfPrecisionPlugin
from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin
from lightning.pytorch.plugins.precision.xla import XLAPrecisionPlugin
from lightning.pytorch.plugins.precision.xlabf16 import XLABf16PrecisionPlugin
Expand All @@ -23,6 +24,7 @@
"DeepSpeedPrecisionPlugin",
"DoublePrecisionPlugin",
"FSDPMixedPrecisionPlugin",
"HalfPrecisionPlugin",
"MixedPrecisionPlugin",
"PrecisionPlugin",
"XLAPrecisionPlugin",
Expand Down
66 changes: 66 additions & 0 deletions src/lightning/pytorch/plugins/precision/half.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Any, Generator, Literal

import torch
from lightning_utilities import apply_to_collection
from torch import Tensor
from torch.nn import Module

from lightning.fabric.plugins.precision.utils import _convert_fp_tensor
from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin


class HalfPrecisionPlugin(PrecisionPlugin):
"""Plugin for training with half precision.

Args:
precision: Whether to use ``torch.float16`` (``'16-true'``) or ``torch.bfloat16`` (``'bf16-true'``).
"""

precision: Literal["bf16-true", "16-true"] = "16-true"

def __init__(self, precision: Literal["bf16-true", "16-true"] = "16-true") -> None:
self.precision = precision
self._desired_input_dtype = torch.bfloat16 if precision == "bf16-true" else torch.float16

def convert_module(self, module: Module) -> Module:
return module.to(dtype=self._desired_input_dtype)

@contextmanager
def init_context(self) -> Generator[None, None, None]:
"""A context manager to change the default tensor type when initializing module parameters or tensors.

See: :meth:`torch.set_default_dtype`
"""
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(self._desired_input_dtype)
yield
torch.set_default_dtype(default_dtype)

@contextmanager
def forward_context(self) -> Generator[None, None, None]:
"""A context manager to change the default tensor type when tensors get created during the module's
forward.

See: :meth:`torch.set_default_tensor_type`
"""
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(self._desired_input_dtype)
yield
torch.set_default_dtype(default_dtype)

def convert_input(self, data: Any) -> Any:
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_input_dtype)
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
CheckpointIO,
DeepSpeedPrecisionPlugin,
DoublePrecisionPlugin,
HalfPrecisionPlugin,
MixedPrecisionPlugin,
PLUGIN_INPUT,
PrecisionPlugin,
Expand Down Expand Up @@ -524,6 +525,8 @@ def _check_and_init_precision(self) -> PrecisionPlugin:
if isinstance(self.strategy, DeepSpeedStrategy):
return DeepSpeedPrecisionPlugin(self._precision_flag) # type: ignore[arg-type]

if self._precision_flag in ("16-true", "bf16-true"):
return HalfPrecisionPlugin(self._precision_flag) # type: ignore
if self._precision_flag == "32-true":
return PrecisionPlugin()
if self._precision_flag == "64-true":
Expand Down
Loading