Skip to content

Commit 8f81daf

Browse files
awaelchliBorda
andauthored
Support true 16-bit precision with deepspeed in Trainer (#18217)
Co-authored-by: Jirka Borovec <[email protected]>
1 parent 774ea1e commit 8f81daf

File tree

9 files changed

+119
-44
lines changed

9 files changed

+119
-44
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
# validation HPU connectors
22
lightning-habana >=0.1.0
3-
lightning-graphcore >=0.1.0.rc3
3+
lightning-graphcore >=0.1.0.rc4

src/lightning/fabric/plugins/precision/deepspeed.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,12 @@ def convert_module(self, module: Module) -> Module:
7070

7171
@contextmanager
7272
def init_context(self) -> Generator[None, None, None]:
73+
if "true" not in self.precision:
74+
yield
75+
return
76+
7377
default_dtype = torch.get_default_dtype()
74-
torch.set_default_dtype(self._desired_dtype if "true" in self.precision else default_dtype)
78+
torch.set_default_dtype(self._desired_dtype)
7579
yield
7680
torch.set_default_dtype(default_dtype)
7781

src/lightning/pytorch/CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
9191
- Allowed accessing rank information in the main process before processes are launched when using the `XLAStrategy` ([#18194](https://github.com/Lightning-AI/lightning/pull/18194))
9292

9393

94-
- Added support for true half-precision training via `Trainer(precision="16-true"|"bf16-true")` ([#18193](https://github.com/Lightning-AI/lightning/pull/18193))
94+
- Added support for true half-precision training via `Trainer(precision="16-true"|"bf16-true")` ([#18193](https://github.com/Lightning-AI/lightning/pull/18193), [#18217](https://github.com/Lightning-AI/lightning/pull/18217))
9595

9696

9797
- Added automatic process cleanup to avoid zombie child processes and stalls when exceptions are raised ([#18218](https://github.com/Lightning-AI/lightning/pull/18218))

src/lightning/pytorch/plugins/precision/deepspeed.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,19 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any, Callable, cast, Literal, Optional, TYPE_CHECKING, Union
14+
from contextlib import contextmanager
15+
from typing import Any, Callable, Generator, Optional, TYPE_CHECKING, Union
1516

17+
import torch
18+
from lightning_utilities import apply_to_collection
1619
from torch import Tensor
20+
from torch.nn import Module
1721
from torch.optim import LBFGS, Optimizer
1822
from typing_extensions import get_args
1923

2024
import lightning.pytorch as pl
25+
from lightning.fabric.plugins.precision.deepspeed import _PRECISION_INPUT
26+
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor
2127
from lightning.fabric.strategies.deepspeed import _DEEPSPEED_AVAILABLE
2228
from lightning.fabric.utilities.types import Steppable
2329
from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin
@@ -31,16 +37,16 @@
3137

3238
warning_cache = WarningCache()
3339

34-
_PRECISION_INPUT = Literal["32-true", "16-mixed", "bf16-mixed"]
35-
3640

3741
class DeepSpeedPrecisionPlugin(PrecisionPlugin):
3842
"""Precision plugin for DeepSpeed integration.
3943
4044
.. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature.
4145
4246
Args:
43-
precision: Full precision (32), half precision (16) or bfloat16 precision (bf16).
47+
precision: Full precision (32-true), half precision (16-true, bf16-true) or
48+
mixed precision (16-mixed, bf16-mixed).
49+
4450
Raises:
4551
ValueError:
4652
If unsupported ``precision`` is provided.
@@ -53,7 +59,34 @@ def __init__(self, precision: _PRECISION_INPUT) -> None:
5359
f"`Trainer(strategy='deepspeed', precision={precision!r})` is not supported."
5460
f" `precision` must be one of: {supported_precision}."
5561
)
56-
self.precision = cast(_PRECISION_INPUT, str(precision))
62+
self.precision = precision
63+
precision_to_type = {
64+
"bf16-mixed": torch.bfloat16,
65+
"16-mixed": torch.float16,
66+
"bf16-true": torch.bfloat16,
67+
"16-true": torch.float16,
68+
"32-true": torch.float32,
69+
}
70+
self._desired_dtype = precision_to_type[self.precision]
71+
72+
def convert_module(self, module: Module) -> Module:
73+
if "true" in self.precision:
74+
return module.to(dtype=self._desired_dtype)
75+
return module
76+
77+
def convert_input(self, data: Any) -> Any:
78+
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_dtype)
79+
80+
@contextmanager
81+
def init_context(self) -> Generator[None, None, None]:
82+
if "true" not in self.precision:
83+
yield
84+
return
85+
86+
default_dtype = torch.get_default_dtype()
87+
torch.set_default_dtype(self._desired_dtype)
88+
yield
89+
torch.set_default_dtype(default_dtype)
5790

5891
def backward( # type: ignore[override]
5992
self,

src/lightning/pytorch/strategies/deepspeed.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@
2222
from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, TYPE_CHECKING, Union
2323

2424
import torch
25-
from lightning_utilities.core.apply_func import apply_to_collection
26-
from torch import Tensor
2725
from torch.nn import Module
2826
from torch.optim import Optimizer
2927

@@ -43,7 +41,6 @@
4341
from lightning.pytorch.core.optimizer import _init_optimizers_and_lr_schedulers
4442
from lightning.pytorch.plugins.precision import PrecisionPlugin
4543
from lightning.pytorch.strategies.ddp import DDPStrategy
46-
from lightning.pytorch.strategies.utils import _fp_to_half
4744
from lightning.pytorch.trainer.states import TrainerFn
4845
from lightning.pytorch.utilities import GradClipAlgorithmType
4946
from lightning.pytorch.utilities.exceptions import MisconfigurationException
@@ -894,5 +891,8 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
894891
)
895892

896893
def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0) -> Any:
897-
batch = apply_to_collection(batch, Tensor, function=_fp_to_half, precision=self.precision_plugin.precision)
894+
# The strategy casts the input before moving to the device
895+
# In all other strategies, the input gets converted in the `Strategy.*_step` methods
896+
# TODO: standardize this for all strategies
897+
batch = self.precision_plugin.convert_input(batch)
898898
return super().batch_to_device(batch, device, dataloader_idx)

src/lightning/pytorch/strategies/utils.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,7 @@
1313
# limitations under the License.
1414
import importlib
1515
from inspect import getmembers, isclass
16-
from typing import Literal
1716

18-
import torch
19-
from torch import Tensor
20-
21-
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor
2217
from lightning.fabric.strategies import _StrategyRegistry
2318
from lightning.fabric.utilities.registry import _is_register_method_overridden
2419
from lightning.pytorch.strategies.strategy import Strategy
@@ -30,19 +25,3 @@ def _call_register_strategies(registry: _StrategyRegistry, base_module: str) ->
3025
for _, mod in getmembers(module, isclass):
3126
if issubclass(mod, Strategy) and _is_register_method_overridden(mod, Strategy, "register_strategies"):
3227
mod.register_strategies(registry)
33-
34-
35-
def _fp_to_half(
36-
tensor: Tensor,
37-
precision: Literal[
38-
"64-true",
39-
"32-true",
40-
"16-mixed",
41-
"bf16-mixed",
42-
],
43-
) -> Tensor:
44-
if str(precision) == "16-mixed":
45-
return _convert_fp_tensor(tensor, torch.half)
46-
if precision == "bf16-mixed":
47-
return _convert_fp_tensor(tensor, torch.bfloat16)
48-
return tensor

tests/tests_pytorch/plugins/precision/test_deepspeed_precision.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,63 @@
1313
# limitations under the License.
1414

1515
import pytest
16+
import torch
1617

1718
from lightning.pytorch.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin
1819

1920

2021
def test_invalid_precision_with_deepspeed_precision():
2122
with pytest.raises(ValueError, match="is not supported. `precision` must be one of"):
2223
DeepSpeedPrecisionPlugin(precision="64-true")
24+
25+
26+
@pytest.mark.parametrize(
27+
("precision", "expected_dtype"),
28+
[
29+
("32-true", torch.float32),
30+
("bf16-mixed", torch.bfloat16),
31+
("16-mixed", torch.float16),
32+
("bf16-true", torch.bfloat16),
33+
("16-true", torch.float16),
34+
],
35+
)
36+
def test_selected_dtype(precision, expected_dtype):
37+
plugin = DeepSpeedPrecisionPlugin(precision=precision)
38+
assert plugin.precision == precision
39+
assert plugin._desired_dtype == expected_dtype
40+
41+
42+
@pytest.mark.parametrize(
43+
("precision", "expected_dtype"),
44+
[
45+
("32-true", torch.float32),
46+
("bf16-mixed", torch.float32),
47+
("16-mixed", torch.float32),
48+
("bf16-true", torch.bfloat16),
49+
("16-true", torch.float16),
50+
],
51+
)
52+
def test_module_init_context(precision, expected_dtype):
53+
plugin = DeepSpeedPrecisionPlugin(precision=precision)
54+
with plugin.init_context():
55+
model = torch.nn.Linear(2, 2)
56+
assert torch.get_default_dtype() == expected_dtype
57+
assert model.weight.dtype == expected_dtype
58+
59+
60+
@pytest.mark.parametrize(
61+
("precision", "expected_dtype"),
62+
[
63+
("32-true", torch.float32),
64+
("bf16-mixed", torch.float32),
65+
("16-mixed", torch.float32),
66+
("bf16-true", torch.bfloat16),
67+
("16-true", torch.float16),
68+
],
69+
)
70+
def test_convert_module(precision, expected_dtype):
71+
precision = DeepSpeedPrecisionPlugin(precision=precision)
72+
module = torch.nn.Linear(2, 2)
73+
assert module.weight.dtype == module.bias.dtype == torch.float32
74+
module = precision.convert_module(module)
75+
assert module.weight.dtype == module.bias.dtype == expected_dtype

tests/tests_pytorch/strategies/test_deepspeed_strategy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1266,7 +1266,7 @@ def transfer_batch_to_device(self, batch, *args, **kwargs):
12661266
model = CustomBoringModel()
12671267
trainer = Trainer(strategy="deepspeed", devices=1, accelerator="cuda", precision="16-mixed")
12681268
trainer.strategy.connect(model)
1269-
batch = torch.zeros((1), dtype=torch.float32)
1269+
batch = torch.zeros(1, dtype=torch.float32)
12701270
batch = trainer.strategy.batch_to_device(batch)
12711271
assert batch.is_cuda
12721272
assert batch.dtype is torch.float16

tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from lightning.pytorch.plugins.io import TorchCheckpointIO
3737
from lightning.pytorch.plugins.layer_sync import LayerSync, TorchSyncBatchNorm
3838
from lightning.pytorch.plugins.precision import (
39+
DeepSpeedPrecisionPlugin,
3940
DoublePrecisionPlugin,
4041
HalfPrecisionPlugin,
4142
MixedPrecisionPlugin,
@@ -1015,16 +1016,21 @@ def test_connector_sets_num_nodes(strategy, cuda_count_2):
10151016

10161017

10171018
@pytest.mark.parametrize(
1018-
("precision_str", "precision_cls"),
1019+
("precision_str", "strategy_str", "expected_precision_cls"),
10191020
[
1020-
("64-true", DoublePrecisionPlugin),
1021-
("32-true", PrecisionPlugin),
1022-
("16-true", HalfPrecisionPlugin),
1023-
("bf16-true", HalfPrecisionPlugin),
1024-
("16-mixed", MixedPrecisionPlugin),
1025-
("bf16-mixed", MixedPrecisionPlugin),
1021+
("64-true", "auto", DoublePrecisionPlugin),
1022+
("32-true", "auto", PrecisionPlugin),
1023+
("16-true", "auto", HalfPrecisionPlugin),
1024+
("bf16-true", "auto", HalfPrecisionPlugin),
1025+
("16-mixed", "auto", MixedPrecisionPlugin),
1026+
("bf16-mixed", "auto", MixedPrecisionPlugin),
1027+
pytest.param("32-true", "deepspeed", DeepSpeedPrecisionPlugin, marks=RunIf(deepspeed=True, mps=False)),
1028+
pytest.param("16-true", "deepspeed", DeepSpeedPrecisionPlugin, marks=RunIf(deepspeed=True, mps=False)),
1029+
pytest.param("bf16-true", "deepspeed", DeepSpeedPrecisionPlugin, marks=RunIf(deepspeed=True, mps=False)),
1030+
pytest.param("16-mixed", "deepspeed", DeepSpeedPrecisionPlugin, marks=RunIf(deepspeed=True, mps=False)),
1031+
pytest.param("bf16-mixed", "deepspeed", DeepSpeedPrecisionPlugin, marks=RunIf(deepspeed=True, mps=False)),
10261032
],
10271033
)
1028-
def test_precision_selection(precision_str, precision_cls):
1029-
connector = _AcceleratorConnector(precision=precision_str)
1030-
assert isinstance(connector.precision_plugin, precision_cls)
1034+
def test_precision_selection(precision_str, strategy_str, expected_precision_cls):
1035+
connector = _AcceleratorConnector(precision=precision_str, strategy=strategy_str)
1036+
assert isinstance(connector.precision_plugin, expected_precision_cls)

0 commit comments

Comments
 (0)