diff --git a/docs/source-pytorch/api_references.rst b/docs/source-pytorch/api_references.rst index 7a04128e8fb27..56886acdfa61a 100644 --- a/docs/source-pytorch/api_references.rst +++ b/docs/source-pytorch/api_references.rst @@ -115,6 +115,7 @@ precision DeepSpeedPrecisionPlugin DoublePrecisionPlugin FSDPMixedPrecisionPlugin + HalfPrecisionPlugin MixedPrecisionPlugin PrecisionPlugin XLABf16PrecisionPlugin diff --git a/docs/source-pytorch/common/lightning_module.rst b/docs/source-pytorch/common/lightning_module.rst index 4b4237563ffb4..3a2691880a20c 100644 --- a/docs/source-pytorch/common/lightning_module.rst +++ b/docs/source-pytorch/common/lightning_module.rst @@ -940,7 +940,7 @@ The type of precision used: .. code-block:: python def training_step(self, batch, batch_idx): - if self.precision == 16: + if self.precision == "16-true": ... trainer diff --git a/docs/source-pytorch/common/precision_basic.rst b/docs/source-pytorch/common/precision_basic.rst index 033f58034f298..9ec3acc64e047 100644 --- a/docs/source-pytorch/common/precision_basic.rst +++ b/docs/source-pytorch/common/precision_basic.rst @@ -20,14 +20,25 @@ Higher precision, such as the 64-bit floating-point, can be used for highly sens 16-bit Precision **************** -Use 16-bit mixed precision to lower your memory consumption by up to half so that you can train and deploy larger models. If your GPUs are [`Tensor Core `_] GPUs, you can also get a ~3x speed improvement. Half precision can sometimes lead to unstable training. +Use 16-bit mixed precision to speed up training and inference. +If your GPUs are [`Tensor Core `_] GPUs, you can expect a ~3x speed improvement. -.. code:: +.. code-block:: python + + Trainer(precision="16-mixed") + + +With true 16-bit precision you can additionally lower your memory consumption by up to half so that you can train and deploy larger models. +However, this setting can sometimes lead to unstable training. + +.. code-block:: python + + Trainer(precision="16-true") - Trainer(precision='16-mixed') ---- + **************** 32-bit Precision **************** diff --git a/docs/source-pytorch/common/precision_intermediate.rst b/docs/source-pytorch/common/precision_intermediate.rst index e0590df72ffd9..0ac5e917b3361 100644 --- a/docs/source-pytorch/common/precision_intermediate.rst +++ b/docs/source-pytorch/common/precision_intermediate.rst @@ -53,6 +53,7 @@ delivers all of these benefits while ensuring that no task-specific accuracy is ---- + ******************** FP16 Mixed Precision ******************** @@ -68,7 +69,11 @@ Since computation happens in FP16, there is a chance of numerical instability du .. testcode:: :skipif: not torch.cuda.is_available() - Trainer(accelerator="gpu", devices=1, precision=16) + Trainer(accelerator="gpu", devices=1, precision="16-mixed") + + +---- + ************************ BFloat16 Mixed Precision @@ -86,16 +91,51 @@ Under the hood, we use `torch.autocast ` for more details. + .. note:: When running on TPUs, torch.bfloat16 will be used but tensor printing will still show torch.float32. + profiler ^^^^^^^^ @@ -841,7 +850,7 @@ profiler To profile individual steps during training and assist in identifying bottlenecks. -See the :doc:`profiler documentation <../tuning/profiler>`. for more details. +See the :doc:`profiler documentation <../tuning/profiler>` for more details. .. testcode:: diff --git a/docs/source-pytorch/levels/advanced_level_19.rst b/docs/source-pytorch/levels/advanced_level_19.rst index 4265b039b51f1..28edaa90826ed 100644 --- a/docs/source-pytorch/levels/advanced_level_19.rst +++ b/docs/source-pytorch/levels/advanced_level_19.rst @@ -33,7 +33,7 @@ Explore Intelligence Processing Unit (IPU) for model scaling. .. displayitem:: :header: Optimize models training on IPUs - :description: Tune model performance with mix-precision and the performance analyser. + :description: Tune model performance with mixed precision and the performance analyser. :col_css: col-md-4 :button_link: ../accelerators/ipu_intermediate.html :height: 150 diff --git a/docs/source-pytorch/levels/advanced_level_20.rst b/docs/source-pytorch/levels/advanced_level_20.rst index f17ebdfd1fd5e..8aaa159cc62e3 100644 --- a/docs/source-pytorch/levels/advanced_level_20.rst +++ b/docs/source-pytorch/levels/advanced_level_20.rst @@ -25,7 +25,7 @@ Explore Intel Habana Processing Unit (HPU) for model scaling. .. displayitem:: :header: Optimize models training on HPUs - :description: Enable state-of-the-art scaling with advanced mix-precision settings. + :description: Enable state-of-the-art scaling with advanced mixed-precision settings. :col_css: col-md-6 :button_link: ../integrations/hpu/intermediate.html :height: 150 diff --git a/docs/source-pytorch/model/build_model_intermediate.rst b/docs/source-pytorch/model/build_model_intermediate.rst index 18da34c2e2dd4..b8187f5982fe9 100644 --- a/docs/source-pytorch/model/build_model_intermediate.rst +++ b/docs/source-pytorch/model/build_model_intermediate.rst @@ -16,7 +16,7 @@ Enable advanced training features using Trainer arguments. These are SOTA techni devices=4, accelerator="gpu", strategy="deepspeed_stage_2", - precision=16 + precision="16-mixed", ) # 20+ helpful arguments for rapid idea iteration diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 5e3df6ccea018..ccae1089cb6ac 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -91,6 +91,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Allowed accessing rank information in the main process before processes are launched when using the `XLAStrategy` ([#18194](https://github.com/Lightning-AI/lightning/pull/18194)) +- Added support for true half-precision training via `Trainer(precision="16-true"|"bf16-true")` ([#18193](https://github.com/Lightning-AI/lightning/pull/18193)) + + ### Changed - Removed the limitation to call `self.trainer.model.parameters()` in `LightningModule.configure_optimizers()` ([#17309](https://github.com/Lightning-AI/lightning/pull/17309)) diff --git a/src/lightning/pytorch/plugins/__init__.py b/src/lightning/pytorch/plugins/__init__.py index 6669df200fe8c..9d0699422e84e 100644 --- a/src/lightning/pytorch/plugins/__init__.py +++ b/src/lightning/pytorch/plugins/__init__.py @@ -7,6 +7,7 @@ from lightning.pytorch.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin from lightning.pytorch.plugins.precision.double import DoublePrecisionPlugin from lightning.pytorch.plugins.precision.fsdp import FSDPMixedPrecisionPlugin +from lightning.pytorch.plugins.precision.half import HalfPrecisionPlugin from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin from lightning.pytorch.plugins.precision.xla import XLAPrecisionPlugin from lightning.pytorch.plugins.precision.xlabf16 import XLABf16PrecisionPlugin @@ -21,6 +22,7 @@ "XLACheckpointIO", "DeepSpeedPrecisionPlugin", "DoublePrecisionPlugin", + "HalfPrecisionPlugin", "MixedPrecisionPlugin", "PrecisionPlugin", "FSDPMixedPrecisionPlugin", diff --git a/src/lightning/pytorch/plugins/precision/__init__.py b/src/lightning/pytorch/plugins/precision/__init__.py index 45bad4f687406..f8e35bdfd02d3 100644 --- a/src/lightning/pytorch/plugins/precision/__init__.py +++ b/src/lightning/pytorch/plugins/precision/__init__.py @@ -15,6 +15,7 @@ from lightning.pytorch.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin from lightning.pytorch.plugins.precision.double import DoublePrecisionPlugin from lightning.pytorch.plugins.precision.fsdp import FSDPMixedPrecisionPlugin +from lightning.pytorch.plugins.precision.half import HalfPrecisionPlugin from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin from lightning.pytorch.plugins.precision.xla import XLAPrecisionPlugin from lightning.pytorch.plugins.precision.xlabf16 import XLABf16PrecisionPlugin @@ -23,6 +24,7 @@ "DeepSpeedPrecisionPlugin", "DoublePrecisionPlugin", "FSDPMixedPrecisionPlugin", + "HalfPrecisionPlugin", "MixedPrecisionPlugin", "PrecisionPlugin", "XLAPrecisionPlugin", diff --git a/src/lightning/pytorch/plugins/precision/half.py b/src/lightning/pytorch/plugins/precision/half.py new file mode 100644 index 0000000000000..dcafa3b33fd53 --- /dev/null +++ b/src/lightning/pytorch/plugins/precision/half.py @@ -0,0 +1,66 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from contextlib import contextmanager +from typing import Any, Generator, Literal + +import torch +from lightning_utilities import apply_to_collection +from torch import Tensor +from torch.nn import Module + +from lightning.fabric.plugins.precision.utils import _convert_fp_tensor +from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin + + +class HalfPrecisionPlugin(PrecisionPlugin): + """Plugin for training with half precision. + + Args: + precision: Whether to use ``torch.float16`` (``'16-true'``) or ``torch.bfloat16`` (``'bf16-true'``). + """ + + precision: Literal["bf16-true", "16-true"] = "16-true" + + def __init__(self, precision: Literal["bf16-true", "16-true"] = "16-true") -> None: + self.precision = precision + self._desired_input_dtype = torch.bfloat16 if precision == "bf16-true" else torch.float16 + + def convert_module(self, module: Module) -> Module: + return module.to(dtype=self._desired_input_dtype) + + @contextmanager + def init_context(self) -> Generator[None, None, None]: + """A context manager to change the default tensor type when initializing module parameters or tensors. + + See: :meth:`torch.set_default_dtype` + """ + default_dtype = torch.get_default_dtype() + torch.set_default_dtype(self._desired_input_dtype) + yield + torch.set_default_dtype(default_dtype) + + @contextmanager + def forward_context(self) -> Generator[None, None, None]: + """A context manager to change the default tensor type when tensors get created during the module's + forward. + + See: :meth:`torch.set_default_tensor_type` + """ + default_dtype = torch.get_default_dtype() + torch.set_default_dtype(self._desired_input_dtype) + yield + torch.set_default_dtype(default_dtype) + + def convert_input(self, data: Any) -> Any: + return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_input_dtype) diff --git a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py index 508a6adae1adb..739e72161099c 100644 --- a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py +++ b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py @@ -39,6 +39,7 @@ CheckpointIO, DeepSpeedPrecisionPlugin, DoublePrecisionPlugin, + HalfPrecisionPlugin, MixedPrecisionPlugin, PLUGIN_INPUT, PrecisionPlugin, @@ -524,6 +525,8 @@ def _check_and_init_precision(self) -> PrecisionPlugin: if isinstance(self.strategy, DeepSpeedStrategy): return DeepSpeedPrecisionPlugin(self._precision_flag) # type: ignore[arg-type] + if self._precision_flag in ("16-true", "bf16-true"): + return HalfPrecisionPlugin(self._precision_flag) # type: ignore if self._precision_flag == "32-true": return PrecisionPlugin() if self._precision_flag == "64-true": diff --git a/tests/tests_pytorch/plugins/precision/test_half.py b/tests/tests_pytorch/plugins/precision/test_half.py new file mode 100644 index 0000000000000..1ac91bc0da44f --- /dev/null +++ b/tests/tests_pytorch/plugins/precision/test_half.py @@ -0,0 +1,76 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from lightning.pytorch.plugins import HalfPrecisionPlugin + + +@pytest.mark.parametrize( + ("precision", "expected_dtype"), + [ + ("bf16-true", torch.bfloat16), + ("16-true", torch.half), + ], +) +def test_selected_dtype(precision, expected_dtype): + plugin = HalfPrecisionPlugin(precision=precision) + assert plugin.precision == precision + assert plugin._desired_input_dtype == expected_dtype + + +@pytest.mark.parametrize( + ("precision", "expected_dtype"), + [ + ("bf16-true", torch.bfloat16), + ("16-true", torch.half), + ], +) +def test_init_context(precision, expected_dtype): + plugin = HalfPrecisionPlugin(precision=precision) + with plugin.init_context(): + model = torch.nn.Linear(2, 2) + assert torch.get_default_dtype() == expected_dtype + assert model.weight.dtype == expected_dtype + + +@pytest.mark.parametrize( + ("precision", "expected_dtype"), + [ + ("bf16-true", torch.bfloat16), + ("16-true", torch.half), + ], +) +def test_forward_context(precision, expected_dtype): + precision = HalfPrecisionPlugin(precision=precision) + assert torch.get_default_dtype() == torch.float32 + with precision.forward_context(): + assert torch.get_default_dtype() == expected_dtype + assert torch.get_default_dtype() == torch.float32 + + +@pytest.mark.parametrize( + ("precision", "expected_dtype"), + [ + ("bf16-true", torch.bfloat16), + ("16-true", torch.half), + ], +) +def test_convert_module(precision, expected_dtype): + precision = HalfPrecisionPlugin(precision=precision) + module = torch.nn.Linear(2, 2) + assert module.weight.dtype == module.bias.dtype == torch.float32 + module = precision.convert_module(module) + assert module.weight.dtype == module.bias.dtype == expected_dtype diff --git a/tests/tests_pytorch/strategies/test_common.py b/tests/tests_pytorch/strategies/test_common.py index 3b2b27cf9996e..c090b2c8bea30 100644 --- a/tests/tests_pytorch/strategies/test_common.py +++ b/tests/tests_pytorch/strategies/test_common.py @@ -18,7 +18,7 @@ from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0 from lightning.pytorch import Trainer -from lightning.pytorch.plugins import DoublePrecisionPlugin, PrecisionPlugin +from lightning.pytorch.plugins import DoublePrecisionPlugin, HalfPrecisionPlugin, PrecisionPlugin from lightning.pytorch.strategies import SingleDeviceStrategy from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf @@ -68,6 +68,8 @@ def test_evaluate(tmpdir, trainer_kwargs): [ (PrecisionPlugin(), torch.float32), pytest.param(DoublePrecisionPlugin(), torch.float64, marks=RunIf(mps=False)), + (HalfPrecisionPlugin("16-true"), torch.float16), + pytest.param(HalfPrecisionPlugin("bf16-true"), torch.bfloat16, marks=RunIf(bf16_cuda=True)), ], ) @pytest.mark.parametrize("empty_init", [None, True, False]) diff --git a/tests/tests_pytorch/strategies/test_ddp.py b/tests/tests_pytorch/strategies/test_ddp.py index 167f36ea8f457..1a06a622d4646 100644 --- a/tests/tests_pytorch/strategies/test_ddp.py +++ b/tests/tests_pytorch/strategies/test_ddp.py @@ -24,7 +24,7 @@ from lightning.pytorch import seed_everything, Trainer from lightning.pytorch.callbacks import Callback from lightning.pytorch.demos.boring_classes import BoringModel -from lightning.pytorch.plugins import DoublePrecisionPlugin, PrecisionPlugin +from lightning.pytorch.plugins import DoublePrecisionPlugin, HalfPrecisionPlugin, PrecisionPlugin from lightning.pytorch.strategies import DDPStrategy from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf @@ -169,6 +169,8 @@ def test_ddp_kwargs_from_registry(strategy_name, expected_ddp_kwargs, mps_count_ [ (PrecisionPlugin(), torch.float32), (DoublePrecisionPlugin(), torch.float64), + (HalfPrecisionPlugin("16-true"), torch.float16), + pytest.param(HalfPrecisionPlugin("bf16-true"), torch.bfloat16, marks=RunIf(bf16_cuda=True)), ], ) @mock.patch.dict(os.environ, {"LOCAL_RANK": "1"}) diff --git a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py index 35e2fce9d6602..41d2c29e23b2e 100644 --- a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py @@ -33,8 +33,14 @@ from lightning.fabric.utilities.imports import _IS_WINDOWS from lightning.pytorch import Trainer from lightning.pytorch.accelerators import Accelerator, CPUAccelerator, CUDAAccelerator, MPSAccelerator, XLAAccelerator -from lightning.pytorch.plugins import DoublePrecisionPlugin, LayerSync, PrecisionPlugin, TorchSyncBatchNorm from lightning.pytorch.plugins.io import TorchCheckpointIO +from lightning.pytorch.plugins.layer_sync import LayerSync, TorchSyncBatchNorm +from lightning.pytorch.plugins.precision import ( + DoublePrecisionPlugin, + HalfPrecisionPlugin, + MixedPrecisionPlugin, + PrecisionPlugin, +) from lightning.pytorch.strategies import ( DDPStrategy, DeepSpeedStrategy, @@ -1006,3 +1012,19 @@ def _mock_tpu_available(value): def test_connector_sets_num_nodes(strategy, cuda_count_2): trainer = Trainer(accelerator="cuda", strategy=strategy, devices=2, num_nodes=2) assert trainer.strategy.num_nodes == 2 + + +@pytest.mark.parametrize( + ("precision_str", "precision_cls"), + [ + ("64-true", DoublePrecisionPlugin), + ("32-true", PrecisionPlugin), + ("16-true", HalfPrecisionPlugin), + ("bf16-true", HalfPrecisionPlugin), + ("16-mixed", MixedPrecisionPlugin), + ("bf16-mixed", MixedPrecisionPlugin), + ], +) +def test_precision_selection(precision_str, precision_cls): + connector = _AcceleratorConnector(precision=precision_str) + assert isinstance(connector.precision_plugin, precision_cls)