Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/lightning_fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

-
- Added support for saving and loading DeepSpeed checkpoints through `Fabric.save/load()` ([#16452](https://github.com/Lightning-AI/lightning/pull/16452))


### Changed
Expand Down
164 changes: 150 additions & 14 deletions src/lightning_fabric/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
import os
import platform
from contextlib import contextmanager
from itertools import chain
from pathlib import Path
from typing import Any, Dict, Generator, Iterable, List, Mapping, Optional, Tuple, TYPE_CHECKING, Union
from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, TYPE_CHECKING, Union

import torch
from lightning_utilities.core.imports import RequirementCache
Expand All @@ -31,7 +32,7 @@
from lightning_fabric.strategies.ddp import DDPStrategy
from lightning_fabric.strategies.strategy import _Sharded
from lightning_fabric.utilities.distributed import log
from lightning_fabric.utilities.rank_zero import rank_zero_info, rank_zero_only
from lightning_fabric.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn
from lightning_fabric.utilities.seed import reset_seed
from lightning_fabric.utilities.types import _PATH

Expand Down Expand Up @@ -365,24 +366,124 @@ def module_sharded_context(self) -> Generator[None, None, None]:
def save_checkpoint(
self, path: _PATH, state: Dict[str, Union[Module, Optimizer, Any]], storage_options: Optional[Any] = None
) -> None:
raise NotImplementedError
"""Save model, optimizer, and other state in a checkpoint directory.

Args:
path: A path to where the files should be saved
state: A dictionary with contents to be saved. If the dict contains modules or optimizers, their
state-dict will be retrieved and converted automatically.
storage_options: Unused by this strategy, since it doesn't use a ``CheckpointIO`` plugin.

Raises:
TypeError:
If the unused ``storage_options`` gets passed.
ValueError:
When no :class:`deepspeed.DeepSpeedEngine` objects were found in the state, or when multiple
:class:`deepspeed.DeepSpeedEngine` objects were found.
"""
if storage_options is not None:
raise TypeError(
"`DeepSpeedStrategy.save_checkpoint(..., storage_options=...)` is not supported because"
" `DeepSpeedStrategy` does not use the `CheckpointIO`."
)

engines = _get_deepspeed_engines_from_state(state)
if len(engines) == 0:
raise ValueError(
"Could not find a DeepSpeed model in the provided checkpoint state. Please provide the model as"
" part of the state like so: `save_checkpoint(..., state={'model': model, ...})`. Make sure"
" you set up the model (and optimizers if any) through the strategy before saving the checkpoint."
)
elif len(engines) > 1:
raise ValueError(
"Found multiple DeepSpeed engine modules in the given state. Saving checkpoints with DeepSpeed is"
" currently limited to a single model per checkpoint. To save multiple models, call the"
" save method for each model separately with a different path."
)
engine = engines[0]

# broadcast the path from rank 0 to ensure all the states are saved in a common path
path = self.broadcast(path)

# split the checkpoint into two parts:
# 1) the deepspeed engine encapsulating both the model and optionally the optimizer(s)
# 2) the rest of the user's state, which in deepspeed is called `client state`
excluded_objects = (engine, engine.optimizer) if engine.optimizer is not None else (engine,)
state = {k: v for k, v in state.items() if v not in excluded_objects}
_validate_state_keys(state)
# there might be other stateful objects unrelated to the deepspeed engine - convert them to a state_dict
state = self._convert_stateful_objects_in_state(state)
# use deepspeed's internal checkpointing function to handle partitioned weights across processes
engine.save_checkpoint(path, client_state=state, tag="checkpoint")

def load_checkpoint(
self, path: _PATH, state: Optional[Dict[str, Union[Module, Optimizer, Any]]] = None
) -> Dict[str, Any]:
raise NotImplementedError

def load_optimizer_state_dict(
self, optimizers: Union[Optimizer, Iterable[Optimizer]], checkpoint: Mapping[str, Any]
) -> None:
# override to do nothing, deepspeed engine already loaded the states in `load_checkpoint()`
pass
"""Load the contents from a checkpoint and restore the state of the given objects.

def load_module_state_dict(self, module: Module, checkpoint: Mapping[str, Any]) -> None:
# override to do nothing, deepspeed engine already loaded the weights in `load_checkpoint()`
Args:
path: A path to where the file is located
state: A dictionary of objects whose state will be restored in-place from the checkpoint path.
This should contain exactly one model, and the model must already be set up by DeepSpeed.

Returns:
Dictionary with the state inside DeepSpeed's engine

Raises:
ValueError:
If no state is provided, when no :class:`deepspeed.DeepSpeedEngine` objects were found in the
state, or when multiple :class:`deepspeed.DeepSpeedEngine` objects were found.
RuntimeError:
If DeepSpeed was unable to load the checkpoint due to missing files or because the checkpoint is
not in the expected DeepSpeed format.
"""
if self.load_full_weights and self.zero_stage_3:
self.module_to_device(module)
self._restore_zero_state(module, checkpoint)
# This code path to enables loading a checkpoint from a non-deepspeed checkpoint or from
# a consolidated checkpoint
path = self.broadcast(path)
return super().load_checkpoint(path=path, state=state)

if not state:
raise ValueError(
f"Got DeepSpeedStrategy.load_checkpoint(..., state={state!r}) but a state with at least "
f" a model instance to reload is required. Pass it in like so:"
" DeepSpeedStrategy.load_checkpoint(..., state={'model': model, ...})"
)

engines = _get_deepspeed_engines_from_state(state)
if len(engines) == 0:
raise ValueError(
"Could not find a DeepSpeed model in the provided checkpoint state. Please provide the model as"
" part of the state like so: `load_checkpoint(..., state={'model': model, ...})`. Make sure"
" you set up the model (and optimizers if any) through the strategy before loading the checkpoint."
)
elif len(engines) > 1:
raise ValueError(
"Found multiple DeepSpeed engine modules in the given state. Saving and loading checkpoints"
" with DeepSpeed is currently limited to a single model per checkpoint. To load multiple model"
" states, call the load method for each model checkpoint separately."
)
engine = engines[0]
optimzer_state_requested = bool(len([item for item in state.values() if isinstance(item, Optimizer)]))

torch.cuda.empty_cache()
_, client_state = engine.load_checkpoint(
path,
tag="checkpoint",
load_optimizer_states=optimzer_state_requested,
load_lr_scheduler_states=False,
load_module_strict=True, # TODO(fabric): make strict loading configurable
)
if client_state is None:
raise RuntimeError(
"DeepSpeed was unable to load the checkpoint. Ensure you passed in a DeepSpeed compatible checkpoint"
" or a single checkpoint file by setting `DeepSpeedStrategy(..., load_full_weights=True)`."
)
for k, v in client_state.copy().items():
if k not in state:
continue
state[k] = client_state.pop(k)
return client_state

@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
Expand Down Expand Up @@ -645,3 +746,38 @@ def _load_config(self, config: Optional[Union[_PATH, Dict[str, Any]]]) -> Option
config = json.load(f)
assert isinstance(config, dict) or config is None
return config


def _get_deepspeed_engines_from_state(state: Dict[str, Any]) -> List["deepspeed.DeepSpeedEngine"]:
from deepspeed import DeepSpeedEngine

modules = chain(*(module.modules() for module in state.values() if isinstance(module, Module)))
engines = [engine for engine in modules if isinstance(engine, DeepSpeedEngine)]
return engines


def _validate_state_keys(state: Dict[str, Any]) -> None:
# DeepSpeed merges the client state into its internal engine state when saving, but it does not check for
# colliding keys from the user. We explicitly check it here:
deepspeed_internal_keys = {
"module",
"buffer_names",
"optimizer",
"param_shapes",
"lr_scheduler",
"sparse_tensor_module_names",
"skipped_steps",
"global_steps",
"global_samples",
"dp_world_size",
"mp_world_size",
"ds_config",
"ds_version",
}
colliding_keys = deepspeed_internal_keys.intersection(state.keys())
if colliding_keys:
rank_zero_warn(
"Your state has keys that collide with DeepSpeed's internal engine state. This could result in your"
" values being overwritten by DeepSpeed. Consider changing the name of these keys to something else: "
+ ", ".join(colliding_keys)
)
170 changes: 170 additions & 0 deletions tests/tests_fabric/strategies/test_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import pytest
import torch
from tests_fabric.helpers.runif import RunIf
from torch.optim import Optimizer

from lightning_fabric.accelerators import CPUAccelerator
from lightning_fabric.strategies import DeepSpeedStrategy
Expand Down Expand Up @@ -151,3 +152,172 @@ def test_deepspeed_requires_joint_setup():
NotImplementedError, match=escape("does not support setting up the module and optimizer(s) independently")
):
strategy.setup_optimizer(Mock())


@RunIf(deepspeed=True)
def test_deepspeed_save_checkpoint_storage_options(tmp_path):
"""Test that the DeepSpeed strategy does not accept storage options for saving checkpoints."""
strategy = DeepSpeedStrategy()
with pytest.raises(TypeError, match=escape("DeepSpeedStrategy.save_checkpoint(..., storage_options=...)` is not")):
strategy.save_checkpoint(path=tmp_path, state=Mock(), storage_options=Mock())


@RunIf(deepspeed=True)
def test_deepspeed_save_checkpoint_one_deepspeed_engine_required(tmp_path):
"""Test that the DeepSpeed strategy can only save one DeepSpeedEngine per checkpoint."""
from deepspeed import DeepSpeedEngine

strategy = DeepSpeedStrategy()

# missing DeepSpeedEngine
with pytest.raises(ValueError, match="Could not find a DeepSpeed model in the provided checkpoint state."):
strategy.save_checkpoint(path=tmp_path, state={})
with pytest.raises(ValueError, match="Could not find a DeepSpeed model in the provided checkpoint state."):
strategy.save_checkpoint(path=tmp_path, state={"model": torch.nn.Linear(3, 3)})

# multiple DeepSpeedEngine
model1 = Mock(spec=torch.nn.Module)
model1.modules.return_value = [Mock(spec=DeepSpeedEngine)]
model2 = Mock(spec=torch.nn.Module)
model2.modules.return_value = [Mock(spec=DeepSpeedEngine)]
with pytest.raises(ValueError, match="Found multiple DeepSpeed engine modules in the given state."):
strategy.save_checkpoint(path=tmp_path, state={"model1": model1, "model2": model2})


@RunIf(deepspeed=True)
def test_deepspeed_save_checkpoint_client_state_separation(tmp_path):
"""Test that the DeepSpeed engine and optimizer get separated from the client state."""
from deepspeed import DeepSpeedEngine

strategy = DeepSpeedStrategy()

# Model only
model = Mock(spec=DeepSpeedEngine, optimizer=None)
model.modules.return_value = [model]
strategy.save_checkpoint(path=tmp_path, state={"model": model, "test": "data"})
# the client_state should not contain any deepspeed engine or deepspeed optimizer
model.save_checkpoint.assert_called_with(tmp_path, client_state={"test": "data"}, tag="checkpoint")

# Model and optimizer
optimizer = Mock()
model = Mock(spec=DeepSpeedEngine, optimizer=optimizer)
model.modules.return_value = [model]
strategy.save_checkpoint(path=tmp_path, state={"model": model, "optimizer": optimizer, "test": "data"})
# the client_state should not contain any deepspeed engine or deepspeed optimizer
model.save_checkpoint.assert_called_with(tmp_path, client_state={"test": "data"}, tag="checkpoint")


@RunIf(deepspeed=True)
def test_deepspeed_save_checkpoint_warn_colliding_keys(tmp_path):
"""Test that the strategy warns if there are keys in the user dict that collide internally with DeepSpeed."""
from deepspeed import DeepSpeedEngine

strategy = DeepSpeedStrategy()
optimizer = Mock()
model = Mock(spec=DeepSpeedEngine, optimizer=optimizer)
model.modules.return_value = [model]
# `mp_world_size` is an internal key
with pytest.warns(UserWarning, match="Your state has keys that collide with DeepSpeed's internal"):
strategy.save_checkpoint(path=tmp_path, state={"model": model, "optimizer": optimizer, "mp_world_size": 2})


@RunIf(deepspeed=True)
def test_deepspeed_load_checkpoint_no_state(tmp_path):
"""Test that DeepSpeed can't load the full state without access to a model instance from the user."""
strategy = DeepSpeedStrategy()
with pytest.raises(ValueError, match=escape("Got DeepSpeedStrategy.load_checkpoint(..., state=None")):
strategy.load_checkpoint(path=tmp_path, state=None)
with pytest.raises(ValueError, match=escape("Got DeepSpeedStrategy.load_checkpoint(..., state={})")):
strategy.load_checkpoint(path=tmp_path, state={})


@RunIf(deepspeed=True)
def test_deepspeed_load_checkpoint_one_deepspeed_engine_required(tmp_path):
"""Test that the DeepSpeed strategy can only load one DeepSpeedEngine per checkpoint."""
from deepspeed import DeepSpeedEngine

strategy = DeepSpeedStrategy()

# missing DeepSpeedEngine
with pytest.raises(ValueError, match="Could not find a DeepSpeed model in the provided checkpoint state."):
strategy.load_checkpoint(path=tmp_path, state={"other": "data"})
with pytest.raises(ValueError, match="Could not find a DeepSpeed model in the provided checkpoint state."):
strategy.load_checkpoint(path=tmp_path, state={"model": torch.nn.Linear(3, 3)})

# multiple DeepSpeedEngine
model1 = Mock(spec=torch.nn.Module)
model1.modules.return_value = [Mock(spec=DeepSpeedEngine)]
model2 = Mock(spec=torch.nn.Module)
model2.modules.return_value = [Mock(spec=DeepSpeedEngine)]
with pytest.raises(ValueError, match="Found multiple DeepSpeed engine modules in the given state."):
strategy.load_checkpoint(path=tmp_path, state={"model1": model1, "model2": model2})


@RunIf(deepspeed=True)
def test_deepspeed_load_checkpoint_client_state_missing(tmp_path):
"""Test that the DeepSpeed strategy raises a custom error when client state couldn't be loaded by DeepSpeed."""
from deepspeed import DeepSpeedEngine

strategy = DeepSpeedStrategy()
optimizer = Mock()
model = Mock(spec=DeepSpeedEngine, optimizer=optimizer)
model.modules.return_value = [model]

# If the DeepSpeed engine fails to load the checkpoint file (e.g., file not found), it prints a warning and
# returns None from its function call
model.load_checkpoint.return_value = [None, None]

# Check for our custom user error
with pytest.raises(RuntimeError, match="DeepSpeed was unable to load the checkpoint"):
strategy.load_checkpoint(path=tmp_path, state={"model": model, "optimizer": optimizer, "test": "data"})


@RunIf(deepspeed=True)
def test_deepspeed_load_checkpoint_state_updated_with_client_state(tmp_path):
"""Test that the DeepSpeed strategy properly updates the state variables and returns additional metadata."""
from deepspeed import DeepSpeedEngine

strategy = DeepSpeedStrategy()
optimizer = Mock()
model = Mock(spec=DeepSpeedEngine, optimizer=optimizer)
model.modules.return_value = [model]

# the client state contains the additional user data that was proveded when saving, plus some deepspeed metadata
loaded_client_state = {"user_data": {"iteration": 5}, "deepspeed_metadata": "data"}
model.load_checkpoint.return_value = [None, loaded_client_state]

state = {"model": model, "user_data": {"iteration": 0}}
metadata = strategy.load_checkpoint(path=tmp_path, state=state)

# the user's state gets updated with the loaded value
assert state == {"model": model, "user_data": {"iteration": 5}}
# additional metadata gets separated from client state
assert metadata == {"deepspeed_metadata": "data"}


@RunIf(deepspeed=True)
@pytest.mark.parametrize("optimzer_state_requested", [True, False])
def test_deepspeed_load_checkpoint_optimzer_state_requested(optimzer_state_requested, tmp_path):
"""Test that the DeepSpeed strategy loads the optimizer state only when requested."""
from deepspeed import DeepSpeedEngine

strategy = DeepSpeedStrategy()
optimizer = Mock(spec=Optimizer)
model = Mock(spec=DeepSpeedEngine, optimizer=optimizer)
model.modules.return_value = [model]

# required, otherwise mock cannot be unpacked
model.load_checkpoint.return_value = [None, {}]

state = {"model": model}
if optimzer_state_requested:
state["optimizer"] = optimizer

strategy.load_checkpoint(path=tmp_path, state=state)
model.load_checkpoint.assert_called_with(
tmp_path,
tag="checkpoint",
load_optimizer_states=optimzer_state_requested,
load_lr_scheduler_states=False,
load_module_strict=True,
)
Loading