Skip to content

Commit ac5fa03

Browse files
authored
Introduce new precision layout in fabric (#16767)
1 parent 3a354ac commit ac5fa03

33 files changed

+214
-135
lines changed

docs/source-pytorch/fabric/api/fabric_args.rst

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,23 +112,27 @@ Learn more about :ref:`distributed multi-node training on clusters <Fabric Clust
112112
precision
113113
=========
114114

115-
Fabric supports double precision (64), full precision (32), or half-precision (16) operation (including `bfloat16 <https://pytorch.org/docs/1.10.0/generated/torch.Tensor.bfloat16.html>`_).
115+
Fabric supports double precision (64 bit), full precision (32 bit), or half-precision (16 bit) floating point operation (including `bfloat16 <https://pytorch.org/docs/1.10.0/generated/torch.Tensor.bfloat16.html>`_).
116116
Half precision, or mixed precision, combines 32 and 16-bit floating points to reduce the memory footprint during model training.
117+
Automatic mixed precision settings are denoted by a ``"-mixed"`` suffix, while settings that only work in the specified precision have a ``"-true"`` suffix.
117118
This can result in improved performance, achieving significant speedups on modern GPUs.
118119

119120
.. code-block:: python
120121
121122
# Default used by the Fabric
122-
fabric = Fabric(precision=32, devices=1)
123+
fabric = Fabric(precision="32-true", devices=1)
124+
125+
# the same as:
126+
fabric = Fabric(precision="32", devices=1)
123127
124128
# 16-bit (mixed) precision
125-
fabric = Fabric(precision=16, devices=1)
129+
fabric = Fabric(precision="16-mixed", devices=1)
126130
127131
# 16-bit bfloat precision
128-
fabric = Fabric(precision="bf16", devices=1)
132+
fabric = Fabric(precision="bf16-mixed", devices=1)
129133
130134
# 64-bit (double) precision
131-
fabric = Fabric(precision=64, devices=1)
135+
fabric = Fabric(precision="64-true", devices=1)
132136
133137
See also: :doc:`../fundamentals/precision`
134138

docs/source-pytorch/fabric/fundamentals/launch.rst

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,13 @@ This is essentially the same as running ``python path/to/your/script.py``, but i
6868
--main-port, --main_port INTEGER
6969
The main port to connect to the main
7070
machine.
71-
--precision [64|32|16|bf16] Double precision (``64``), full precision
72-
(``32``), half precision (``16``) or
73-
bfloat16 precision (``'bf16'``)
71+
--precision [16-mixed|bf16-mixed|32-true|64-true|64|32|16|bf16]
72+
Double precision (``64-true`` or ``64``),
73+
full precision (``32-true`` or ``64``), half
74+
precision (``16-mixed`` or ``16``) or
75+
bfloat16 precision (``bf16-mixed`` or
76+
``bf16``)
77+
7478
--help Show this message and exit.
7579
7680

docs/source-pytorch/fabric/fundamentals/precision.rst

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,26 +24,35 @@ This is how you select the precision in Fabric:
2424
from lightning.fabric import Fabric
2525
2626
# This is the default
27+
fabric = Fabric(precision="32-true")
28+
29+
# Also FP32
2730
fabric = Fabric(precision=32)
2831
29-
# FP16 mixed precision
30-
fabric = Fabric(precision=16)
32+
# FP32 as well
33+
fabric = Fabric(precision="32")
3134
32-
# Precision values can also be set as a string
33-
fabric = Fabric(precision="16")
35+
# FP16 mixed precision
36+
fabric = Fabric(precision="16-mixed)
3437
3538
# BFloat16 precision (Volta GPUs and later)
36-
fabric = Fabric(precision="bf16")
39+
fabric = Fabric(precision="bf16-mixed")
3740
3841
# Double precision
42+
fabric = Fabric(precision="64-true")
43+
44+
# Or
45+
fabric = Fabric(precision="64")
46+
47+
# Or
3948
fabric = Fabric(precision=64)
4049
4150
4251
The same values can also be set through the :doc:`command line interface <launch>`:
4352
4453
.. code-block:: bash
4554
46-
lightning run model train.py --precision=bf16
55+
lightning run model train.py --precision=bf16-mixed
4756
4857
4958
.. note::
@@ -70,14 +79,11 @@ This is how you enable FP16 in Fabric:
7079
.. code-block:: python
7180
7281
# Select FP16 mixed precision
73-
fabric = Fabric(precision=16)
74-
75-
# Or as a string
76-
fabric = Fabric(precision="16")
82+
fabric = Fabric(precision="16-mixed")
7783
7884
.. note::
7985
80-
When using TPUs, setting ``precision=16`` will enable bfloat16, the only supported half-precision type on TPUs.
86+
When using TPUs, setting ``precision="16-mixed"`` will enable bfloat16 based mixed precision, the only supported half-precision type on TPUs.
8187
8288
8389
----
@@ -94,7 +100,7 @@ For more information, see `this TPU performance blog post <https://cloud.google.
94100
.. code-block:: python
95101
96102
# Select BF16 precision
97-
fabric = Fabric(precision="bf16")
103+
fabric = Fabric(precision="bf16-mixed")
98104
99105
100106
Under the hood, we use `torch.autocast <https://pytorch.org/docs/stable/amp.html>`__ with the dtype set to ``bfloat16``, with no gradient scaling.
@@ -117,7 +123,7 @@ Fabric automatically casts the data type and operations in the ``forward`` of yo
117123
118124
.. code-block:: python
119125
120-
fabric = Fabric(precision="bf16")
126+
fabric = Fabric(precision="bf16-mixed")
121127
122128
model = ...
123129
optimizer = ...

docs/source-pytorch/fabric/guide/multi_node/cloud.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ Launch multi-node training in the cloud
5050
def run(self):
5151
# Set up Fabric
5252
# The `devices` and `num_nodes` gets set by Lightning automatically
53-
fabric = L.Fabric(strategy="ddp", precision=16)
53+
fabric = L.Fabric(strategy="ddp", precision="16-mixed")
5454
5555
# Your training code
5656
model = ...

src/lightning/fabric/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3838
- Renamed `strategy='tpu_spawn'` to `strategy='xla'` and `strategy='tpu_spawn_debug'` to `strategy='xla_debug'` ([#16781](https://github.com/Lightning-AI/lightning/pull/16781))
3939

4040

41+
- Changed arguments for precision settings (from [64|32|16|bf16] to ["64-true"|"32-true"|"16-mixed"|"bf16-mixed"]) ([#16767](https://github.com/Lightning-AI/lightning/pull/16767))
42+
4143
### Deprecated
4244

4345
-

src/lightning/fabric/cli.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818
from typing import Any, List, Optional
1919

2020
from lightning_utilities.core.imports import RequirementCache
21+
from typing_extensions import get_args
2122

2223
from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator
24+
from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT_STR, _PRECISION_INPUT_STR_ALIAS
2325
from lightning.fabric.strategies import STRATEGY_REGISTRY
2426
from lightning.fabric.utilities.device_parser import _parse_gpu_ids
2527

@@ -28,7 +30,6 @@
2830
_CLICK_AVAILABLE = RequirementCache("click")
2931

3032
_SUPPORTED_ACCELERATORS = ("cpu", "gpu", "cuda", "mps", "tpu")
31-
_SUPPORTED_PRECISION = ("64", "32", "16", "bf16")
3233

3334

3435
def _get_supported_strategies() -> List[str]:
@@ -106,11 +107,11 @@ def _get_supported_strategies() -> List[str]:
106107
)
107108
@click.option(
108109
"--precision",
109-
type=click.Choice(_SUPPORTED_PRECISION),
110-
default="32",
110+
type=click.Choice(get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_STR_ALIAS)),
111+
default="32-true",
111112
help=(
112-
"Double precision (``64``), full precision (``32``), half precision (``16``) or bfloat16 precision"
113-
" (``'bf16'``)"
113+
"Double precision (``64-true`` or ``64``), full precision (``32-true`` or ``64``), "
114+
"half precision (``16-mixed`` or ``16``) or bfloat16 precision (``bf16-mixed`` or ``bf16``)"
114115
),
115116
)
116117
@click.argument("script_args", nargs=-1, type=click.UNPROCESSED)

src/lightning/fabric/connector.py

Lines changed: 45 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,13 @@
4242
)
4343
from lightning.fabric.plugins.precision.double import DoublePrecision
4444
from lightning.fabric.plugins.precision.fsdp import FSDPPrecision
45-
from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT, _PRECISION_INPUT_INT, _PRECISION_INPUT_STR
45+
from lightning.fabric.plugins.precision.precision import (
46+
_PRECISION_INPUT,
47+
_PRECISION_INPUT_INT,
48+
_PRECISION_INPUT_STR,
49+
_PRECISION_INPUT_STR_ALIAS,
50+
_PRECISION_INPUT_STR_ALIAS_CONVERSION,
51+
)
4652
from lightning.fabric.strategies import (
4753
DeepSpeedStrategy,
4854
ParallelStrategy,
@@ -98,7 +104,7 @@ def __init__(
98104
strategy: Optional[Union[str, Strategy]] = None,
99105
devices: Optional[Union[List[int], str, int]] = None,
100106
num_nodes: int = 1,
101-
precision: _PRECISION_INPUT = 32,
107+
precision: _PRECISION_INPUT = "32-true",
102108
plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None,
103109
) -> None:
104110

@@ -107,7 +113,7 @@ def __init__(
107113
strategy = self._argument_from_env("strategy", strategy, default=None)
108114
devices = self._argument_from_env("devices", devices, default=None)
109115
num_nodes = self._argument_from_env("num_nodes", num_nodes, default=1)
110-
precision = self._argument_from_env("precision", precision, default=32)
116+
precision = self._argument_from_env("precision", precision, default="32-true")
111117

112118
# 1. Parsing flags
113119
# Get registered strategies, built-in accelerators and precision plugins
@@ -119,7 +125,7 @@ def __init__(
119125
# For devices: Assign gpus, etc. to the accelerator flag and devices flag
120126
self._strategy_flag: Optional[Union[Strategy, str]] = None
121127
self._accelerator_flag: Optional[Union[Accelerator, str]] = None
122-
self._precision_input: _PRECISION_INPUT_STR = "32"
128+
self._precision_input: _PRECISION_INPUT_STR = "32-true"
123129
self._precision_instance: Optional[Precision] = None
124130
self._cluster_environment_flag: Optional[Union[ClusterEnvironment, str]] = None
125131
self._parallel_devices: List[Union[int, torch.device, str]] = []
@@ -220,10 +226,7 @@ def _check_config_and_set_final_flags(
220226

221227
self._accelerator_flag = accelerator
222228

223-
supported_precision = get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_INT)
224-
if precision not in supported_precision:
225-
raise ValueError(f"Precision {repr(precision)} is invalid. Allowed precision values: {supported_precision}")
226-
self._precision_input = cast(_PRECISION_INPUT_STR, str(precision))
229+
self._precision_input = _convert_precision_to_unified_args(precision)
227230

228231
if plugins:
229232
plugins_flags_types: Dict[str, int] = Counter()
@@ -453,34 +456,34 @@ def _check_and_init_precision(self) -> Precision:
453456
return self._precision_instance
454457

455458
if isinstance(self.accelerator, TPUAccelerator):
456-
if self._precision_input == "32":
459+
if self._precision_input == "32-true":
457460
return TPUPrecision()
458-
elif self._precision_input in ("16", "bf16"):
459-
if self._precision_input == "16":
461+
elif self._precision_input in ("16-mixed", "bf16-mixed"):
462+
if self._precision_input == "16-mixed":
460463
rank_zero_warn(
461-
"You passed `Fabric(accelerator='tpu', precision=16)` but AMP"
462-
" is not supported with TPUs. Using `precision='bf16'` instead."
464+
"You passed `Fabric(accelerator='tpu', precision='16-mixed')` but AMP with fp16"
465+
" is not supported with TPUs. Using `precision='bf16-mixed'` instead."
463466
)
464467
return TPUBf16Precision()
465468
if isinstance(self.strategy, DeepSpeedStrategy):
466469
return DeepSpeedPrecision(self._precision_input) # type: ignore
467470

468-
if self._precision_input == "32":
471+
if self._precision_input == "32-true":
469472
return Precision()
470-
if self._precision_input == "64":
473+
if self._precision_input == "64-true":
471474
return DoublePrecision()
472475

473-
if self._precision_input == "16" and self._accelerator_flag == "cpu":
476+
if self._precision_input == "16-mixed" and self._accelerator_flag == "cpu":
474477
rank_zero_warn(
475-
"You passed `Fabric(accelerator='cpu', precision=16)` but AMP is not supported on CPU."
476-
" Using `precision='bf16'` instead."
478+
"You passed `Fabric(accelerator='cpu', precision='16-mixed')` but AMP with fp16 is not supported on "
479+
"CPU. Using `precision='bf16-mixed'` instead."
477480
)
478-
self._precision_input = "bf16"
481+
self._precision_input = "bf16-mixed"
479482

480-
if self._precision_input in ("16", "bf16"):
483+
if self._precision_input in ("16-mixed", "bf16-mixed"):
481484
rank_zero_info(
482485
"Using 16-bit Automatic Mixed Precision (AMP)"
483-
if self._precision_input == "16"
486+
if self._precision_input == "16-mixed"
484487
else "Using bfloat16 Automatic Mixed Precision (AMP)"
485488
)
486489
device = "cpu" if self._accelerator_flag == "cpu" else "cuda"
@@ -494,9 +497,9 @@ def _check_and_init_precision(self) -> Precision:
494497
def _validate_precision_choice(self) -> None:
495498
"""Validate the combination of choices for precision, and accelerator."""
496499
if isinstance(self.accelerator, TPUAccelerator):
497-
if self._precision_input == "64":
500+
if self._precision_input == "64-true":
498501
raise NotImplementedError(
499-
"`Fabric(accelerator='tpu', precision=64)` is not implemented."
502+
"`Fabric(accelerator='tpu', precision='64-true')` is not implemented."
500503
" Please, open an issue in `https://github.com/Lightning-AI/lightning/issues`"
501504
" requesting this feature."
502505
)
@@ -561,3 +564,22 @@ def _argument_from_env(name: str, current: Any, default: Any) -> Any:
561564
if env_value is None:
562565
return current
563566
return env_value
567+
568+
569+
def _convert_precision_to_unified_args(precision: _PRECISION_INPUT) -> _PRECISION_INPUT_STR:
570+
supported_precision = (
571+
get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_INT) + get_args(_PRECISION_INPUT_STR_ALIAS)
572+
)
573+
if precision not in supported_precision:
574+
raise ValueError(f"Precision {repr(precision)} is invalid. Allowed precision values: {supported_precision}")
575+
576+
precision = str(precision) # convert int flags to str here to enable the legacy-conversion below
577+
578+
if precision in get_args(_PRECISION_INPUT_STR_ALIAS):
579+
if str(precision)[:2] not in ("32", "64"):
580+
rank_zero_warn(
581+
f"{precision} is supported for historical reasons but its usage is discouraged. "
582+
f"Please set your precision to {_PRECISION_INPUT_STR_ALIAS_CONVERSION[precision]} instead!"
583+
)
584+
precision = _PRECISION_INPUT_STR_ALIAS_CONVERSION[precision]
585+
return cast(_PRECISION_INPUT_STR, precision)

src/lightning/fabric/fabric.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ class Fabric:
6767
devices: Number of devices to train on (``int``), which GPUs to train on (``list`` or ``str``), or ``"auto"``.
6868
The value applies per node.
6969
num_nodes: Number of GPU nodes for distributed training.
70-
precision: Double precision (``64``), full precision (``32``), half precision (``16``),
71-
or bfloat16 precision (``"bf16"``).
70+
precision: Double precision (``"64-true"``), full precision (``"32"``), half precision AMP (``"16-mixed"``),
71+
or bfloat16 precision AMP (``"bf16-mixed"``).
7272
plugins: One or several custom plugins
7373
callbacks: A single callback or a list of callbacks. A callback can contain any arbitrary methods that
7474
can be invoked through :meth:`~lightning.fabric.fabric.Fabric.call` by the user.
@@ -82,7 +82,7 @@ def __init__(
8282
strategy: Optional[Union[str, Strategy]] = None,
8383
devices: Optional[Union[List[int], str, int]] = None,
8484
num_nodes: int = 1,
85-
precision: _PRECISION_INPUT = 32,
85+
precision: _PRECISION_INPUT = "32-true",
8686
plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None,
8787
callbacks: Optional[Union[List[Any], Any]] = None,
8888
loggers: Optional[Union[Logger, List[Logger]]] = None,

src/lightning/fabric/plugins/precision/amp.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,21 +29,24 @@ class MixedPrecision(Precision):
2929
"""Plugin for Automatic Mixed Precision (AMP) training with ``torch.autocast``.
3030
3131
Args:
32-
precision: Whether to use ``torch.float16`` (``16``) or ``torch.bfloat16`` (``'bf16'``).
32+
precision: Whether to use ``torch.float16`` (``'16-mixed'``) or ``torch.bfloat16`` (``'bf16-mixed'``).
3333
device: The device for ``torch.autocast``.
3434
scaler: An optional :class:`torch.cuda.amp.GradScaler` to use.
3535
"""
3636

3737
def __init__(
38-
self, precision: Literal["16", 16, "bf16"], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None
38+
self,
39+
precision: Literal["16-mixed", "bf16-mixed"],
40+
device: str,
41+
scaler: Optional[torch.cuda.amp.GradScaler] = None,
3942
) -> None:
40-
self.precision = cast(Literal["16", "bf16"], str(precision))
41-
if scaler is None and self.precision == "16":
43+
self.precision = cast(Literal["16-mixed", "bf16-mixed"], str(precision))
44+
if scaler is None and self.precision == "16-mixed":
4245
with _patch_cuda_is_available():
4346
# if possible, we defer CUDA initialization to support strategies that will attempt forks
4447
scaler = torch.cuda.amp.GradScaler()
45-
if scaler is not None and self.precision == "bf16":
46-
raise ValueError(f"`precision='bf16'` does not use a scaler, found {scaler}.")
48+
if scaler is not None and self.precision == "bf16-mixed":
49+
raise ValueError(f"`precision='bf16-mixed'` does not use a scaler, found {scaler}.")
4750
self.device = device
4851
self.scaler = scaler
4952

@@ -53,7 +56,7 @@ def forward_context(self) -> Generator[None, None, None]:
5356
yield
5457

5558
def convert_input(self, data: Tensor) -> Tensor:
56-
precision_to_type = {"bf16": torch.bfloat16, "16": torch.float16}
59+
precision_to_type = {"bf16-mixed": torch.bfloat16, "16-mixed": torch.float16}
5760
dst_type = precision_to_type[self.precision]
5861
return _convert_fp_tensor(data, dst_type)
5962

@@ -89,4 +92,4 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
8992
def _autocast_context_manager(self) -> torch.autocast:
9093
# the dtype could be automatically inferred but we need to manually set it due to a bug upstream
9194
# https://github.com/pytorch/pytorch/issues/67233
92-
return torch.autocast(self.device, dtype=torch.bfloat16 if self.precision == "bf16" else torch.half)
95+
return torch.autocast(self.device, dtype=torch.bfloat16 if self.precision == "bf16-mixed" else torch.half)

0 commit comments

Comments
 (0)