Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements/_integrations/accelerators.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# validation HPU connectors
lightning-habana >=0.1.0
lightning-graphcore >=0.1.0.rc3
lightning-graphcore >=0.1.0.rc4
6 changes: 5 additions & 1 deletion src/lightning/fabric/plugins/precision/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,12 @@ def convert_module(self, module: Module) -> Module:

@contextmanager
def init_context(self) -> Generator[None, None, None]:
if "true" not in self.precision:
yield
return

default_dtype = torch.get_default_dtype()
torch.set_default_dtype(self._desired_dtype if "true" in self.precision else default_dtype)
torch.set_default_dtype(self._desired_dtype)
yield
torch.set_default_dtype(default_dtype)

Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Allowed accessing rank information in the main process before processes are launched when using the `XLAStrategy` ([#18194](https://github.com/Lightning-AI/lightning/pull/18194))


- Added support for true half-precision training via `Trainer(precision="16-true"|"bf16-true")` ([#18193](https://github.com/Lightning-AI/lightning/pull/18193))
- Added support for true half-precision training via `Trainer(precision="16-true"|"bf16-true")` ([#18193](https://github.com/Lightning-AI/lightning/pull/18193), [#18217](https://github.com/Lightning-AI/lightning/pull/18217))


- Added automatic process cleanup to avoid zombie child processes and stalls when exceptions are raised ([#18218](https://github.com/Lightning-AI/lightning/pull/18218))
Expand Down
43 changes: 38 additions & 5 deletions src/lightning/pytorch/plugins/precision/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, cast, Literal, Optional, TYPE_CHECKING, Union
from contextlib import contextmanager
from typing import Any, Callable, Generator, Optional, TYPE_CHECKING, Union

import torch
from lightning_utilities import apply_to_collection
from torch import Tensor
from torch.nn import Module
from torch.optim import LBFGS, Optimizer
from typing_extensions import get_args

import lightning.pytorch as pl
from lightning.fabric.plugins.precision.deepspeed import _PRECISION_INPUT
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor
from lightning.fabric.strategies.deepspeed import _DEEPSPEED_AVAILABLE
from lightning.fabric.utilities.types import Steppable
from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin
Expand All @@ -31,16 +37,16 @@

warning_cache = WarningCache()

_PRECISION_INPUT = Literal["32-true", "16-mixed", "bf16-mixed"]


class DeepSpeedPrecisionPlugin(PrecisionPlugin):
"""Precision plugin for DeepSpeed integration.

.. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature.

Args:
precision: Full precision (32), half precision (16) or bfloat16 precision (bf16).
precision: Full precision (32-true), half precision (16-true, bf16-true) or
mixed precision (16-mixed, bf16-mixed).

Raises:
ValueError:
If unsupported ``precision`` is provided.
Expand All @@ -53,7 +59,34 @@ def __init__(self, precision: _PRECISION_INPUT) -> None:
f"`Trainer(strategy='deepspeed', precision={precision!r})` is not supported."
f" `precision` must be one of: {supported_precision}."
)
self.precision = cast(_PRECISION_INPUT, str(precision))
self.precision = precision
precision_to_type = {
"bf16-mixed": torch.bfloat16,
"16-mixed": torch.float16,
"bf16-true": torch.bfloat16,
"16-true": torch.float16,
"32-true": torch.float32,
}
self._desired_dtype = precision_to_type[self.precision]

def convert_module(self, module: Module) -> Module:
if "true" in self.precision:
return module.to(dtype=self._desired_dtype)
return module

def convert_input(self, data: Any) -> Any:
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_dtype)

@contextmanager
def init_context(self) -> Generator[None, None, None]:
if "true" not in self.precision:
yield
return

default_dtype = torch.get_default_dtype()
torch.set_default_dtype(self._desired_dtype)
yield
torch.set_default_dtype(default_dtype)

def backward( # type: ignore[override]
self,
Expand Down
8 changes: 4 additions & 4 deletions src/lightning/pytorch/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, TYPE_CHECKING, Union

import torch
from lightning_utilities.core.apply_func import apply_to_collection
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer

Expand All @@ -43,7 +41,6 @@
from lightning.pytorch.core.optimizer import _init_optimizers_and_lr_schedulers
from lightning.pytorch.plugins.precision import PrecisionPlugin
from lightning.pytorch.strategies.ddp import DDPStrategy
from lightning.pytorch.strategies.utils import _fp_to_half
from lightning.pytorch.trainer.states import TrainerFn
from lightning.pytorch.utilities import GradClipAlgorithmType
from lightning.pytorch.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -894,5 +891,8 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
)

def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0) -> Any:
batch = apply_to_collection(batch, Tensor, function=_fp_to_half, precision=self.precision_plugin.precision)
# The strategy casts the input before moving to the device
# In all other strategies, the input gets converted in the `Strategy.*_step` methods
# TODO: standardize this for all strategies
batch = self.precision_plugin.convert_input(batch)
return super().batch_to_device(batch, device, dataloader_idx)
21 changes: 0 additions & 21 deletions src/lightning/pytorch/strategies/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,7 @@
# limitations under the License.
import importlib
from inspect import getmembers, isclass
from typing import Literal

import torch
from torch import Tensor

from lightning.fabric.plugins.precision.utils import _convert_fp_tensor
from lightning.fabric.strategies import _StrategyRegistry
from lightning.fabric.utilities.registry import _is_register_method_overridden
from lightning.pytorch.strategies.strategy import Strategy
Expand All @@ -30,19 +25,3 @@ def _call_register_strategies(registry: _StrategyRegistry, base_module: str) ->
for _, mod in getmembers(module, isclass):
if issubclass(mod, Strategy) and _is_register_method_overridden(mod, Strategy, "register_strategies"):
mod.register_strategies(registry)


def _fp_to_half(
tensor: Tensor,
precision: Literal[
"64-true",
"32-true",
"16-mixed",
"bf16-mixed",
],
) -> Tensor:
if str(precision) == "16-mixed":
return _convert_fp_tensor(tensor, torch.half)
if precision == "bf16-mixed":
return _convert_fp_tensor(tensor, torch.bfloat16)
return tensor
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,63 @@
# limitations under the License.

import pytest
import torch

from lightning.pytorch.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin


def test_invalid_precision_with_deepspeed_precision():
with pytest.raises(ValueError, match="is not supported. `precision` must be one of"):
DeepSpeedPrecisionPlugin(precision="64-true")


@pytest.mark.parametrize(
("precision", "expected_dtype"),
[
("32-true", torch.float32),
("bf16-mixed", torch.bfloat16),
("16-mixed", torch.float16),
("bf16-true", torch.bfloat16),
("16-true", torch.float16),
],
)
def test_selected_dtype(precision, expected_dtype):
plugin = DeepSpeedPrecisionPlugin(precision=precision)
assert plugin.precision == precision
assert plugin._desired_dtype == expected_dtype


@pytest.mark.parametrize(
("precision", "expected_dtype"),
[
("32-true", torch.float32),
("bf16-mixed", torch.float32),
("16-mixed", torch.float32),
("bf16-true", torch.bfloat16),
("16-true", torch.float16),
],
)
def test_module_init_context(precision, expected_dtype):
plugin = DeepSpeedPrecisionPlugin(precision=precision)
with plugin.init_context():
model = torch.nn.Linear(2, 2)
assert torch.get_default_dtype() == expected_dtype
assert model.weight.dtype == expected_dtype


@pytest.mark.parametrize(
("precision", "expected_dtype"),
[
("32-true", torch.float32),
("bf16-mixed", torch.float32),
("16-mixed", torch.float32),
("bf16-true", torch.bfloat16),
("16-true", torch.float16),
],
)
def test_convert_module(precision, expected_dtype):
precision = DeepSpeedPrecisionPlugin(precision=precision)
module = torch.nn.Linear(2, 2)
assert module.weight.dtype == module.bias.dtype == torch.float32
module = precision.convert_module(module)
assert module.weight.dtype == module.bias.dtype == expected_dtype
Original file line number Diff line number Diff line change
Expand Up @@ -1266,7 +1266,7 @@ def transfer_batch_to_device(self, batch, *args, **kwargs):
model = CustomBoringModel()
trainer = Trainer(strategy="deepspeed", devices=1, accelerator="cuda", precision="16-mixed")
trainer.strategy.connect(model)
batch = torch.zeros((1), dtype=torch.float32)
batch = torch.zeros(1, dtype=torch.float32)
batch = trainer.strategy.batch_to_device(batch)
assert batch.is_cuda
assert batch.dtype is torch.float16
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from lightning.pytorch.plugins.io import TorchCheckpointIO
from lightning.pytorch.plugins.layer_sync import LayerSync, TorchSyncBatchNorm
from lightning.pytorch.plugins.precision import (
DeepSpeedPrecisionPlugin,
DoublePrecisionPlugin,
HalfPrecisionPlugin,
MixedPrecisionPlugin,
Expand Down Expand Up @@ -1015,16 +1016,21 @@ def test_connector_sets_num_nodes(strategy, cuda_count_2):


@pytest.mark.parametrize(
("precision_str", "precision_cls"),
("precision_str", "strategy_str", "expected_precision_cls"),
[
("64-true", DoublePrecisionPlugin),
("32-true", PrecisionPlugin),
("16-true", HalfPrecisionPlugin),
("bf16-true", HalfPrecisionPlugin),
("16-mixed", MixedPrecisionPlugin),
("bf16-mixed", MixedPrecisionPlugin),
("64-true", "auto", DoublePrecisionPlugin),
("32-true", "auto", PrecisionPlugin),
("16-true", "auto", HalfPrecisionPlugin),
("bf16-true", "auto", HalfPrecisionPlugin),
("16-mixed", "auto", MixedPrecisionPlugin),
("bf16-mixed", "auto", MixedPrecisionPlugin),
pytest.param("32-true", "deepspeed", DeepSpeedPrecisionPlugin, marks=RunIf(deepspeed=True, mps=False)),
pytest.param("16-true", "deepspeed", DeepSpeedPrecisionPlugin, marks=RunIf(deepspeed=True, mps=False)),
pytest.param("bf16-true", "deepspeed", DeepSpeedPrecisionPlugin, marks=RunIf(deepspeed=True, mps=False)),
pytest.param("16-mixed", "deepspeed", DeepSpeedPrecisionPlugin, marks=RunIf(deepspeed=True, mps=False)),
pytest.param("bf16-mixed", "deepspeed", DeepSpeedPrecisionPlugin, marks=RunIf(deepspeed=True, mps=False)),
],
)
def test_precision_selection(precision_str, precision_cls):
connector = _AcceleratorConnector(precision=precision_str)
assert isinstance(connector.precision_plugin, precision_cls)
def test_precision_selection(precision_str, strategy_str, expected_precision_cls):
connector = _AcceleratorConnector(precision=precision_str, strategy=strategy_str)
assert isinstance(connector.precision_plugin, expected_precision_cls)