Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
be48ffa
model instantiation
awaelchli Apr 5, 2023
5d68d2f
strategy implementations
awaelchli Apr 5, 2023
df5c9ed
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 5, 2023
1779460
tests
awaelchli Apr 10, 2023
b92d056
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 10, 2023
de21ae5
connect precision
awaelchli Apr 16, 2023
3153e0c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 16, 2023
70c25df
tests
awaelchli Apr 17, 2023
3fb0c50
ddp
awaelchli Apr 18, 2023
ccc9b8d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 18, 2023
d93b7c9
update
awaelchli Apr 18, 2023
cb94829
Merge remote-tracking branch 'origin/fabric/half-precision' into fabr…
awaelchli Apr 18, 2023
5f57343
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 18, 2023
2fd241a
ddp test
awaelchli Apr 18, 2023
9f80ea3
ddp test
awaelchli Apr 18, 2023
5a72dca
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 18, 2023
e1e1852
reset
awaelchli Apr 18, 2023
4c07eae
notebook
awaelchli Apr 18, 2023
9b0f0de
notebook
awaelchli Apr 18, 2023
bb2321f
notebook
awaelchli Apr 18, 2023
6a14bdd
add test
awaelchli Apr 18, 2023
b6f8e1a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 18, 2023
71d9308
Merge branch 'master' into fabric/half-precision
awaelchli Apr 24, 2023
66d3a20
fsdp tests
awaelchli Apr 24, 2023
3207812
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 24, 2023
1a852e1
comments
awaelchli Apr 24, 2023
1805a60
reset
awaelchli Apr 24, 2023
f368a66
Revert "reset"
awaelchli Apr 24, 2023
edd35dd
Merge branch 'master' into fabric/module-init
awaelchli Apr 24, 2023
edf7135
changelog
awaelchli Apr 24, 2023
65e0d22
add changelog
awaelchli Apr 24, 2023
2993281
add test
awaelchli Apr 24, 2023
e2aa4c3
add test
awaelchli Apr 24, 2023
caa469a
Merge branch 'fabric/module-init' into fabric/half-precision
awaelchli Apr 24, 2023
63faed7
document true half precision
awaelchli Apr 26, 2023
e9463a6
changelog
awaelchli Apr 26, 2023
c114f45
Merge branch 'master' into fabric/half-precision
awaelchli Apr 26, 2023
57e9ee8
fix import
awaelchli Apr 26, 2023
808f9d4
fix merge error
awaelchli Apr 26, 2023
a81152a
ignore weirdo type error
awaelchli Apr 26, 2023
043911f
Update docs/source-fabric/fundamentals/precision.rst
awaelchli Apr 27, 2023
898ee2a
Update src/lightning/fabric/plugins/precision/half.py
awaelchli Apr 27, 2023
9e581fa
Update src/lightning/fabric/plugins/precision/half.py
awaelchli Apr 27, 2023
9cbd557
update default
awaelchli Apr 27, 2023
f8b02e1
mypy
awaelchli Apr 27, 2023
3fac938
Merge branch 'master' into fabric/half-precision
Borda Apr 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions docs/source-fabric/api/fabric_args.rst
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,19 @@ This can result in improved performance, achieving significant speedups on moder
# the same as:
fabric = Fabric(precision="32", devices=1)

# 16-bit (mixed) precision
# 16-bit mixed precision (model weights remain in torch.float32)
fabric = Fabric(precision="16-mixed", devices=1)

# 16-bit bfloat precision
# 16-bit bfloat mixed precision (model weights remain in torch.float32)
fabric = Fabric(precision="bf16-mixed", devices=1)

# 64-bit (double) precision
# 16-bit precision (model weights get cast to torch.float16)
fabric = Fabric(precision="16-true", devices=1)

# 16-bit bfloat precision (model weights get cast to torch.bfloat16)
fabric = Fabric(precision="bf16-true", devices=1)

# 64-bit (double) precision (model weights get cast to torch.float64)
fabric = Fabric(precision="64-true", devices=1)

