Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source-pytorch/api_references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ precision

DeepSpeedPrecisionPlugin
DoublePrecisionPlugin
FSDPMixedPrecisionPlugin
FSDPPrecisionPlugin
HalfPrecisionPlugin
MixedPrecisionPlugin
PrecisionPlugin
Expand Down
2 changes: 1 addition & 1 deletion docs/source-pytorch/extensions/plugins.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ The full list of built-in precision plugins is listed below.

DeepSpeedPrecisionPlugin
DoublePrecisionPlugin
FSDPMixedPrecisionPlugin
FSDPPrecisionPlugin
MixedPrecisionPlugin
PrecisionPlugin
XLABf16PrecisionPlugin
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/fabric/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ def _convert_precision_to_unified_args(precision: _PRECISION_INPUT) -> _PRECISIO
if precision in get_args(_PRECISION_INPUT_STR_ALIAS):
if str(precision)[:2] not in ("32", "64"):
rank_zero_warn(
f"{precision} is supported for historical reasons but its usage is discouraged. "
f"`precision={precision}` is supported for historical reasons but its usage is discouraged. "
f"Please set your precision to {_PRECISION_INPUT_STR_ALIAS_CONVERSION[precision]} instead!"
)
precision = _PRECISION_INPUT_STR_ALIAS_CONVERSION[precision]
Expand Down
16 changes: 8 additions & 8 deletions src/lightning/fabric/plugins/precision/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@


class FSDPPrecision(Precision):
"""Precision plugin training with Fully Sharded Data Parallel (FSDP).
"""Precision plugin for training with Fully Sharded Data Parallel (FSDP).

.. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature.

