Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Handles initialization for FSDP models before wrapping and the Zero stage 3 initialization for DeepSpeed before sharding


- Added supports for empty weight initialization with `Fabric.init_module(empty_init=True)` for checkpoint loading ([#17627](https://github.com/Lightning-AI/lightning/pull/17627))
- Added support for empty weight initialization with `Fabric.init_module(empty_init=True)` for checkpoint loading ([#17627](https://github.com/Lightning-AI/lightning/pull/17627))


- Added support for meta-device initialization with `Fabric.init_module(empty_init=True)` in FSDP ([#18122](https://github.com/Lightning-AI/lightning/pull/18122))


- Added `lightning.fabric.plugins.Precision.init_context()` and `lightning.fabric.strategies.Strategy.module_init_context()` context managers to control model and tensor instantiation ([#17462](https://github.com/Lightning-AI/lightning/pull/17462))
Expand Down
45 changes: 31 additions & 14 deletions src/lightning/fabric/fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
Strategy,
XLAStrategy,
)
from lightning.fabric.strategies.fsdp import _has_meta_device_parameters
from lightning.fabric.strategies.launchers import _MultiProcessingLauncher, _XLALauncher
from lightning.fabric.strategies.strategy import _Sharded, TBroadcast
from lightning.fabric.utilities import move_data_to_device
Expand Down Expand Up @@ -224,8 +225,9 @@ def setup(

module = _FabricModule(module, self._precision, original_module=original_module)

# Update the _DeviceDtypeModuleMixin's device parameter
module.to(self.device if move_to_device else next(module.parameters(), torch.tensor(0)).device)
if not _has_meta_device_parameters(module):
# Update the _DeviceDtypeModuleMixin's device parameter
module.to(self.device if move_to_device else next(module.parameters(), torch.tensor(0)).device)

optimizers = [
_FabricOptimizer(optimizer=optimizer, strategy=self._strategy, callbacks=self._callbacks)
Expand Down Expand Up @@ -384,7 +386,7 @@ def backward(self, tensor: Tensor, *args: Any, model: Optional[_FabricModule] =
if isinstance(self._strategy, DeepSpeedStrategy):
if model is None:
if self._models_setup == 0:
raise RuntimeError("No models were set up for backward. Did you forget to call `self.setup()`?")
raise RuntimeError("No models were set up for backward. Did you forget to call `fabric.setup()`?")
if self._models_setup > 1:
raise ValueError(
"When using multiple models + deepspeed, please provide the model used to perform"
Expand Down Expand Up @@ -589,14 +591,14 @@ def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> Gener
Example::

# Accumulate gradient 8 batches at a time
with self.no_backward_sync(model, enabled=(batch_idx % 8 != 0)):
with fabric.no_backward_sync(model, enabled=(batch_idx % 8 != 0)):
output = model(input)
loss = ...
self.backward(loss)
fabric.backward(loss)
...

For those strategies that don't support it, a warning is emitted. For single-device strategies, it is a no-op.
Both the model's `.forward()` and the `self.backward()` call need to run under this context.
Both the model's `.forward()` and the `fabric.backward()` call need to run under this context.

Args:
module: The module for which to control the gradient synchronization.
Expand All @@ -606,8 +608,8 @@ def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> Gener
module = _unwrap_compiled(module)
if not isinstance(module, _FabricModule):
raise TypeError(
"You need to set up the model first before you can call `self.no_backward_sync()`:"
" `model = self.setup(model, ...)`"
"You need to set up the model first before you can call `fabric.no_backward_sync()`:"
" `model = fabric.setup(model, ...)`"
)
if not enabled or isinstance(self._strategy, (SingleDeviceStrategy, XLAStrategy)):
context = nullcontext()
Expand Down Expand Up @@ -957,12 +959,20 @@ def _validate_setup(self, module: nn.Module, optimizers: Sequence[Optimizer]) ->
if any(isinstance(opt, _FabricOptimizer) for opt in optimizers):
raise ValueError("An optimizer should be passed only once to the `setup` method.")

if isinstance(self._strategy, FSDPStrategy) and not _TORCH_GREATER_EQUAL_2_0:
raise RuntimeError(
f"The `{type(self).__name__}` requires the model and optimizer(s) to be set up separately."
" Create and set up the model first through `model = self.setup_module(model)`. Then create the"
" optimizer and set it up: `optimizer = self.setup_optimizer(optimizer)`."
)
if isinstance(self._strategy, FSDPStrategy):
if not _TORCH_GREATER_EQUAL_2_0:
raise RuntimeError(
f"The `{type(self).__name__}` requires the model and optimizer(s) to be set up separately."
" Create and set up the model first through `model = fabric.setup_module(model)`. Then create the"
" optimizer and set it up: `optimizer = fabric.setup_optimizers(optimizer)`."
)
if any(_has_meta_device_parameters(optimizer) for optimizer in optimizers):
raise RuntimeError(
"The optimizer has references to the model's meta-device parameters. Materializing them is"
" is currently not supported unless you to set up the model and optimizer(s) separately."
" Create and set up the model first through `model = fabric.setup_module(model)`. Then create the"
" optimizer and set it up: `optimizer = fabric.setup_optimizers(optimizer)`."
)

def _validate_setup_module(self, module: nn.Module) -> None:
self._validate_launched()
Expand All @@ -983,6 +993,13 @@ def _validate_setup_optimizers(self, optimizers: Sequence[Optimizer]) -> None:
if any(isinstance(opt, _FabricOptimizer) for opt in optimizers):
raise ValueError("An optimizer should be passed only once to the `setup_optimizers` method.")

if any(_has_meta_device_parameters(optimizer) for optimizer in optimizers):
raise RuntimeError(
"The optimizer has references to the model's meta-device parameters. Materializing them is"
" is currently not supported. Create the optimizer after setting up the model, then call"
" `fabric.setup_optimizers(optimizer)`."
)

def _validate_setup_dataloaders(self, dataloaders: Sequence[DataLoader]) -> None:
self._validate_launched()
if not dataloaders:
Expand Down
30 changes: 26 additions & 4 deletions src/lightning/fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

import torch
from torch import Tensor
from torch.nn import Module
from torch.nn import Module, Parameter
from torch.optim import Optimizer
from typing_extensions import TypeGuard

Expand Down Expand Up @@ -264,6 +264,11 @@ def setup_module(self, module: Module) -> Module:
from torch.distributed.fsdp import FullyShardedDataParallel

if any(isinstance(mod, FullyShardedDataParallel) for mod in module.modules()):
# The user has wrapped their submodules manually, don't apply the auto wrap policy.
if _has_meta_device_parameters(module):
rank_zero_warn(
"The model is already wrapped in `FSDP` but there are still parameters on the meta device."
)
if "auto_wrap_policy" in self._fsdp_kwargs:
rank_zero_warn(
"A FSDP `auto_wrap_policy` is set, but the model is already wrapped. The policy will be ignored."
Expand Down Expand Up @@ -317,9 +322,16 @@ def module_to_device(self, module: Module) -> None:

@contextmanager
def module_init_context(self, empty_init: Optional[bool] = None) -> Generator[None, None, None]:
# TODO: Use the meta device and reset parameters after https://github.com/pytorch/pytorch/issues/90465
# is resolved. For now, the module will get moved to the device in `setup_module`.
empty_init_context = _EmptyInit(enabled=bool(empty_init)) if _TORCH_GREATER_EQUAL_1_13 else nullcontext()
empty_init_context: Union[torch.device, _EmptyInit, nullcontext]
if _TORCH_GREATER_EQUAL_2_1 and empty_init:
# Materialization happens in `setup`. When modules get wrapped by FSDP, the sequence of operations is:
# 1) materialize module 2) call `reset_parameters()` 3) shard the module.
# These operations are applied to each submodule 'bottom up' in the module hierarchy.
empty_init_context = torch.device("meta")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this break the loading of checkpoints for models that haven't been FSDP wrapped with setup yet? For instance: https://github.com/Lightning-AI/lit-gpt/blob/1900b80424825cb221af0b63d19dd33b027d9aff/generate/base.py#L145-L151

Wouldn't load_state_dict need assign=True now? pytorch/pytorch#96161 (comment)

elif _TORCH_GREATER_EQUAL_1_13:
empty_init_context = _EmptyInit(enabled=bool(empty_init))
else:
empty_init_context = nullcontext()
with empty_init_context, self.precision.init_context(), self.module_sharded_context():
yield

Expand Down Expand Up @@ -841,6 +853,16 @@ def _load_raw_module_state(state_dict: Dict[str, Any], module: Module, strict: b
module.load_state_dict(state_dict, strict=strict)


def _has_meta_device_parameters(obj: Union[Module, Optimizer]) -> bool:
if isinstance(obj, Optimizer):
return any(
t.is_meta for param_group in obj.param_groups for t in param_group["params"] if isinstance(t, Parameter)
)
if isinstance(obj, Module):
return any(t.is_meta for t in obj.parameters())
raise TypeError(f"Expected `torch.nn.Module` or `torch.optim.Optimizer`, got: {type(obj).__name__}")


def _no_op() -> None:
pass

Expand Down
25 changes: 24 additions & 1 deletion tests/tests_fabric/strategies/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@
from lightning.fabric import Fabric
from lightning.fabric.plugins.environments import LightningEnvironment
from lightning.fabric.strategies import FSDPStrategy
from lightning.fabric.strategies.fsdp import _FSDPBackwardSyncControl, fsdp_overlap_step_with_backward
from lightning.fabric.strategies.fsdp import (
_FSDPBackwardSyncControl,
_has_meta_device_parameters,
fsdp_overlap_step_with_backward,
)
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _TORCH_GREATER_EQUAL_2_1
from tests_fabric.helpers.runif import RunIf
from tests_fabric.strategies.test_single_device import _MyFabricGradNorm
Expand Down Expand Up @@ -395,6 +399,25 @@ def test_set_timeout(init_process_group_mock):
)


def test_has_meta_device_parameters():
"""Test that the `_has_meta_device_parameters` function can find meta-device parameters in models and
optimizers."""
# nn.Module
module = nn.Linear(2, 2)
meta_module = nn.Linear(2, 2, device="meta")
assert not _has_meta_device_parameters(module)
assert _has_meta_device_parameters(meta_module)
assert _has_meta_device_parameters(nn.Sequential(module, meta_module, nn.ReLU()))
# optim.Optimizer
optimizer = torch.optim.SGD(module.parameters(), lr=0.1)
meta_optimizer = torch.optim.SGD(meta_module.parameters(), lr=0.1)
assert not _has_meta_device_parameters(optimizer)
assert _has_meta_device_parameters(meta_optimizer)
# unsupported objects
with pytest.raises(TypeError, match="Expected `torch.nn.Module` or `torch.optim.Optimizer`"):
_has_meta_device_parameters(None)


class SubBlock(nn.Sequential):
def __init__(self, feature_dim: int) -> None:
super().__init__(
Expand Down
48 changes: 34 additions & 14 deletions tests/tests_fabric/strategies/test_fsdp_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@
from lightning.fabric import Fabric
from lightning.fabric.plugins import FSDPPrecision
from lightning.fabric.strategies import FSDPStrategy
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _TORCH_GREATER_EQUAL_2_0
from lightning.fabric.utilities.imports import (
_TORCH_GREATER_EQUAL_1_12,
_TORCH_GREATER_EQUAL_2_0,
_TORCH_GREATER_EQUAL_2_1,
)
from lightning.fabric.wrappers import _FabricOptimizer
from tests_fabric.helpers.models import BoringFabric
from tests_fabric.helpers.runif import RunIf
Expand Down Expand Up @@ -393,21 +397,27 @@ def test_module_init_context(precision, expected_dtype):
)
fabric.launch()

with fabric.init_module():
model = torch.nn.Linear(100, 100, bias=False)
def _run_setup_assertions(empty_init, expected_device):
with fabric.init_module(empty_init=empty_init):
model = torch.nn.Linear(100, 100, bias=False)

# The model is on the CPU until after `.setup()``
# TODO: Support initialization on meta device
expected_device = torch.device("cpu")
assert model.weight.device == expected_device
assert model.weight.dtype == expected_dtype
# The model is on the CPU/meta-device until after `.setup()``
assert model.weight.device == expected_device
assert model.weight.dtype == expected_dtype
model = fabric.setup(model)
# Parameters get sharded in `.setup()` and moved to the target device
assert model.weight.device == torch.device("cuda", fabric.local_rank)
assert model.weight.dtype == expected_dtype

optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
model, optimizer = fabric.setup(model, optimizer)
# Case 1: No empty init
_run_setup_assertions(empty_init=False, expected_device=torch.device("cpu"))

# Parameters get sharded in `.setup()` and moved to the target device
assert model.weight.device == torch.device("cuda", fabric.local_rank)
assert model.weight.dtype == expected_dtype
if _TORCH_GREATER_EQUAL_2_1:
# Case 2: Empty-init with PyTorch >= 2.1 supports meta device
_run_setup_assertions(empty_init=True, expected_device=torch.device("meta"))
else:
# Case 2: Empty-init with PyTorch < 2.1 only supports `torch.empty()`-init
_run_setup_assertions(empty_init=True, expected_device=torch.device("cpu"))


@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.0.0")
Expand Down Expand Up @@ -460,7 +470,7 @@ def test_fsdp_manual_activation_checkpointing():


@RunIf(min_torch="1.12", min_cuda_gpus=1)
def test_rewrap_warning():
def test_rewrap_warnings():
from torch.distributed.fsdp import FullyShardedDataParallel
from torch.distributed.fsdp.wrap import wrap

Expand All @@ -473,3 +483,13 @@ def test_rewrap_warning():
model = fabric.setup(model)
assert not isinstance(model._forward_module, FullyShardedDataParallel)
assert isinstance(model._forward_module[2], FullyShardedDataParallel)

if not _TORCH_GREATER_EQUAL_2_1:
return

with fabric.init_module(empty_init=True):
model = torch.nn.Sequential(torch.nn.Linear(1, 1), torch.nn.ReLU(), wrap(torch.nn.Linear(1, 1)))
assert model[0].weight.is_meta
with pytest.warns(match="there are still parameters on the meta device"):
fabric_model = fabric.setup(model)
assert next(fabric_model.parameters()).is_meta
16 changes: 16 additions & 0 deletions tests/tests_fabric/test_fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,22 @@ def test_setup_optimizers_not_supported(strategy_cls):
fabric.setup_optimizers(optimizer)


@RunIf(min_cuda_gpus=1, min_torch="2.1")
def test_setup_optimizer_on_meta_device():
"""Test that the setup-methods validate that the optimizer doesn't have references to meta-device
parameters."""
fabric = Fabric(strategy="fsdp", devices=1)
fabric._launched = True # pretend we have launched multiple processes
with fabric.init_module(empty_init=True):
model = nn.Linear(1, 2)
assert model.weight.is_meta
optimizer = torch.optim.Adam(model.parameters()) # optimizer references meta device params
with pytest.raises(RuntimeError, match="The optimizer has references to the model's meta-device parameters"):
fabric.setup(model, optimizer)
with pytest.raises(RuntimeError, match="The optimizer has references to the model's meta-device parameters"):
fabric.setup_optimizers(optimizer)


def test_setup_tracks_num_models():
"""Test that setup() tracks how many times it has setup a model."""
fabric = Fabric(devices=1)
Expand Down