See also: :doc:`../fundamentals/precision`
Expand Down
35 changes: 35 additions & 0 deletions docs/source-fabric/fundamentals/precision.rst
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,41 @@ It is also possible to use BFloat16 mixed precision on the CPU, relying on MKLDN
----


*******************
True Half Precision
*******************

As mentioned before, for numerical stability mixed precision keeps the model weights in full float32 precision while casting only supported operations to lower bit precision.
However, in some cases it is indeed possible to train completely in half precision. Similarly, for inference the model weights can often be cast to half precision without a loss in accuracy (even when trained with mixed precision).

.. code-block:: python

# Select FP16 precision
fabric = Fabric(precision="16-true")
model = MyModel()
model = fabric.setup(model) # model gets cast to torch.float16

# Select BF16 precision
fabric = Fabric(precision="bf16-true")
model = MyModel()
model = fabric.setup(model) # model gets cast to torch.bfloat16

Tip: For faster initialization, you can create model parameters with the desired dtype directly on the device:

.. code-block:: python

fabric = Fabric(precision="bf16-true")

# init the model directly on the device and with parameters in half-precision
with fabric.init_module():
model = MyModel()

model = fabric.setup(model)


----


************************************
Control where precision gets applied
************************************
Expand Down
6 changes: 5 additions & 1 deletion src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Run the DDP wrapper in a CUDA stream ([#17334](https://github.com/Lightning-AI/lightning/pull/17334))


- Added support for true half-precision as `L.Fabric(precision="16-true"|"bf16-true")` ([#17287](https://github.com/Lightning-AI/lightning/pull/17287))


### Changed

- Allow using iterable-style datasets with TPUs ([#17331](https://github.com/Lightning-AI/lightning/pull/17331))
Expand All @@ -55,7 +58,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Deprecated

-

- Deprecated the `Fabric.sharded_model()` context manager in favor of `Fabric.init_module()` ([#17462](https://github.com/Lightning-AI/lightning/pull/17462))


### Removed
Expand Down
7 changes: 5 additions & 2 deletions src/lightning/fabric/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from lightning.fabric.plugins import (
CheckpointIO,
DeepSpeedPrecision,
HalfPrecision,
MixedPrecision,
Precision,
TPUBf16Precision,
Expand Down Expand Up @@ -446,6 +447,8 @@ def _check_and_init_precision(self) -> Precision:
if isinstance(self.strategy, DeepSpeedStrategy):
return DeepSpeedPrecision(self._precision_input) # type: ignore

if self._precision_input in ("16-true", "bf16-true"):
return HalfPrecision(self._precision_input) # type: ignore
if self._precision_input == "32-true":
return Precision()
if self._precision_input == "64-true":
Expand All @@ -467,8 +470,8 @@ def _check_and_init_precision(self) -> Precision:
device = "cpu" if self._accelerator_flag == "cpu" else "cuda"

if isinstance(self.strategy, FSDPStrategy):
return FSDPPrecision(precision=self._precision_input, device=device)
return MixedPrecision(precision=self._precision_input, device=device)
return FSDPPrecision(precision=self._precision_input, device=device) # type: ignore[arg-type]
return MixedPrecision(precision=self._precision_input, device=device) # type: ignore[arg-type]

raise RuntimeError("No precision set")

Expand Down
2 changes: 2 additions & 0 deletions src/lightning/fabric/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from lightning.fabric.plugins.precision.deepspeed import DeepSpeedPrecision
from lightning.fabric.plugins.precision.double import DoublePrecision
from lightning.fabric.plugins.precision.fsdp import FSDPPrecision
from lightning.fabric.plugins.precision.half import HalfPrecision
from lightning.fabric.plugins.precision.precision import Precision
from lightning.fabric.plugins.precision.tpu import TPUPrecision
from lightning.fabric.plugins.precision.tpu_bf16 import TPUBf16Precision
Expand All @@ -31,6 +32,7 @@
"Precision",
"DeepSpeedPrecision",
"DoublePrecision",
"HalfPrecision",
"MixedPrecision",
"TPUPrecision",
"TPUBf16Precision",
Expand Down
2 changes: 2 additions & 0 deletions src/lightning/fabric/plugins/precision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
from lightning.fabric.plugins.precision.deepspeed import DeepSpeedPrecision
from lightning.fabric.plugins.precision.double import DoublePrecision
from lightning.fabric.plugins.precision.fsdp import FSDPPrecision
from lightning.fabric.plugins.precision.half import HalfPrecision
from lightning.fabric.plugins.precision.precision import Precision
from lightning.fabric.plugins.precision.tpu import TPUPrecision
from lightning.fabric.plugins.precision.tpu_bf16 import TPUBf16Precision

__all__ = [
"DeepSpeedPrecision",
"DoublePrecision",
"HalfPrecision",
"MixedPrecision",
"Precision",
"TPUPrecision",
Expand Down
69 changes: 69 additions & 0 deletions src/lightning/fabric/plugins/precision/half.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Any, Generator, Literal

import torch
from lightning_utilities.core.apply_func import apply_to_collection
from torch import Tensor
from torch.nn import Module

from lightning.fabric.plugins.precision.precision import Precision
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor


class HalfPrecision(Precision):
"""Plugin for training with half precision.

Args:
precision: Whether to use ``torch.float16`` (``'16-true'``) or ``torch.bfloat16`` (``'bf16-true'``).
"""

precision: Literal["bf16-true", "16-true"] = "16-true"

def __init__(self, precision: Literal["bf16-true", "16-true"] = "16-true") -> None:
self.precision = precision
self._desired_input_dtype = torch.bfloat16 if precision == "bf16-true" else torch.float16

def convert_module(self, module: Module) -> Module:
return module.to(dtype=self._desired_input_dtype)

@contextmanager
def module_init_context(self) -> Generator[None, None, None]:
"""A context manager to change the default tensor type when initializing the parameters in a module.

See: :meth:`torch.set_default_tensor_type`
"""
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(self._desired_input_dtype)
yield
torch.set_default_dtype(default_dtype)

@contextmanager
def forward_context(self) -> Generator[None, None, None]:
"""A context manager to change the default tensor type when tensors get created during the module's
forward.

See: :meth:`torch.set_default_tensor_type`
"""
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(self._desired_input_dtype)
yield
torch.set_default_dtype(default_dtype)

def convert_input(self, data: Any) -> Any:
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_input_dtype)

def convert_output(self, data: Any) -> Any:
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype())
2 changes: 1 addition & 1 deletion src/lightning/fabric/plugins/precision/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
_PRECISION_INPUT_INT = Literal[64, 32, 16]
_PRECISION_INPUT_STR_ALIAS_CONVERSION = {"64": "64-true", "32": "32-true", "16": "16-mixed", "bf16": "bf16-mixed"}
_PRECISION_INPUT_STR_ALIAS = Literal["64", "32", "16", "bf16"]
_PRECISION_INPUT_STR = Literal["16-mixed", "bf16-mixed", "32-true", "64-true"]
_PRECISION_INPUT_STR = Literal["16-true", "16-mixed", "bf16-true", "bf16-mixed", "32-true", "64-true"]
_PRECISION_INPUT = Union[_PRECISION_INPUT_INT, _PRECISION_INPUT_STR, _PRECISION_INPUT_STR_ALIAS]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ def _check_and_init_precision(self) -> PrecisionPlugin:

if isinstance(self.strategy, FSDPStrategy):
return FSDPMixedPrecisionPlugin(self._precision_flag, device)
return MixedPrecisionPlugin(self._precision_flag, device)
return MixedPrecisionPlugin(self._precision_flag, device) # type: ignore[arg-type]

raise RuntimeError("No precision set")

Expand Down
9 changes: 8 additions & 1 deletion tests/tests_fabric/plugins/precision/test_double.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from lightning.fabric.plugins.precision.double import DoublePrecision
Expand All @@ -23,3 +22,11 @@ def test_double_precision_forward_context():
with precision.forward_context():
assert torch.get_default_dtype() == torch.float64
assert torch.get_default_dtype() == torch.float32


def test_convert_module():
precision = DoublePrecision()
module = torch.nn.Linear(2, 2)
assert module.weight.dtype == module.bias.dtype == torch.float32
module = precision.convert_module(module)
assert module.weight.dtype == module.bias.dtype == torch.float64
75 changes: 75 additions & 0 deletions tests/tests_fabric/plugins/precision/test_half.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch

from lightning.fabric.plugins.precision import HalfPrecision


@pytest.mark.parametrize(
"precision, expected_dtype",
[
("bf16-true", torch.bfloat16),
("16-true", torch.half),
],
)
def test_selected_dtype(precision, expected_dtype):
plugin = HalfPrecision(precision=precision)
assert plugin.precision == precision
assert plugin._desired_input_dtype == expected_dtype


@pytest.mark.parametrize(
"precision, expected_dtype",
[
("bf16-true", torch.bfloat16),
("16-true", torch.half),
],
)
def test_module_init_context(precision, expected_dtype):
plugin = HalfPrecision(precision=precision)
with plugin.module_init_context():
model = torch.nn.Linear(2, 2)
assert torch.get_default_dtype() == expected_dtype
assert model.weight.dtype == expected_dtype


@pytest.mark.parametrize(
"precision, expected_dtype",
[
("bf16-true", torch.bfloat16),
("16-true", torch.half),
],
)
def test_forward_context(precision, expected_dtype):
precision = HalfPrecision(precision=precision)
assert torch.get_default_dtype() == torch.float32
with precision.forward_context():
assert torch.get_default_dtype() == expected_dtype
assert torch.get_default_dtype() == torch.float32


@pytest.mark.parametrize(
"precision, expected_dtype",
[
("bf16-true", torch.bfloat16),
("16-true", torch.half),
],
)
def test_convert_module(precision, expected_dtype):
precision = HalfPrecision(precision=precision)
module = torch.nn.Linear(2, 2)
assert module.weight.dtype == module.bias.dtype == torch.float32
module = precision.convert_module(module)
assert module.weight.dtype == module.bias.dtype == expected_dtype
4 changes: 3 additions & 1 deletion tests/tests_fabric/strategies/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch
from torch.nn.parallel import DistributedDataParallel

from lightning.fabric.plugins import DoublePrecision, Precision
from lightning.fabric.plugins import DoublePrecision, HalfPrecision, Precision
from lightning.fabric.plugins.environments import LightningEnvironment
from lightning.fabric.strategies import DDPStrategy
from lightning.fabric.strategies.ddp import _DDPBackwardSyncControl
Expand Down Expand Up @@ -133,6 +133,8 @@ def test_ddp_grad_clipping(clip_type, accelerator, precision):
"precision,expected_dtype",
[
(Precision(), torch.float32),
(HalfPrecision("16-true"), torch.float16),
pytest.param(HalfPrecision("bf16-true"), torch.bfloat16, marks=RunIf(bf16_cuda=True)),
(DoublePrecision(), torch.float64),
],
)
Expand Down
2 changes: 2 additions & 0 deletions tests/tests_fabric/strategies/test_fsdp_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ def test_compile(compile_after_setup):
"precision,expected_dtype",
[
("32-true", torch.float32),
("16-true", torch.float16),
pytest.param("bf16-true", torch.bfloat16, marks=RunIf(bf16_cuda=True)),
("64-true", torch.float64),
],
)
Expand Down
4 changes: 3 additions & 1 deletion tests/tests_fabric/strategies/test_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import pytest
import torch

from lightning.fabric.plugins import DoublePrecision, Precision
from lightning.fabric.plugins import DoublePrecision, HalfPrecision, Precision
from lightning.fabric.strategies import SingleDeviceStrategy
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from lightning.fabric.wrappers import _FabricModule, _FabricOptimizer
Expand Down Expand Up @@ -163,6 +163,8 @@ def test_single_device_grad_clipping(clip_type, precision):
"precision,dtype",
[
(Precision(), torch.float32),
(HalfPrecision("16-true"), torch.float16),
pytest.param(HalfPrecision("bf16-true"), torch.bfloat16, marks=RunIf(mps=False)),
pytest.param(DoublePrecision(), torch.float64, marks=RunIf(mps=False)),
],
)
Expand Down
Loading