Expand Down Expand Up @@ -150,6 +150,13 @@ def optimizer_step(
self.scaler.update()
return step_output

def unscale_gradients(self, optimizer: Optimizer) -> None:
scaler = self.scaler
if scaler is not None:
if _optimizer_handles_unscaling(optimizer):
raise NotImplementedError("Gradient clipping is not implemented for optimizers handling the unscaling.")
scaler.unscale_(optimizer) # type: ignore[arg-type] # ShardedGradScaler has wrong type annotation

def state_dict(self) -> Dict[str, Any]:
if self.scaler is not None:
return self.scaler.state_dict()
Expand All @@ -163,10 +170,3 @@ def _autocast_context_manager(self) -> torch.autocast:
# the dtype could be automatically inferred but we need to manually set it due to a bug upstream
# https://github.com/pytorch/pytorch/issues/67233
return torch.autocast("cuda", dtype=self._desired_input_dtype)

def unscale_gradients(self, optimizer: Optimizer) -> None:
scaler = self.scaler
if scaler is not None:
if _optimizer_handles_unscaling(optimizer):
raise NotImplementedError("Gradient clipping is not implemented for optimizers handling the unscaling.")
scaler.unscale_(optimizer) # type: ignore[arg-type] # ShardedGradScaler has wrong type annotation
2 changes: 1 addition & 1 deletion src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Allowed accessing rank information in the main process before processes are launched when using the `XLAStrategy` ([#18194](https://github.com/Lightning-AI/lightning/pull/18194))


- Added support for true half-precision training via `Trainer(precision="16-true"|"bf16-true")` ([#18193](https://github.com/Lightning-AI/lightning/pull/18193), [#18217](https://github.com/Lightning-AI/lightning/pull/18217))
- Added support for true half-precision training via `Trainer(precision="16-true"|"bf16-true")` ([#18193](https://github.com/Lightning-AI/lightning/pull/18193), [#18217](https://github.com/Lightning-AI/lightning/pull/18217), [#18219](https://github.com/Lightning-AI/lightning/pull/18219))


- Added automatic process cleanup to avoid zombie child processes and stalls when exceptions are raised ([#18218](https://github.com/Lightning-AI/lightning/pull/18218))
Expand Down
3 changes: 2 additions & 1 deletion src/lightning/pytorch/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from lightning.pytorch.plugins.precision.amp import MixedPrecisionPlugin
from lightning.pytorch.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin
from lightning.pytorch.plugins.precision.double import DoublePrecisionPlugin
from lightning.pytorch.plugins.precision.fsdp import FSDPMixedPrecisionPlugin
from lightning.pytorch.plugins.precision.fsdp import FSDPMixedPrecisionPlugin, FSDPPrecisionPlugin
from lightning.pytorch.plugins.precision.half import HalfPrecisionPlugin
from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin
from lightning.pytorch.plugins.precision.xla import XLAPrecisionPlugin
Expand All @@ -26,6 +26,7 @@
"MixedPrecisionPlugin",
"PrecisionPlugin",
"FSDPMixedPrecisionPlugin",
"FSDPPrecisionPlugin",
"XLAPrecisionPlugin",
"XLABf16PrecisionPlugin",
"LayerSync",
Expand Down
3 changes: 2 additions & 1 deletion src/lightning/pytorch/plugins/precision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from lightning.pytorch.plugins.precision.amp import MixedPrecisionPlugin
from lightning.pytorch.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin
from lightning.pytorch.plugins.precision.double import DoublePrecisionPlugin
from lightning.pytorch.plugins.precision.fsdp import FSDPMixedPrecisionPlugin
from lightning.pytorch.plugins.precision.fsdp import FSDPMixedPrecisionPlugin, FSDPPrecisionPlugin
from lightning.pytorch.plugins.precision.half import HalfPrecisionPlugin
from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin
from lightning.pytorch.plugins.precision.xla import XLAPrecisionPlugin
Expand All @@ -24,6 +24,7 @@
"DeepSpeedPrecisionPlugin",
"DoublePrecisionPlugin",
"FSDPMixedPrecisionPlugin",
"FSDPPrecisionPlugin",
"HalfPrecisionPlugin",
"MixedPrecisionPlugin",
"PrecisionPlugin",
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/plugins/precision/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Any, Callable, cast, Dict, Generator, Literal, Optional, Union
from typing import Any, Callable, Dict, Generator, Literal, Optional, Union

import torch
from torch import Tensor
Expand Down Expand Up @@ -46,7 +46,7 @@ def __init__(
f" Precision must be '16-mixed' or 'bf16-mixed'."
)

self.precision = cast(Literal["16-mixed", "bf16-mixed"], str(precision))
self.precision = precision
if scaler is None and self.precision == "16-mixed":
with _patch_cuda_is_available():
# if possible, we defer CUDA initialization to support strategies that will attempt forks
Expand Down
147 changes: 130 additions & 17 deletions src/lightning/pytorch/plugins/precision/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,36 +12,71 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Any, Generator, Literal, Optional, TYPE_CHECKING
from typing import Any, Callable, Dict, Generator, Literal, Optional, TYPE_CHECKING

import torch

from lightning_utilities import apply_to_collection
from torch import Tensor
from typing_extensions import get_args

import lightning.pytorch as pl
from lightning.fabric.plugins.precision.amp import _optimizer_handles_unscaling
from lightning.fabric.plugins.precision.fsdp import _PRECISION_INPUT
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12
from lightning.pytorch.plugins.precision.amp import MixedPrecisionPlugin
from lightning.fabric.utilities.rank_zero import rank_zero_deprecation
from lightning.fabric.utilities.types import Optimizable
from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin
from lightning.pytorch.utilities.exceptions import MisconfigurationException

if TYPE_CHECKING:
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler


class FSDPMixedPrecisionPlugin(MixedPrecisionPlugin):
"""AMP for Fully Sharded Data Parallel (FSDP) Training.
class FSDPPrecisionPlugin(PrecisionPlugin):
"""Precision plugin for training with Fully Sharded Data Parallel (FSDP).

.. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature.

Args:
precision: Full precision (32-true), half precision (16-true, bf16-true) or
mixed precision (16-mixed, bf16-mixed).
scaler: An optional :class:`torch.distributed.fsdp.sharded_grad_scaler.ShardedGradScaler` to use.

Raises:
ValueError:
If unsupported ``precision`` is provided.

"""

def __init__(
self, precision: Literal["16-mixed", "bf16-mixed"], device: str, scaler: Optional["ShardedGradScaler"] = None
) -> None:
def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None) -> None:
if not _TORCH_GREATER_EQUAL_1_12:
raise MisconfigurationException("`FSDPMixedPrecisionPlugin` is supported from PyTorch v1.12.0 onwards.")
raise NotImplementedError("`FSDPPrecisionPlugin` is supported from PyTorch v1.12.0 onwards.")

supported_precision = get_args(_PRECISION_INPUT)
if precision not in supported_precision:
raise ValueError(
f"`precision={precision!r})` is not supported in FSDP."
f" `precision` must be one of: {supported_precision}."
)

from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler

super().__init__(
precision, device, scaler=(ShardedGradScaler() if scaler is None and str(precision) == "16-mixed" else None)
)
if scaler is not None and self.precision != "16-mixed":
raise ValueError(f"`precision={precision!r}` does not use a scaler, found {scaler}.")

self.scaler = ShardedGradScaler() if scaler is None and precision == "16-mixed" else None
self.precision = precision

precision_to_type = {
"bf16-mixed": torch.bfloat16,
"16-mixed": torch.float16,
"bf16-true": torch.bfloat16,
"16-true": torch.float16,
"32-true": torch.float32,
}
self._desired_input_dtype = precision_to_type[self.precision]

def clip_grad_by_norm(self, *_: Any, **__: Any) -> None:
# see https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_
Expand All @@ -66,6 +101,8 @@ def mixed_precision_config(self) -> "TorchMixedPrecision":
param_dtype = reduce_dtype = buffer_dtype = torch.float16
elif self.precision == "bf16-true":
param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16
elif self.precision == "32-true":
param_dtype = reduce_dtype = buffer_dtype = torch.float32
else:
raise MisconfigurationException(f"Was unable to infer precision type, received {self.precision!r}.")

Expand All @@ -88,10 +125,86 @@ def init_context(self) -> Generator[None, None, None]:
torch.set_default_dtype(default_dtype)

@contextmanager
def forward_context(self) -> Generator[None, None, None]:
"""For FSDP, this context manager is a no-op since conversion is already handled internally.
def forward_context(self) -> Generator:
if "mixed" in self.precision:
with self._autocast_context_manager():
yield
else:
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(self._desired_input_dtype)
yield
torch.set_default_dtype(default_dtype)

def convert_input(self, data: Any) -> Any:
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_input_dtype)

def convert_output(self, data: Any) -> Any:
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype())

def pre_backward(self, tensor: Tensor, module: "pl.LightningModule") -> Tensor: # type: ignore[override]
if self.scaler is not None:
tensor = self.scaler.scale(tensor) # type: ignore[assignment]
return super().pre_backward(tensor, module)

def optimizer_step( # type: ignore[override]
self,
optimizer: Optimizable,
model: "pl.LightningModule",
closure: Callable[[], Any],
**kwargs: Any,
) -> Any:
if self.scaler is None:
# skip scaler logic, as bfloat16 does not require scaler
return super().optimizer_step(optimizer, model=model, closure=closure, **kwargs)
closure_result = closure()

if not _optimizer_handles_unscaling(optimizer):
# Unscaling needs to be performed here in case we are going to apply gradient clipping.
# Optimizers that perform unscaling in their `.step()` method are not supported (e.g., fused Adam).
# Note: `unscale` happens after the closure is executed, but before the `on_before_optimizer_step` hook.
self.scaler.unscale_(optimizer) # type: ignore[arg-type]

self._after_closure(model, optimizer)
skipped_backward = closure_result is None
# in manual optimization, the closure does not return a value
if not model.automatic_optimization or not skipped_backward:
# note: the scaler will skip the `optimizer.step` if nonfinite gradients are found
step_output = self.scaler.step(optimizer, **kwargs) # type: ignore[arg-type]
self.scaler.update()
return step_output
return closure_result

def state_dict(self) -> Dict[str, Any]:
if self.scaler is not None:
return self.scaler.state_dict()
return {}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
if self.scaler is not None:
self.scaler.load_state_dict(state_dict)

def _autocast_context_manager(self) -> torch.autocast:
# the dtype could be automatically inferred but we need to manually set it due to a bug upstream
# https://github.com/pytorch/pytorch/issues/67233
return torch.autocast("cuda", dtype=self._desired_input_dtype)


class FSDPMixedPrecisionPlugin(FSDPPrecisionPlugin):
"""AMP for Fully Sharded Data Parallel (FSDP) Training.

See: https://pytorch.org/docs/stable/fsdp.html for more details on mixed precision.
.. deprecated:: Use :class:`FSDPPrecisionPlugin` instead.

"""
yield
.. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature.

"""

def __init__(
self, precision: Literal["16-mixed", "bf16-mixed"], device: str, scaler: Optional["ShardedGradScaler"] = None
) -> None:
rank_zero_deprecation(
f"The `{type(self).__name__}` is deprecated."
" Use `lightning.pytorch.plugins.precision.FSDPPrecisionPlugin` instead."
)
if not _TORCH_GREATER_EQUAL_1_12:
raise MisconfigurationException("`FSDPMixedPrecisionPlugin` is supported from PyTorch v1.12.0 onwards.")
super().__init__(precision=precision, scaler=scaler)
4 changes: 2 additions & 2 deletions src/lightning/pytorch/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from lightning.fabric.utilities.types import ProcessGroup, ReduceOp
from lightning.pytorch.core.optimizer import LightningOptimizer
from lightning.pytorch.plugins.precision import PrecisionPlugin
from lightning.pytorch.plugins.precision.fsdp import FSDPMixedPrecisionPlugin
from lightning.pytorch.plugins.precision.fsdp import FSDPPrecisionPlugin
from lightning.pytorch.strategies.launchers.subprocess_script import _SubprocessScriptLauncher
from lightning.pytorch.strategies.parallel import ParallelStrategy
from lightning.pytorch.strategies.strategy import TBroadcast
Expand Down Expand Up @@ -214,7 +214,7 @@ def mixed_precision_config(self) -> Optional["MixedPrecision"]:
if self.mixed_precision:
return self.mixed_precision
plugin = self.precision_plugin
if isinstance(plugin, FSDPMixedPrecisionPlugin):
if isinstance(plugin, FSDPPrecisionPlugin):
return plugin.mixed_precision_config
return None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
CheckpointIO,
DeepSpeedPrecisionPlugin,
DoublePrecisionPlugin,
FSDPPrecisionPlugin,
HalfPrecisionPlugin,
MixedPrecisionPlugin,
PLUGIN_INPUT,
Expand All @@ -47,7 +48,6 @@
XLAPrecisionPlugin,
)
from lightning.pytorch.plugins.layer_sync import LayerSync, TorchSyncBatchNorm
from lightning.pytorch.plugins.precision.fsdp import FSDPMixedPrecisionPlugin
from lightning.pytorch.strategies import (
DDPStrategy,
DeepSpeedStrategy,
Expand Down Expand Up @@ -526,7 +526,8 @@ def _check_and_init_precision(self) -> PrecisionPlugin:

if isinstance(self.strategy, DeepSpeedStrategy):
return DeepSpeedPrecisionPlugin(self._precision_flag) # type: ignore[arg-type]

if isinstance(self.strategy, FSDPStrategy):
return FSDPPrecisionPlugin(self._precision_flag) # type: ignore[arg-type]
if self._precision_flag in ("16-true", "bf16-true"):
return HalfPrecisionPlugin(self._precision_flag) # type: ignore
if self._precision_flag == "32-true":
Expand All @@ -546,9 +547,6 @@ def _check_and_init_precision(self) -> PrecisionPlugin:
f"Using {'16bit' if self._precision_flag == '16-mixed' else 'bfloat16'} Automatic Mixed Precision (AMP)"
)
device = "cpu" if self._accelerator_flag == "cpu" else "cuda"

if isinstance(self.strategy, FSDPStrategy):
return FSDPMixedPrecisionPlugin(self._precision_flag, device) # type: ignore[arg-type]
return MixedPrecisionPlugin(self._precision_flag, device) # type: ignore[arg-type]

raise RuntimeError("No precision set")
Expand Down
4 changes: 3 additions & 1 deletion tests/tests_fabric/plugins/precision/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ def test_fsdp_precision_support(*_):
[
("16-mixed", (torch.float32, torch.float16, torch.float16)),
("bf16-mixed", (torch.float32, torch.bfloat16, torch.bfloat16)),
# TODO: add 16-true and bf16-true once supported
("16-true", (torch.float16, torch.float16, torch.float16)),
("bf16-true", (torch.bfloat16, torch.bfloat16, torch.bfloat16)),
("32-true", (torch.float32, torch.float32, torch.float32)),
],
)
def test_fsdp_precision_config(precision, expected):
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_fabric/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ def test_precision_conversion(patch1, patch2, precision, expected_precision, sho
with warn_context(
UserWarning,
match=(
f"{precision} is supported for historical reasons but its usage is discouraged. "
f"{precision}` is supported for historical reasons but its usage is discouraged. "
f"Please set your precision to {expected_precision} instead!"
),
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.plugins.precision.double import LightningDoublePrecisionModule
from lightning.pytorch.plugins.precision.fsdp import FSDPMixedPrecisionPlugin
from lightning.pytorch.strategies import DDPStrategy, FSDPStrategy
from tests_pytorch.helpers.runif import RunIf

Expand Down Expand Up @@ -49,3 +50,9 @@ def test_fsdp_activation_checkpointing(monkeypatch):
def test_double_precision_wrapper():
with pytest.deprecated_call(match=r"The `LightningDoublePrecisionModule` is deprecated and no longer needed"):
LightningDoublePrecisionModule(BoringModel())


@RunIf(min_torch="1.12")
def test_fsdp_mixed_precision_plugin():
with pytest.deprecated_call(match=r"The `FSDPMixedPrecisionPlugin` is deprecated"):
FSDPMixedPrecisionPlugin(precision="16-mixed", device="cuda")
Loading