Skip to content

Commit a3c97c9

Browse files
committed
squashs all commits
1 parent f4a1c2d commit a3c97c9

File tree

3 files changed

+145
-22
lines changed

3 files changed

+145
-22
lines changed

src/lightning_fabric/fabric.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,7 @@ def save(self, path: Union[str, Path], state: Dict[str, Union[nn.Module, Optimiz
519519
state: A dictionary with contents to be saved. If the dict contains modules or optimizers, their
520520
state-dict will be retrieved and converted automatically.
521521
"""
522+
# TODO: validate deepspeed model with self._models_setup > 1
522523
return self._strategy.save_checkpoint(path=path, state=_unwrap_objects(state))
523524

524525
def load(
@@ -537,6 +538,9 @@ def load(
537538
The remaining items that were not restored into the given state dictionary. If no state dictionary is
538539
given, the full checkpoint will be returned.
539540
"""
541+
# TODO: validate deepspeed model with self._models_setup > 1
542+
# if isinstance(self._strategy, DeepSpeedStrategy):
543+
540544
return self._strategy.load_checkpoint(path=path, state=state)
541545

542546
def launch(self, function: Optional[Callable[["Fabric"], Any]] = None, *args: Any, **kwargs: Any) -> Any:

src/lightning_fabric/strategies/deepspeed.py

Lines changed: 52 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717
import os
1818
import platform
1919
from contextlib import contextmanager
20+
from itertools import chain
2021
from pathlib import Path
21-
from typing import Any, Dict, Generator, Iterable, List, Mapping, Optional, Tuple, TYPE_CHECKING, Union
22+
from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, TYPE_CHECKING, Union
2223

2324
import torch
2425
from lightning_utilities.core.imports import RequirementCache
@@ -376,22 +377,37 @@ def save_checkpoint(
376377
Raises:
377378
TypeError if the unused ``storage_options`` gets passed.
378379
"""
379-
# broadcast the path from rank 0 to ensure all the states are saved in a common path
380-
path = self.broadcast(path)
381-
382380
if storage_options is not None:
383381
raise TypeError(
384382
f"`{self.__class__.__name__}.save_checkpoint(..., storage_options=...)` is not supported because"
385383
f" {self.__class__.__name__} does not use the `CheckpointIO`."
386384
)
385+
# validate that the deepspeed engine recorded in this strategy corresponds with the model the user
386+
# is handling
387+
# TODO: we support multiple models with deepspeed, redo this error
388+
if self._deepspeed_engine not in state.values():
389+
raise ValueError(
390+
"Could not find a deepspeed model in the provided checkpoint state. Please provide the model as"
391+
" part of the state like so: `save_checkpoint(..., state={'model': model, ...})`. Make sure"
392+
" you set up the model (and optimizers if any) through the strategy before saving the checkpoint."
393+
)
387394

395+
# broadcast the path from rank 0 to ensure all the states are saved in a common path
396+
path = self.broadcast(path)
397+
398+
# split the checkpoint into two parts:
399+
# 1) the deepspeed engine encapsulating both the model and optionally the optimizer(s)
400+
# 2) the rest of the user's state, which in deepspeed is called `client state`
388401
excluded_objects = (self._deepspeed_engine, self._deepspeed_engine.optimizer)
389402
state = {k: v for k, v in state.items() if v not in excluded_objects}
403+
# there might be other stateful objects unrelatd to the deepspeed engine - convert them to a state_dict
390404
state = self._convert_stateful_objects_in_state(state)
391-
# Use deepspeed's internal checkpointing function to handle partitioned weights across processes
405+
# use deepspeed's internal checkpointing function to handle partitioned weights across processes
392406
self._deepspeed_engine.save_checkpoint(path, client_state=state, tag="checkpoint")
393407

394-
def load_checkpoint(self, path: _PATH, state: Optional[Dict[str, Union[Module, Optimizer, Any]]] = None) -> Dict[str, Any]:
408+
def load_checkpoint(
409+
self, path: _PATH, state: Optional[Dict[str, Union[Module, Optimizer, Any]]] = None
410+
) -> Dict[str, Any]:
395411
"""Load the contents from a checkpoint and restore the state of the given objects.
396412
397413
Args:
@@ -404,25 +420,43 @@ def load_checkpoint(self, path: _PATH, state: Optional[Dict[str, Union[Module, O
404420
given, the full checkpoint will be returned.
405421
"""
406422
if self.load_full_weights and self.zero_stage_3:
407-
# Broadcast to ensure we load from the rank 0 checkpoint
408-
# This doesn't have to be the case when using deepspeed sharded checkpointing
423+
# This code path to enables loading a checkpoint from a non-deepspeed checkpoint or from
424+
# a consolidated checkpoint
409425
path = self.broadcast(path)
410426
return super().load_checkpoint(path=path, state=state)
411427

412-
if self._deepspeed_engine not in state.values():
413-
# TODO
414-
raise ValueError()
415-
optimzer_state_requested = bool(len([item for item in state.values() if isinstance(item, Optimizer)]))
416-
417428
torch.cuda.empty_cache()
418-
_, client_state = self._deepspeed_engine.load_checkpoint(
419-
path, load_optimizer_states=optimzer_state_requested, load_lr_scheduler_states=False
429+
430+
from deepspeed import DeepSpeedEngine
431+
432+
modules = chain(*(module.modules() for module in state.values() if isinstance(module, Module)))
433+
engines = [engine for engine in modules if isinstance(engine, DeepSpeedEngine)]
434+
if len(engines) == 0:
435+
raise ValueError(
436+
"Could not find a deepspeed model in the provided checkpoint state. Please provide the model as"
437+
" part of the state like so: `load_checkpoint(..., state={'model': model, ...})`. Make sure"
438+
" you set up the model (and optimizers if any) through the strategy before loading the checkpoint."
439+
)
440+
elif len(engines) > 1:
441+
raise ValueError(
442+
"Found multiple DeepSpeed engine modules in the given state. Saving checkpoints with DeepSpeed is"
443+
" currently limited to a single model per checkpoint. To save multiple model checkpoints, call the"
444+
" save method for each model separately with a different path."
445+
)
446+
engine = engines[0]
447+
448+
optimzer_state_requested = bool(len([item for item in state.values() if isinstance(item, Optimizer)]))
449+
_, client_state = engine.load_checkpoint(
450+
path,
451+
tag="checkpoint",
452+
load_optimizer_states=optimzer_state_requested,
453+
load_lr_scheduler_states=False,
454+
load_module_strict=True, # TODO: make strict loading configurable
420455
)
421456
if client_state is None:
422-
# TODO: fix message
423457
raise ValueError(
424-
"DeepSpeed was unable to load the checkpoint. Ensure you passed in a DeepSpeed compatible checkpoint "
425-
"or a single checkpoint file with `Trainer(strategy=DeepSpeedStrategy(load_full_weights=True))`."
458+
"DeepSpeed was unable to load the checkpoint. Ensure you passed in a DeepSpeed compatible checkpoint"
459+
" or a single checkpoint file by setting `DeepSpeedStrategy(..., load_full_weights=True)`."
426460
)
427461
for k, v in client_state.copy().items():
428462
if k not in state:

tests/tests_fabric/strategies/test_deepspeed_integration.py

Lines changed: 89 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def _make_block(self):
241241

242242

243243
@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True)
244-
def test_deepspeed_multigpu_stage_3(tmpdir):
244+
def test_deepspeed_multigpu_stage_3():
245245
"""Test to ensure ZeRO Stage 3 works with a parallel model."""
246246
fabric = ModelParallelClassification(
247247
strategy=DeepSpeedStrategy(stage=3),
@@ -255,7 +255,7 @@ def test_deepspeed_multigpu_stage_3(tmpdir):
255255
@RunIf(deepspeed=True)
256256
@mock.patch("deepspeed.init_distributed", autospec=True)
257257
@pytest.mark.parametrize("platform", ["Linux", "Windows"])
258-
def test_deepspeed_env_variables_on_platforms(deepspeed_dist_mock, tmpdir, platform):
258+
def test_deepspeed_env_variables_on_platforms(deepspeed_dist_mock, platform):
259259
"""Test to ensure that we set up distributed communication correctly.
260260
261261
When using Windows, ranks environment variables should not be set, and DeepSpeed should handle this.
@@ -279,7 +279,7 @@ def test_deepspeed_env_variables_on_platforms(deepspeed_dist_mock, tmpdir, platf
279279

280280

281281
@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True)
282-
def test_deepspeed_specific_gpu_device_index(tmpdir):
282+
def test_deepspeed_specific_gpu_device_index():
283283
"""Test that the DeepSpeed strategy can run on specific device indices."""
284284

285285
class RunFabric(BoringFabric):
@@ -295,7 +295,7 @@ def step(self, model, batch):
295295

296296

297297
@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True, bf16_cuda=True)
298-
def test_deepspeed_with_bfloat16_precision(tmpdir):
298+
def test_deepspeed_with_bfloat16_precision():
299299
"""Test that the DeepSpeed strategy works with bfloat16 precision."""
300300

301301
class Model(nn.Module):
@@ -322,3 +322,88 @@ def step(self, model, batch):
322322
assert fabric._strategy.precision.precision == "bf16"
323323
assert fabric._strategy.config["zero_optimization"]["stage"] == 3
324324
fabric.run()
325+
326+
327+
def _assert_saved_model_is_equal(fabric, model, checkpoint_path):
328+
"""Convert the saved checkpoint to a single file with the model weights consolidated to easily verify the full
329+
weights in float32 precision."""
330+
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
331+
332+
assert isinstance(fabric.strategy, DeepSpeedStrategy)
333+
334+
# carry out the check only on rank 0
335+
if fabric.is_global_zero:
336+
if fabric.strategy.config["zero_optimization"]["stage"] in (2, 3):
337+
single_ckpt_path = checkpoint_path / "single_model.pt"
338+
# the tag is hardcoded in DeepSpeedStrategy
339+
convert_zero_checkpoint_to_fp32_state_dict(checkpoint_path, single_ckpt_path, tag="checkpoint")
340+
state_dict = torch.load(single_ckpt_path)
341+
else:
342+
# 'checkpoint' is the tag, hardcoded in DeepSpeedStrategy
343+
single_ckpt_path = checkpoint_path / "checkpoint" / "mp_rank_00_model_states.pt"
344+
state_dict = torch.load(single_ckpt_path)["module"]
345+
346+
model = model.cpu()
347+
348+
# assert model parameters are identical after loading
349+
for orig_param, saved_model_param in zip(model.parameters(), state_dict.values()):
350+
# perform the equality check in the same precision
351+
saved_model_param = saved_model_param.cpu().to(orig_param.dtype)
352+
assert torch.equal(orig_param, saved_model_param)
353+
354+
fabric.barrier()
355+
356+
357+
@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True, bf16_cuda=True)
358+
@pytest.mark.parametrize("stage", [1, 2, 3])
359+
def test_deepspeed_save_load_checkpoint_zero_3(stage, tmp_path):
360+
"""Test that DeepSpeed stage 1, 2, and 3 model checkpoints can be saved and loaded successfully."""
361+
from deepspeed import DeepSpeedEngine
362+
363+
fabric = Fabric(accelerator="cuda", devices=2, strategy=DeepSpeedStrategy(stage=stage), precision="bf16")
364+
fabric.launch()
365+
366+
checkpoint_path = fabric.broadcast(tmp_path / "deepspeed-checkpoint")
367+
368+
with fabric.sharded_model():
369+
model = BoringModel()
370+
371+
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
372+
model, optimizer = fabric.setup(model, optimizer)
373+
assert isinstance(model._forward_module, DeepSpeedEngine)
374+
375+
# TODO(fabric): The dtype on the model is not correct, should be torch.bfloat16
376+
assert model.dtype == torch.float32
377+
assert next(model.parameters()).dtype == torch.bfloat16
378+
379+
# dummy training step
380+
output = model(torch.randn(1, 32).to(fabric.device))
381+
loss = output.sum()
382+
fabric.backward(loss)
383+
optimizer.step()
384+
optimizer.zero_grad()
385+
386+
state = {"model": model, "optimizer": optimizer, "steps": 1}
387+
fabric.save(checkpoint_path, state)
388+
389+
fabric.barrier()
390+
391+
# re-init all objects and resume
392+
fabric = Fabric(accelerator="cuda", devices=2, strategy=DeepSpeedStrategy(stage=stage), precision="bf16")
393+
fabric.launch()
394+
with fabric.sharded_model():
395+
model = BoringModel()
396+
397+
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
398+
model, optimizer = fabric.setup(model, optimizer)
399+
state = {"model": model, "optimizer": optimizer, "steps": 0}
400+
401+
metadata = fabric.load(checkpoint_path, state)
402+
fabric.barrier()
403+
404+
# check user data in state reloaded
405+
assert state["steps"] == 1
406+
# the remainder of the deepspeed checkpoint contains metadata
407+
assert "ds_version" in metadata
408+
409+
_assert_saved_model_is_equal(fabric, model, checkpoint_path)

0 commit comments

Comments
 (0)