Skip to content

Commit 7603dd0

Browse files
awaelchlicarmocca
andauthored
Fabric checkpointing 2/n: DeepSpeed implementation (#16452)
Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 4a802e0 commit 7603dd0

File tree

4 files changed

+409
-18
lines changed

4 files changed

+409
-18
lines changed

src/lightning_fabric/CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99

1010
### Added
1111

12-
-
12+
- Added support for saving and loading DeepSpeed checkpoints through `Fabric.save/load()` ([#16452](https://github.com/Lightning-AI/lightning/pull/16452))
1313

1414

1515
### Changed

src/lightning_fabric/strategies/deepspeed.py

Lines changed: 150 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717
import os
1818
import platform
1919
from contextlib import contextmanager
20+
from itertools import chain
2021
from pathlib import Path
21-
from typing import Any, Dict, Generator, Iterable, List, Mapping, Optional, Tuple, TYPE_CHECKING, Union
22+
from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, TYPE_CHECKING, Union
2223

2324
import torch
2425
from lightning_utilities.core.imports import RequirementCache
@@ -31,7 +32,7 @@
3132
from lightning_fabric.strategies.ddp import DDPStrategy
3233
from lightning_fabric.strategies.strategy import _Sharded
3334
from lightning_fabric.utilities.distributed import log
34-
from lightning_fabric.utilities.rank_zero import rank_zero_info, rank_zero_only
35+
from lightning_fabric.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn
3536
from lightning_fabric.utilities.seed import reset_seed
3637
from lightning_fabric.utilities.types import _PATH
3738

@@ -365,24 +366,124 @@ def module_sharded_context(self) -> Generator[None, None, None]:
365366
def save_checkpoint(
366367
self, path: _PATH, state: Dict[str, Union[Module, Optimizer, Any]], storage_options: Optional[Any] = None
367368
) -> None:
368-
raise NotImplementedError
369+
"""Save model, optimizer, and other state in a checkpoint directory.
370+
371+
Args:
372+
path: A path to where the files should be saved
373+
state: A dictionary with contents to be saved. If the dict contains modules or optimizers, their
374+
state-dict will be retrieved and converted automatically.
375+
storage_options: Unused by this strategy, since it doesn't use a ``CheckpointIO`` plugin.
376+
377+
Raises:
378+
TypeError:
379+
If the unused ``storage_options`` gets passed.
380+
ValueError:
381+
When no :class:`deepspeed.DeepSpeedEngine` objects were found in the state, or when multiple
382+
:class:`deepspeed.DeepSpeedEngine` objects were found.
383+
"""
384+
if storage_options is not None:
385+
raise TypeError(
386+
"`DeepSpeedStrategy.save_checkpoint(..., storage_options=...)` is not supported because"
387+
" `DeepSpeedStrategy` does not use the `CheckpointIO`."
388+
)
389+
390+
engines = _get_deepspeed_engines_from_state(state)
391+
if len(engines) == 0:
392+
raise ValueError(
393+
"Could not find a DeepSpeed model in the provided checkpoint state. Please provide the model as"
394+
" part of the state like so: `save_checkpoint(..., state={'model': model, ...})`. Make sure"
395+
" you set up the model (and optimizers if any) through the strategy before saving the checkpoint."
396+
)
397+
elif len(engines) > 1:
398+
raise ValueError(
399+
"Found multiple DeepSpeed engine modules in the given state. Saving checkpoints with DeepSpeed is"
400+
" currently limited to a single model per checkpoint. To save multiple models, call the"
401+
" save method for each model separately with a different path."
402+
)
403+
engine = engines[0]
404+
405+
# broadcast the path from rank 0 to ensure all the states are saved in a common path
406+
path = self.broadcast(path)
407+
408+
# split the checkpoint into two parts:
409+
# 1) the deepspeed engine encapsulating both the model and optionally the optimizer(s)
410+
# 2) the rest of the user's state, which in deepspeed is called `client state`
411+
excluded_objects = (engine, engine.optimizer) if engine.optimizer is not None else (engine,)
412+
state = {k: v for k, v in state.items() if v not in excluded_objects}
413+
_validate_state_keys(state)
414+
# there might be other stateful objects unrelated to the deepspeed engine - convert them to a state_dict
415+
state = self._convert_stateful_objects_in_state(state)
416+
# use deepspeed's internal checkpointing function to handle partitioned weights across processes
417+
engine.save_checkpoint(path, client_state=state, tag="checkpoint")
369418

370419
def load_checkpoint(
371420
self, path: _PATH, state: Optional[Dict[str, Union[Module, Optimizer, Any]]] = None
372421
) -> Dict[str, Any]:
373-
raise NotImplementedError
374-
375-
def load_optimizer_state_dict(
376-
self, optimizers: Union[Optimizer, Iterable[Optimizer]], checkpoint: Mapping[str, Any]
377-
) -> None:
378-
# override to do nothing, deepspeed engine already loaded the states in `load_checkpoint()`
379-
pass
422+
"""Load the contents from a checkpoint and restore the state of the given objects.
380423
381-
def load_module_state_dict(self, module: Module, checkpoint: Mapping[str, Any]) -> None:
382-
# override to do nothing, deepspeed engine already loaded the weights in `load_checkpoint()`
424+
Args:
425+
path: A path to where the file is located
426+
state: A dictionary of objects whose state will be restored in-place from the checkpoint path.
427+
This should contain exactly one model, and the model must already be set up by DeepSpeed.
428+
429+
Returns:
430+
Dictionary with the state inside DeepSpeed's engine
431+
432+
Raises:
433+
ValueError:
434+
If no state is provided, when no :class:`deepspeed.DeepSpeedEngine` objects were found in the
435+
state, or when multiple :class:`deepspeed.DeepSpeedEngine` objects were found.
436+
RuntimeError:
437+
If DeepSpeed was unable to load the checkpoint due to missing files or because the checkpoint is
438+
not in the expected DeepSpeed format.
439+
"""
383440
if self.load_full_weights and self.zero_stage_3:
384-
self.module_to_device(module)
385-
self._restore_zero_state(module, checkpoint)
441+
# This code path to enables loading a checkpoint from a non-deepspeed checkpoint or from
442+
# a consolidated checkpoint
443+
path = self.broadcast(path)
444+
return super().load_checkpoint(path=path, state=state)
445+
446+
if not state:
447+
raise ValueError(
448+
f"Got DeepSpeedStrategy.load_checkpoint(..., state={state!r}) but a state with at least "
449+
f" a model instance to reload is required. Pass it in like so:"
450+
" DeepSpeedStrategy.load_checkpoint(..., state={'model': model, ...})"
451+
)
452+
453+
engines = _get_deepspeed_engines_from_state(state)
454+
if len(engines) == 0:
455+
raise ValueError(
456+
"Could not find a DeepSpeed model in the provided checkpoint state. Please provide the model as"
457+
" part of the state like so: `load_checkpoint(..., state={'model': model, ...})`. Make sure"
458+
" you set up the model (and optimizers if any) through the strategy before loading the checkpoint."
459+
)
460+
elif len(engines) > 1:
461+
raise ValueError(
462+
"Found multiple DeepSpeed engine modules in the given state. Saving and loading checkpoints"
463+
" with DeepSpeed is currently limited to a single model per checkpoint. To load multiple model"
464+
" states, call the load method for each model checkpoint separately."
465+
)
466+
engine = engines[0]
467+
optimzer_state_requested = bool(len([item for item in state.values() if isinstance(item, Optimizer)]))
468+
469+
torch.cuda.empty_cache()
470+
_, client_state = engine.load_checkpoint(
471+
path,
472+
tag="checkpoint",
473+
load_optimizer_states=optimzer_state_requested,
474+
load_lr_scheduler_states=False,
475+
load_module_strict=True, # TODO(fabric): make strict loading configurable
476+
)
477+
if client_state is None:
478+
raise RuntimeError(
479+
"DeepSpeed was unable to load the checkpoint. Ensure you passed in a DeepSpeed compatible checkpoint"
480+
" or a single checkpoint file by setting `DeepSpeedStrategy(..., load_full_weights=True)`."
481+
)
482+
for k, v in client_state.copy().items():
483+
if k not in state:
484+
continue
485+
state[k] = client_state.pop(k)
486+
return client_state
386487

387488
@classmethod
388489
def register_strategies(cls, strategy_registry: Dict) -> None:
@@ -645,3 +746,38 @@ def _load_config(self, config: Optional[Union[_PATH, Dict[str, Any]]]) -> Option
645746
config = json.load(f)
646747
assert isinstance(config, dict) or config is None
647748
return config
749+
750+
751+
def _get_deepspeed_engines_from_state(state: Dict[str, Any]) -> List["deepspeed.DeepSpeedEngine"]:
752+
from deepspeed import DeepSpeedEngine
753+
754+
modules = chain(*(module.modules() for module in state.values() if isinstance(module, Module)))
755+
engines = [engine for engine in modules if isinstance(engine, DeepSpeedEngine)]
756+
return engines
757+
758+
759+
def _validate_state_keys(state: Dict[str, Any]) -> None:
760+
# DeepSpeed merges the client state into its internal engine state when saving, but it does not check for
761+
# colliding keys from the user. We explicitly check it here:
762+
deepspeed_internal_keys = {
763+
"module",
764+
"buffer_names",
765+
"optimizer",
766+
"param_shapes",
767+
"lr_scheduler",
768+
"sparse_tensor_module_names",
769+
"skipped_steps",
770+
"global_steps",
771+
"global_samples",
772+
"dp_world_size",
773+
"mp_world_size",
774+
"ds_config",
775+
"ds_version",
776+
}
777+
colliding_keys = deepspeed_internal_keys.intersection(state.keys())
778+
if colliding_keys:
779+
rank_zero_warn(
780+
"Your state has keys that collide with DeepSpeed's internal engine state. This could result in your"
781+
" values being overwritten by DeepSpeed. Consider changing the name of these keys to something else: "
782+
+ ", ".join(colliding_keys)
783+
)

tests/tests_fabric/strategies/test_deepspeed.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import pytest
2121
import torch
2222
from tests_fabric.helpers.runif import RunIf
23+
from torch.optim import Optimizer
2324

2425
from lightning_fabric.accelerators import CPUAccelerator
2526
from lightning_fabric.strategies import DeepSpeedStrategy
@@ -151,3 +152,172 @@ def test_deepspeed_requires_joint_setup():
151152
NotImplementedError, match=escape("does not support setting up the module and optimizer(s) independently")
152153
):
153154
strategy.setup_optimizer(Mock())
155+
156+
157+
@RunIf(deepspeed=True)
158+
def test_deepspeed_save_checkpoint_storage_options(tmp_path):
159+
"""Test that the DeepSpeed strategy does not accept storage options for saving checkpoints."""
160+
strategy = DeepSpeedStrategy()
161+
with pytest.raises(TypeError, match=escape("DeepSpeedStrategy.save_checkpoint(..., storage_options=...)` is not")):
162+
strategy.save_checkpoint(path=tmp_path, state=Mock(), storage_options=Mock())
163+
164+
165+
@RunIf(deepspeed=True)
166+
def test_deepspeed_save_checkpoint_one_deepspeed_engine_required(tmp_path):
167+
"""Test that the DeepSpeed strategy can only save one DeepSpeedEngine per checkpoint."""
168+
from deepspeed import DeepSpeedEngine
169+
170+
strategy = DeepSpeedStrategy()
171+
172+
# missing DeepSpeedEngine
173+
with pytest.raises(ValueError, match="Could not find a DeepSpeed model in the provided checkpoint state."):
174+
strategy.save_checkpoint(path=tmp_path, state={})
175+
with pytest.raises(ValueError, match="Could not find a DeepSpeed model in the provided checkpoint state."):
176+
strategy.save_checkpoint(path=tmp_path, state={"model": torch.nn.Linear(3, 3)})
177+
178+
# multiple DeepSpeedEngine
179+
model1 = Mock(spec=torch.nn.Module)
180+
model1.modules.return_value = [Mock(spec=DeepSpeedEngine)]
181+
model2 = Mock(spec=torch.nn.Module)
182+
model2.modules.return_value = [Mock(spec=DeepSpeedEngine)]
183+
with pytest.raises(ValueError, match="Found multiple DeepSpeed engine modules in the given state."):
184+
strategy.save_checkpoint(path=tmp_path, state={"model1": model1, "model2": model2})
185+
186+
187+
@RunIf(deepspeed=True)
188+
def test_deepspeed_save_checkpoint_client_state_separation(tmp_path):
189+
"""Test that the DeepSpeed engine and optimizer get separated from the client state."""
190+
from deepspeed import DeepSpeedEngine
191+
192+
strategy = DeepSpeedStrategy()
193+
194+
# Model only
195+
model = Mock(spec=DeepSpeedEngine, optimizer=None)
196+
model.modules.return_value = [model]
197+
strategy.save_checkpoint(path=tmp_path, state={"model": model, "test": "data"})
198+
# the client_state should not contain any deepspeed engine or deepspeed optimizer
199+
model.save_checkpoint.assert_called_with(tmp_path, client_state={"test": "data"}, tag="checkpoint")
200+
201+
# Model and optimizer
202+
optimizer = Mock()
203+
model = Mock(spec=DeepSpeedEngine, optimizer=optimizer)
204+
model.modules.return_value = [model]
205+
strategy.save_checkpoint(path=tmp_path, state={"model": model, "optimizer": optimizer, "test": "data"})
206+
# the client_state should not contain any deepspeed engine or deepspeed optimizer
207+
model.save_checkpoint.assert_called_with(tmp_path, client_state={"test": "data"}, tag="checkpoint")
208+
209+
210+
@RunIf(deepspeed=True)
211+
def test_deepspeed_save_checkpoint_warn_colliding_keys(tmp_path):
212+
"""Test that the strategy warns if there are keys in the user dict that collide internally with DeepSpeed."""
213+
from deepspeed import DeepSpeedEngine
214+
215+
strategy = DeepSpeedStrategy()
216+
optimizer = Mock()
217+
model = Mock(spec=DeepSpeedEngine, optimizer=optimizer)
218+
model.modules.return_value = [model]
219+
# `mp_world_size` is an internal key
220+
with pytest.warns(UserWarning, match="Your state has keys that collide with DeepSpeed's internal"):
221+
strategy.save_checkpoint(path=tmp_path, state={"model": model, "optimizer": optimizer, "mp_world_size": 2})
222+
223+
224+
@RunIf(deepspeed=True)
225+
def test_deepspeed_load_checkpoint_no_state(tmp_path):
226+
"""Test that DeepSpeed can't load the full state without access to a model instance from the user."""
227+
strategy = DeepSpeedStrategy()
228+
with pytest.raises(ValueError, match=escape("Got DeepSpeedStrategy.load_checkpoint(..., state=None")):
229+
strategy.load_checkpoint(path=tmp_path, state=None)
230+
with pytest.raises(ValueError, match=escape("Got DeepSpeedStrategy.load_checkpoint(..., state={})")):
231+
strategy.load_checkpoint(path=tmp_path, state={})
232+
233+
234+
@RunIf(deepspeed=True)
235+
def test_deepspeed_load_checkpoint_one_deepspeed_engine_required(tmp_path):
236+
"""Test that the DeepSpeed strategy can only load one DeepSpeedEngine per checkpoint."""
237+
from deepspeed import DeepSpeedEngine
238+
239+
strategy = DeepSpeedStrategy()
240+
241+
# missing DeepSpeedEngine
242+
with pytest.raises(ValueError, match="Could not find a DeepSpeed model in the provided checkpoint state."):
243+
strategy.load_checkpoint(path=tmp_path, state={"other": "data"})
244+
with pytest.raises(ValueError, match="Could not find a DeepSpeed model in the provided checkpoint state."):
245+
strategy.load_checkpoint(path=tmp_path, state={"model": torch.nn.Linear(3, 3)})
246+
247+
# multiple DeepSpeedEngine
248+
model1 = Mock(spec=torch.nn.Module)
249+
model1.modules.return_value = [Mock(spec=DeepSpeedEngine)]
250+
model2 = Mock(spec=torch.nn.Module)
251+
model2.modules.return_value = [Mock(spec=DeepSpeedEngine)]
252+
with pytest.raises(ValueError, match="Found multiple DeepSpeed engine modules in the given state."):
253+
strategy.load_checkpoint(path=tmp_path, state={"model1": model1, "model2": model2})
254+
255+
256+
@RunIf(deepspeed=True)
257+
def test_deepspeed_load_checkpoint_client_state_missing(tmp_path):
258+
"""Test that the DeepSpeed strategy raises a custom error when client state couldn't be loaded by DeepSpeed."""
259+
from deepspeed import DeepSpeedEngine
260+
261+
strategy = DeepSpeedStrategy()
262+
optimizer = Mock()
263+
model = Mock(spec=DeepSpeedEngine, optimizer=optimizer)
264+
model.modules.return_value = [model]
265+
266+
# If the DeepSpeed engine fails to load the checkpoint file (e.g., file not found), it prints a warning and
267+
# returns None from its function call
268+
model.load_checkpoint.return_value = [None, None]
269+
270+
# Check for our custom user error
271+
with pytest.raises(RuntimeError, match="DeepSpeed was unable to load the checkpoint"):
272+
strategy.load_checkpoint(path=tmp_path, state={"model": model, "optimizer": optimizer, "test": "data"})
273+
274+
275+
@RunIf(deepspeed=True)
276+
def test_deepspeed_load_checkpoint_state_updated_with_client_state(tmp_path):
277+
"""Test that the DeepSpeed strategy properly updates the state variables and returns additional metadata."""
278+
from deepspeed import DeepSpeedEngine
279+
280+
strategy = DeepSpeedStrategy()
281+
optimizer = Mock()
282+
model = Mock(spec=DeepSpeedEngine, optimizer=optimizer)
283+
model.modules.return_value = [model]
284+
285+
# the client state contains the additional user data that was proveded when saving, plus some deepspeed metadata
286+
loaded_client_state = {"user_data": {"iteration": 5}, "deepspeed_metadata": "data"}
287+
model.load_checkpoint.return_value = [None, loaded_client_state]
288+
289+
state = {"model": model, "user_data": {"iteration": 0}}
290+
metadata = strategy.load_checkpoint(path=tmp_path, state=state)
291+
292+
# the user's state gets updated with the loaded value
293+
assert state == {"model": model, "user_data": {"iteration": 5}}
294+
# additional metadata gets separated from client state
295+
assert metadata == {"deepspeed_metadata": "data"}
296+
297+
298+
@RunIf(deepspeed=True)
299+
@pytest.mark.parametrize("optimzer_state_requested", [True, False])
300+
def test_deepspeed_load_checkpoint_optimzer_state_requested(optimzer_state_requested, tmp_path):
301+
"""Test that the DeepSpeed strategy loads the optimizer state only when requested."""
302+
from deepspeed import DeepSpeedEngine
303+
304+
strategy = DeepSpeedStrategy()
305+
optimizer = Mock(spec=Optimizer)
306+
model = Mock(spec=DeepSpeedEngine, optimizer=optimizer)
307+
model.modules.return_value = [model]
308+
309+
# required, otherwise mock cannot be unpacked
310+
model.load_checkpoint.return_value = [None, {}]
311+
312+
state = {"model": model}
313+
if optimzer_state_requested:
314+
state["optimizer"] = optimizer
315+
316+
strategy.load_checkpoint(path=tmp_path, state=state)
317+
model.load_checkpoint.assert_called_with(
318+
tmp_path,
319+
tag="checkpoint",
320+
load_optimizer_states=optimzer_state_requested,
321+
load_lr_scheduler_states=False,
322+
load_module_strict=True,
323+
)

0 commit comments

Comments
 (0)