diff --git a/docs/source-fabric/api/fabric_args.rst b/docs/source-fabric/api/fabric_args.rst index e7829acaeabe7..4e140814a7657 100644 --- a/docs/source-fabric/api/fabric_args.rst +++ b/docs/source-fabric/api/fabric_args.rst @@ -123,13 +123,19 @@ This can result in improved performance, achieving significant speedups on moder # the same as: fabric = Fabric(precision="32", devices=1) - # 16-bit (mixed) precision + # 16-bit mixed precision (model weights remain in torch.float32) fabric = Fabric(precision="16-mixed", devices=1) - # 16-bit bfloat precision + # 16-bit bfloat mixed precision (model weights remain in torch.float32) fabric = Fabric(precision="bf16-mixed", devices=1) - # 64-bit (double) precision + # 16-bit precision (model weights get cast to torch.float16) + fabric = Fabric(precision="16-true", devices=1) + + # 16-bit bfloat precision (model weights get cast to torch.bfloat16) + fabric = Fabric(precision="bf16-true", devices=1) + + # 64-bit (double) precision (model weights get cast to torch.float64) fabric = Fabric(precision="64-true", devices=1) See also: :doc:`../fundamentals/precision` diff --git a/docs/source-fabric/fundamentals/precision.rst b/docs/source-fabric/fundamentals/precision.rst index 731f7d42fe423..2108676cab285 100644 --- a/docs/source-fabric/fundamentals/precision.rst +++ b/docs/source-fabric/fundamentals/precision.rst @@ -118,6 +118,41 @@ It is also possible to use BFloat16 mixed precision on the CPU, relying on MKLDN ---- +******************* +True Half Precision +******************* + +As mentioned before, for numerical stability mixed precision keeps the model weights in full float32 precision while casting only supported operations to lower bit precision. +However, in some cases it is indeed possible to train completely in half precision. Similarly, for inference the model weights can often be cast to half precision without a loss in accuracy (even when trained with mixed precision). + +.. code-block:: python + + # Select FP16 precision + fabric = Fabric(precision="16-true") + model = MyModel() + model = fabric.setup(model) # model gets cast to torch.float16 + + # Select BF16 precision + fabric = Fabric(precision="bf16-true") + model = MyModel() + model = fabric.setup(model) # model gets cast to torch.bfloat16 + +Tip: For faster initialization, you can create model parameters with the desired dtype directly on the device: + +.. code-block:: python + + fabric = Fabric(precision="bf16-true") + + # init the model directly on the device and with parameters in half-precision + with fabric.init_module(): + model = MyModel() + + model = fabric.setup(model) + + +---- + + ************************************ Control where precision gets applied ************************************ diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 62b0c6586622f..a347822e8e9ea 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -39,6 +39,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Run the DDP wrapper in a CUDA stream ([#17334](https://github.com/Lightning-AI/lightning/pull/17334)) +- Added support for true half-precision as `L.Fabric(precision="16-true"|"bf16-true")` ([#17287](https://github.com/Lightning-AI/lightning/pull/17287)) + + ### Changed - Allow using iterable-style datasets with TPUs ([#17331](https://github.com/Lightning-AI/lightning/pull/17331)) @@ -55,7 +58,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Deprecated -- + +- Deprecated the `Fabric.sharded_model()` context manager in favor of `Fabric.init_module()` ([#17462](https://github.com/Lightning-AI/lightning/pull/17462)) ### Removed diff --git a/src/lightning/fabric/connector.py b/src/lightning/fabric/connector.py index 4eae73e33de48..f90fd3347f7a4 100644 --- a/src/lightning/fabric/connector.py +++ b/src/lightning/fabric/connector.py @@ -26,6 +26,7 @@ from lightning.fabric.plugins import ( CheckpointIO, DeepSpeedPrecision, + HalfPrecision, MixedPrecision, Precision, TPUBf16Precision, @@ -446,6 +447,8 @@ def _check_and_init_precision(self) -> Precision: if isinstance(self.strategy, DeepSpeedStrategy): return DeepSpeedPrecision(self._precision_input) # type: ignore + if self._precision_input in ("16-true", "bf16-true"): + return HalfPrecision(self._precision_input) # type: ignore if self._precision_input == "32-true": return Precision() if self._precision_input == "64-true": @@ -467,8 +470,8 @@ def _check_and_init_precision(self) -> Precision: device = "cpu" if self._accelerator_flag == "cpu" else "cuda" if isinstance(self.strategy, FSDPStrategy): - return FSDPPrecision(precision=self._precision_input, device=device) - return MixedPrecision(precision=self._precision_input, device=device) + return FSDPPrecision(precision=self._precision_input, device=device) # type: ignore[arg-type] + return MixedPrecision(precision=self._precision_input, device=device) # type: ignore[arg-type] raise RuntimeError("No precision set") diff --git a/src/lightning/fabric/plugins/__init__.py b/src/lightning/fabric/plugins/__init__.py index 31368988e2db2..c88397db89951 100644 --- a/src/lightning/fabric/plugins/__init__.py +++ b/src/lightning/fabric/plugins/__init__.py @@ -19,6 +19,7 @@ from lightning.fabric.plugins.precision.deepspeed import DeepSpeedPrecision from lightning.fabric.plugins.precision.double import DoublePrecision from lightning.fabric.plugins.precision.fsdp import FSDPPrecision +from lightning.fabric.plugins.precision.half import HalfPrecision from lightning.fabric.plugins.precision.precision import Precision from lightning.fabric.plugins.precision.tpu import TPUPrecision from lightning.fabric.plugins.precision.tpu_bf16 import TPUBf16Precision @@ -31,6 +32,7 @@ "Precision", "DeepSpeedPrecision", "DoublePrecision", + "HalfPrecision", "MixedPrecision", "TPUPrecision", "TPUBf16Precision", diff --git a/src/lightning/fabric/plugins/precision/__init__.py b/src/lightning/fabric/plugins/precision/__init__.py index 58f45ede53b98..a31fd78865db5 100644 --- a/src/lightning/fabric/plugins/precision/__init__.py +++ b/src/lightning/fabric/plugins/precision/__init__.py @@ -15,6 +15,7 @@ from lightning.fabric.plugins.precision.deepspeed import DeepSpeedPrecision from lightning.fabric.plugins.precision.double import DoublePrecision from lightning.fabric.plugins.precision.fsdp import FSDPPrecision +from lightning.fabric.plugins.precision.half import HalfPrecision from lightning.fabric.plugins.precision.precision import Precision from lightning.fabric.plugins.precision.tpu import TPUPrecision from lightning.fabric.plugins.precision.tpu_bf16 import TPUBf16Precision @@ -22,6 +23,7 @@ __all__ = [ "DeepSpeedPrecision", "DoublePrecision", + "HalfPrecision", "MixedPrecision", "Precision", "TPUPrecision", diff --git a/src/lightning/fabric/plugins/precision/half.py b/src/lightning/fabric/plugins/precision/half.py new file mode 100644 index 0000000000000..7831ba5aa7110 --- /dev/null +++ b/src/lightning/fabric/plugins/precision/half.py @@ -0,0 +1,69 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from contextlib import contextmanager +from typing import Any, Generator, Literal + +import torch +from lightning_utilities.core.apply_func import apply_to_collection +from torch import Tensor +from torch.nn import Module + +from lightning.fabric.plugins.precision.precision import Precision +from lightning.fabric.plugins.precision.utils import _convert_fp_tensor + + +class HalfPrecision(Precision): + """Plugin for training with half precision. + + Args: + precision: Whether to use ``torch.float16`` (``'16-true'``) or ``torch.bfloat16`` (``'bf16-true'``). + """ + + precision: Literal["bf16-true", "16-true"] = "16-true" + + def __init__(self, precision: Literal["bf16-true", "16-true"] = "16-true") -> None: + self.precision = precision + self._desired_input_dtype = torch.bfloat16 if precision == "bf16-true" else torch.float16 + + def convert_module(self, module: Module) -> Module: + return module.to(dtype=self._desired_input_dtype) + + @contextmanager + def module_init_context(self) -> Generator[None, None, None]: + """A context manager to change the default tensor type when initializing the parameters in a module. + + See: :meth:`torch.set_default_tensor_type` + """ + default_dtype = torch.get_default_dtype() + torch.set_default_dtype(self._desired_input_dtype) + yield + torch.set_default_dtype(default_dtype) + + @contextmanager + def forward_context(self) -> Generator[None, None, None]: + """A context manager to change the default tensor type when tensors get created during the module's + forward. + + See: :meth:`torch.set_default_tensor_type` + """ + default_dtype = torch.get_default_dtype() + torch.set_default_dtype(self._desired_input_dtype) + yield + torch.set_default_dtype(default_dtype) + + def convert_input(self, data: Any) -> Any: + return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_input_dtype) + + def convert_output(self, data: Any) -> Any: + return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype()) diff --git a/src/lightning/fabric/plugins/precision/precision.py b/src/lightning/fabric/plugins/precision/precision.py index dddbfc20c8526..f42c9836b96c3 100644 --- a/src/lightning/fabric/plugins/precision/precision.py +++ b/src/lightning/fabric/plugins/precision/precision.py @@ -23,7 +23,7 @@ _PRECISION_INPUT_INT = Literal[64, 32, 16] _PRECISION_INPUT_STR_ALIAS_CONVERSION = {"64": "64-true", "32": "32-true", "16": "16-mixed", "bf16": "bf16-mixed"} _PRECISION_INPUT_STR_ALIAS = Literal["64", "32", "16", "bf16"] -_PRECISION_INPUT_STR = Literal["16-mixed", "bf16-mixed", "32-true", "64-true"] +_PRECISION_INPUT_STR = Literal["16-true", "16-mixed", "bf16-true", "bf16-mixed", "32-true", "64-true"] _PRECISION_INPUT = Union[_PRECISION_INPUT_INT, _PRECISION_INPUT_STR, _PRECISION_INPUT_STR_ALIAS] diff --git a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py index 3071c3fb1c8bd..245c00559b0d5 100644 --- a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py +++ b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py @@ -544,7 +544,7 @@ def _check_and_init_precision(self) -> PrecisionPlugin: if isinstance(self.strategy, FSDPStrategy): return FSDPMixedPrecisionPlugin(self._precision_flag, device) - return MixedPrecisionPlugin(self._precision_flag, device) + return MixedPrecisionPlugin(self._precision_flag, device) # type: ignore[arg-type] raise RuntimeError("No precision set") diff --git a/tests/tests_fabric/plugins/precision/test_double.py b/tests/tests_fabric/plugins/precision/test_double.py index 699fd6376299f..4921e0f4e659b 100644 --- a/tests/tests_fabric/plugins/precision/test_double.py +++ b/tests/tests_fabric/plugins/precision/test_double.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import torch from lightning.fabric.plugins.precision.double import DoublePrecision @@ -23,3 +22,11 @@ def test_double_precision_forward_context(): with precision.forward_context(): assert torch.get_default_dtype() == torch.float64 assert torch.get_default_dtype() == torch.float32 + + +def test_convert_module(): + precision = DoublePrecision() + module = torch.nn.Linear(2, 2) + assert module.weight.dtype == module.bias.dtype == torch.float32 + module = precision.convert_module(module) + assert module.weight.dtype == module.bias.dtype == torch.float64 diff --git a/tests/tests_fabric/plugins/precision/test_half.py b/tests/tests_fabric/plugins/precision/test_half.py new file mode 100644 index 0000000000000..c39d6f8a951a7 --- /dev/null +++ b/tests/tests_fabric/plugins/precision/test_half.py @@ -0,0 +1,75 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import torch + +from lightning.fabric.plugins.precision import HalfPrecision + + +@pytest.mark.parametrize( + "precision, expected_dtype", + [ + ("bf16-true", torch.bfloat16), + ("16-true", torch.half), + ], +) +def test_selected_dtype(precision, expected_dtype): + plugin = HalfPrecision(precision=precision) + assert plugin.precision == precision + assert plugin._desired_input_dtype == expected_dtype + + +@pytest.mark.parametrize( + "precision, expected_dtype", + [ + ("bf16-true", torch.bfloat16), + ("16-true", torch.half), + ], +) +def test_module_init_context(precision, expected_dtype): + plugin = HalfPrecision(precision=precision) + with plugin.module_init_context(): + model = torch.nn.Linear(2, 2) + assert torch.get_default_dtype() == expected_dtype + assert model.weight.dtype == expected_dtype + + +@pytest.mark.parametrize( + "precision, expected_dtype", + [ + ("bf16-true", torch.bfloat16), + ("16-true", torch.half), + ], +) +def test_forward_context(precision, expected_dtype): + precision = HalfPrecision(precision=precision) + assert torch.get_default_dtype() == torch.float32 + with precision.forward_context(): + assert torch.get_default_dtype() == expected_dtype + assert torch.get_default_dtype() == torch.float32 + + +@pytest.mark.parametrize( + "precision, expected_dtype", + [ + ("bf16-true", torch.bfloat16), + ("16-true", torch.half), + ], +) +def test_convert_module(precision, expected_dtype): + precision = HalfPrecision(precision=precision) + module = torch.nn.Linear(2, 2) + assert module.weight.dtype == module.bias.dtype == torch.float32 + module = precision.convert_module(module) + assert module.weight.dtype == module.bias.dtype == expected_dtype diff --git a/tests/tests_fabric/strategies/test_ddp.py b/tests/tests_fabric/strategies/test_ddp.py index 7e0f0ee8e23c4..963ff607815cb 100644 --- a/tests/tests_fabric/strategies/test_ddp.py +++ b/tests/tests_fabric/strategies/test_ddp.py @@ -19,7 +19,7 @@ import torch from torch.nn.parallel import DistributedDataParallel -from lightning.fabric.plugins import DoublePrecision, Precision +from lightning.fabric.plugins import DoublePrecision, HalfPrecision, Precision from lightning.fabric.plugins.environments import LightningEnvironment from lightning.fabric.strategies import DDPStrategy from lightning.fabric.strategies.ddp import _DDPBackwardSyncControl @@ -133,6 +133,8 @@ def test_ddp_grad_clipping(clip_type, accelerator, precision): "precision,expected_dtype", [ (Precision(), torch.float32), + (HalfPrecision("16-true"), torch.float16), + pytest.param(HalfPrecision("bf16-true"), torch.bfloat16, marks=RunIf(bf16_cuda=True)), (DoublePrecision(), torch.float64), ], ) diff --git a/tests/tests_fabric/strategies/test_fsdp_integration.py b/tests/tests_fabric/strategies/test_fsdp_integration.py index 024475940f6ba..4ce58c53f9395 100644 --- a/tests/tests_fabric/strategies/test_fsdp_integration.py +++ b/tests/tests_fabric/strategies/test_fsdp_integration.py @@ -198,6 +198,8 @@ def test_compile(compile_after_setup): "precision,expected_dtype", [ ("32-true", torch.float32), + ("16-true", torch.float16), + pytest.param("bf16-true", torch.bfloat16, marks=RunIf(bf16_cuda=True)), ("64-true", torch.float64), ], ) diff --git a/tests/tests_fabric/strategies/test_single_device.py b/tests/tests_fabric/strategies/test_single_device.py index bab2744464f49..05c11f703d301 100644 --- a/tests/tests_fabric/strategies/test_single_device.py +++ b/tests/tests_fabric/strategies/test_single_device.py @@ -16,7 +16,7 @@ import pytest import torch -from lightning.fabric.plugins import DoublePrecision, Precision +from lightning.fabric.plugins import DoublePrecision, HalfPrecision, Precision from lightning.fabric.strategies import SingleDeviceStrategy from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0 from lightning.fabric.wrappers import _FabricModule, _FabricOptimizer @@ -163,6 +163,8 @@ def test_single_device_grad_clipping(clip_type, precision): "precision,dtype", [ (Precision(), torch.float32), + (HalfPrecision("16-true"), torch.float16), + pytest.param(HalfPrecision("bf16-true"), torch.bfloat16, marks=RunIf(mps=False)), pytest.param(DoublePrecision(), torch.float64, marks=RunIf(mps=False)), ], ) diff --git a/tests/tests_fabric/test_connector.py b/tests/tests_fabric/test_connector.py index b6b9da4947df7..05feb6350f701 100644 --- a/tests/tests_fabric/test_connector.py +++ b/tests/tests_fabric/test_connector.py @@ -30,7 +30,7 @@ from lightning.fabric.accelerators.cuda import CUDAAccelerator from lightning.fabric.accelerators.mps import MPSAccelerator from lightning.fabric.connector import _Connector -from lightning.fabric.plugins import DoublePrecision, MixedPrecision, Precision, TPUPrecision +from lightning.fabric.plugins import DoublePrecision, HalfPrecision, MixedPrecision, Precision, TPUPrecision from lightning.fabric.plugins.environments import ( KubeflowEnvironment, LightningEnvironment, @@ -765,6 +765,22 @@ def test_ddp_fork_on_unsupported_platform(_, __, strategy): _Connector(strategy=strategy) +@pytest.mark.parametrize( + "precision_str,precision_cls", + [ + ("64-true", DoublePrecision), + ("32-true", Precision), + ("16-true", HalfPrecision), + ("bf16-true", HalfPrecision), + ("16-mixed", MixedPrecision), + ("bf16-mixed", MixedPrecision), + ], +) +def test_precision_selection(precision_str, precision_cls): + connector = _Connector(precision=precision_str) + assert isinstance(connector.precision, precision_cls) + + def test_precision_selection_16_on_cpu_warns(): with pytest.warns( UserWarning,