diff --git a/src/lightning_fabric/CHANGELOG.md b/src/lightning_fabric/CHANGELOG.md index 4c5c406db5747..3319362b79790 100644 --- a/src/lightning_fabric/CHANGELOG.md +++ b/src/lightning_fabric/CHANGELOG.md @@ -9,7 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- +- Added support for saving and loading DeepSpeed checkpoints through `Fabric.save/load()` ([#16452](https://github.com/Lightning-AI/lightning/pull/16452)) ### Changed diff --git a/src/lightning_fabric/strategies/deepspeed.py b/src/lightning_fabric/strategies/deepspeed.py index cc7718bdc5eb8..48ce926f96050 100644 --- a/src/lightning_fabric/strategies/deepspeed.py +++ b/src/lightning_fabric/strategies/deepspeed.py @@ -17,8 +17,9 @@ import os import platform from contextlib import contextmanager +from itertools import chain from pathlib import Path -from typing import Any, Dict, Generator, Iterable, List, Mapping, Optional, Tuple, TYPE_CHECKING, Union +from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, TYPE_CHECKING, Union import torch from lightning_utilities.core.imports import RequirementCache @@ -31,7 +32,7 @@ from lightning_fabric.strategies.ddp import DDPStrategy from lightning_fabric.strategies.strategy import _Sharded from lightning_fabric.utilities.distributed import log -from lightning_fabric.utilities.rank_zero import rank_zero_info, rank_zero_only +from lightning_fabric.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn from lightning_fabric.utilities.seed import reset_seed from lightning_fabric.utilities.types import _PATH @@ -365,24 +366,124 @@ def module_sharded_context(self) -> Generator[None, None, None]: def save_checkpoint( self, path: _PATH, state: Dict[str, Union[Module, Optimizer, Any]], storage_options: Optional[Any] = None ) -> None: - raise NotImplementedError + """Save model, optimizer, and other state in a checkpoint directory. + + Args: + path: A path to where the files should be saved + state: A dictionary with contents to be saved. If the dict contains modules or optimizers, their + state-dict will be retrieved and converted automatically. + storage_options: Unused by this strategy, since it doesn't use a ``CheckpointIO`` plugin. + + Raises: + TypeError: + If the unused ``storage_options`` gets passed. + ValueError: + When no :class:`deepspeed.DeepSpeedEngine` objects were found in the state, or when multiple + :class:`deepspeed.DeepSpeedEngine` objects were found. + """ + if storage_options is not None: + raise TypeError( + "`DeepSpeedStrategy.save_checkpoint(..., storage_options=...)` is not supported because" + " `DeepSpeedStrategy` does not use the `CheckpointIO`." + ) + + engines = _get_deepspeed_engines_from_state(state) + if len(engines) == 0: + raise ValueError( + "Could not find a DeepSpeed model in the provided checkpoint state. Please provide the model as" + " part of the state like so: `save_checkpoint(..., state={'model': model, ...})`. Make sure" + " you set up the model (and optimizers if any) through the strategy before saving the checkpoint." + ) + elif len(engines) > 1: + raise ValueError( + "Found multiple DeepSpeed engine modules in the given state. Saving checkpoints with DeepSpeed is" + " currently limited to a single model per checkpoint. To save multiple models, call the" + " save method for each model separately with a different path." + ) + engine = engines[0] + + # broadcast the path from rank 0 to ensure all the states are saved in a common path + path = self.broadcast(path) + + # split the checkpoint into two parts: + # 1) the deepspeed engine encapsulating both the model and optionally the optimizer(s) + # 2) the rest of the user's state, which in deepspeed is called `client state` + excluded_objects = (engine, engine.optimizer) if engine.optimizer is not None else (engine,) + state = {k: v for k, v in state.items() if v not in excluded_objects} + _validate_state_keys(state) + # there might be other stateful objects unrelated to the deepspeed engine - convert them to a state_dict + state = self._convert_stateful_objects_in_state(state) + # use deepspeed's internal checkpointing function to handle partitioned weights across processes + engine.save_checkpoint(path, client_state=state, tag="checkpoint") def load_checkpoint( self, path: _PATH, state: Optional[Dict[str, Union[Module, Optimizer, Any]]] = None ) -> Dict[str, Any]: - raise NotImplementedError - - def load_optimizer_state_dict( - self, optimizers: Union[Optimizer, Iterable[Optimizer]], checkpoint: Mapping[str, Any] - ) -> None: - # override to do nothing, deepspeed engine already loaded the states in `load_checkpoint()` - pass + """Load the contents from a checkpoint and restore the state of the given objects. - def load_module_state_dict(self, module: Module, checkpoint: Mapping[str, Any]) -> None: - # override to do nothing, deepspeed engine already loaded the weights in `load_checkpoint()` + Args: + path: A path to where the file is located + state: A dictionary of objects whose state will be restored in-place from the checkpoint path. + This should contain exactly one model, and the model must already be set up by DeepSpeed. + + Returns: + Dictionary with the state inside DeepSpeed's engine + + Raises: + ValueError: + If no state is provided, when no :class:`deepspeed.DeepSpeedEngine` objects were found in the + state, or when multiple :class:`deepspeed.DeepSpeedEngine` objects were found. + RuntimeError: + If DeepSpeed was unable to load the checkpoint due to missing files or because the checkpoint is + not in the expected DeepSpeed format. + """ if self.load_full_weights and self.zero_stage_3: - self.module_to_device(module) - self._restore_zero_state(module, checkpoint) + # This code path to enables loading a checkpoint from a non-deepspeed checkpoint or from + # a consolidated checkpoint + path = self.broadcast(path) + return super().load_checkpoint(path=path, state=state) + + if not state: + raise ValueError( + f"Got DeepSpeedStrategy.load_checkpoint(..., state={state!r}) but a state with at least " + f" a model instance to reload is required. Pass it in like so:" + " DeepSpeedStrategy.load_checkpoint(..., state={'model': model, ...})" + ) + + engines = _get_deepspeed_engines_from_state(state) + if len(engines) == 0: + raise ValueError( + "Could not find a DeepSpeed model in the provided checkpoint state. Please provide the model as" + " part of the state like so: `load_checkpoint(..., state={'model': model, ...})`. Make sure" + " you set up the model (and optimizers if any) through the strategy before loading the checkpoint." + ) + elif len(engines) > 1: + raise ValueError( + "Found multiple DeepSpeed engine modules in the given state. Saving and loading checkpoints" + " with DeepSpeed is currently limited to a single model per checkpoint. To load multiple model" + " states, call the load method for each model checkpoint separately." + ) + engine = engines[0] + optimzer_state_requested = bool(len([item for item in state.values() if isinstance(item, Optimizer)])) + + torch.cuda.empty_cache() + _, client_state = engine.load_checkpoint( + path, + tag="checkpoint", + load_optimizer_states=optimzer_state_requested, + load_lr_scheduler_states=False, + load_module_strict=True, # TODO(fabric): make strict loading configurable + ) + if client_state is None: + raise RuntimeError( + "DeepSpeed was unable to load the checkpoint. Ensure you passed in a DeepSpeed compatible checkpoint" + " or a single checkpoint file by setting `DeepSpeedStrategy(..., load_full_weights=True)`." + ) + for k, v in client_state.copy().items(): + if k not in state: + continue + state[k] = client_state.pop(k) + return client_state @classmethod def register_strategies(cls, strategy_registry: Dict) -> None: @@ -645,3 +746,38 @@ def _load_config(self, config: Optional[Union[_PATH, Dict[str, Any]]]) -> Option config = json.load(f) assert isinstance(config, dict) or config is None return config + + +def _get_deepspeed_engines_from_state(state: Dict[str, Any]) -> List["deepspeed.DeepSpeedEngine"]: + from deepspeed import DeepSpeedEngine + + modules = chain(*(module.modules() for module in state.values() if isinstance(module, Module))) + engines = [engine for engine in modules if isinstance(engine, DeepSpeedEngine)] + return engines + + +def _validate_state_keys(state: Dict[str, Any]) -> None: + # DeepSpeed merges the client state into its internal engine state when saving, but it does not check for + # colliding keys from the user. We explicitly check it here: + deepspeed_internal_keys = { + "module", + "buffer_names", + "optimizer", + "param_shapes", + "lr_scheduler", + "sparse_tensor_module_names", + "skipped_steps", + "global_steps", + "global_samples", + "dp_world_size", + "mp_world_size", + "ds_config", + "ds_version", + } + colliding_keys = deepspeed_internal_keys.intersection(state.keys()) + if colliding_keys: + rank_zero_warn( + "Your state has keys that collide with DeepSpeed's internal engine state. This could result in your" + " values being overwritten by DeepSpeed. Consider changing the name of these keys to something else: " + + ", ".join(colliding_keys) + ) diff --git a/tests/tests_fabric/strategies/test_deepspeed.py b/tests/tests_fabric/strategies/test_deepspeed.py index 6828e1a636d3c..cae7ef19f442b 100644 --- a/tests/tests_fabric/strategies/test_deepspeed.py +++ b/tests/tests_fabric/strategies/test_deepspeed.py @@ -20,6 +20,7 @@ import pytest import torch from tests_fabric.helpers.runif import RunIf +from torch.optim import Optimizer from lightning_fabric.accelerators import CPUAccelerator from lightning_fabric.strategies import DeepSpeedStrategy @@ -151,3 +152,172 @@ def test_deepspeed_requires_joint_setup(): NotImplementedError, match=escape("does not support setting up the module and optimizer(s) independently") ): strategy.setup_optimizer(Mock()) + + +@RunIf(deepspeed=True) +def test_deepspeed_save_checkpoint_storage_options(tmp_path): + """Test that the DeepSpeed strategy does not accept storage options for saving checkpoints.""" + strategy = DeepSpeedStrategy() + with pytest.raises(TypeError, match=escape("DeepSpeedStrategy.save_checkpoint(..., storage_options=...)` is not")): + strategy.save_checkpoint(path=tmp_path, state=Mock(), storage_options=Mock()) + + +@RunIf(deepspeed=True) +def test_deepspeed_save_checkpoint_one_deepspeed_engine_required(tmp_path): + """Test that the DeepSpeed strategy can only save one DeepSpeedEngine per checkpoint.""" + from deepspeed import DeepSpeedEngine + + strategy = DeepSpeedStrategy() + + # missing DeepSpeedEngine + with pytest.raises(ValueError, match="Could not find a DeepSpeed model in the provided checkpoint state."): + strategy.save_checkpoint(path=tmp_path, state={}) + with pytest.raises(ValueError, match="Could not find a DeepSpeed model in the provided checkpoint state."): + strategy.save_checkpoint(path=tmp_path, state={"model": torch.nn.Linear(3, 3)}) + + # multiple DeepSpeedEngine + model1 = Mock(spec=torch.nn.Module) + model1.modules.return_value = [Mock(spec=DeepSpeedEngine)] + model2 = Mock(spec=torch.nn.Module) + model2.modules.return_value = [Mock(spec=DeepSpeedEngine)] + with pytest.raises(ValueError, match="Found multiple DeepSpeed engine modules in the given state."): + strategy.save_checkpoint(path=tmp_path, state={"model1": model1, "model2": model2}) + + +@RunIf(deepspeed=True) +def test_deepspeed_save_checkpoint_client_state_separation(tmp_path): + """Test that the DeepSpeed engine and optimizer get separated from the client state.""" + from deepspeed import DeepSpeedEngine + + strategy = DeepSpeedStrategy() + + # Model only + model = Mock(spec=DeepSpeedEngine, optimizer=None) + model.modules.return_value = [model] + strategy.save_checkpoint(path=tmp_path, state={"model": model, "test": "data"}) + # the client_state should not contain any deepspeed engine or deepspeed optimizer + model.save_checkpoint.assert_called_with(tmp_path, client_state={"test": "data"}, tag="checkpoint") + + # Model and optimizer + optimizer = Mock() + model = Mock(spec=DeepSpeedEngine, optimizer=optimizer) + model.modules.return_value = [model] + strategy.save_checkpoint(path=tmp_path, state={"model": model, "optimizer": optimizer, "test": "data"}) + # the client_state should not contain any deepspeed engine or deepspeed optimizer + model.save_checkpoint.assert_called_with(tmp_path, client_state={"test": "data"}, tag="checkpoint") + + +@RunIf(deepspeed=True) +def test_deepspeed_save_checkpoint_warn_colliding_keys(tmp_path): + """Test that the strategy warns if there are keys in the user dict that collide internally with DeepSpeed.""" + from deepspeed import DeepSpeedEngine + + strategy = DeepSpeedStrategy() + optimizer = Mock() + model = Mock(spec=DeepSpeedEngine, optimizer=optimizer) + model.modules.return_value = [model] + # `mp_world_size` is an internal key + with pytest.warns(UserWarning, match="Your state has keys that collide with DeepSpeed's internal"): + strategy.save_checkpoint(path=tmp_path, state={"model": model, "optimizer": optimizer, "mp_world_size": 2}) + + +@RunIf(deepspeed=True) +def test_deepspeed_load_checkpoint_no_state(tmp_path): + """Test that DeepSpeed can't load the full state without access to a model instance from the user.""" + strategy = DeepSpeedStrategy() + with pytest.raises(ValueError, match=escape("Got DeepSpeedStrategy.load_checkpoint(..., state=None")): + strategy.load_checkpoint(path=tmp_path, state=None) + with pytest.raises(ValueError, match=escape("Got DeepSpeedStrategy.load_checkpoint(..., state={})")): + strategy.load_checkpoint(path=tmp_path, state={}) + + +@RunIf(deepspeed=True) +def test_deepspeed_load_checkpoint_one_deepspeed_engine_required(tmp_path): + """Test that the DeepSpeed strategy can only load one DeepSpeedEngine per checkpoint.""" + from deepspeed import DeepSpeedEngine + + strategy = DeepSpeedStrategy() + + # missing DeepSpeedEngine + with pytest.raises(ValueError, match="Could not find a DeepSpeed model in the provided checkpoint state."): + strategy.load_checkpoint(path=tmp_path, state={"other": "data"}) + with pytest.raises(ValueError, match="Could not find a DeepSpeed model in the provided checkpoint state."): + strategy.load_checkpoint(path=tmp_path, state={"model": torch.nn.Linear(3, 3)}) + + # multiple DeepSpeedEngine + model1 = Mock(spec=torch.nn.Module) + model1.modules.return_value = [Mock(spec=DeepSpeedEngine)] + model2 = Mock(spec=torch.nn.Module) + model2.modules.return_value = [Mock(spec=DeepSpeedEngine)] + with pytest.raises(ValueError, match="Found multiple DeepSpeed engine modules in the given state."): + strategy.load_checkpoint(path=tmp_path, state={"model1": model1, "model2": model2}) + + +@RunIf(deepspeed=True) +def test_deepspeed_load_checkpoint_client_state_missing(tmp_path): + """Test that the DeepSpeed strategy raises a custom error when client state couldn't be loaded by DeepSpeed.""" + from deepspeed import DeepSpeedEngine + + strategy = DeepSpeedStrategy() + optimizer = Mock() + model = Mock(spec=DeepSpeedEngine, optimizer=optimizer) + model.modules.return_value = [model] + + # If the DeepSpeed engine fails to load the checkpoint file (e.g., file not found), it prints a warning and + # returns None from its function call + model.load_checkpoint.return_value = [None, None] + + # Check for our custom user error + with pytest.raises(RuntimeError, match="DeepSpeed was unable to load the checkpoint"): + strategy.load_checkpoint(path=tmp_path, state={"model": model, "optimizer": optimizer, "test": "data"}) + + +@RunIf(deepspeed=True) +def test_deepspeed_load_checkpoint_state_updated_with_client_state(tmp_path): + """Test that the DeepSpeed strategy properly updates the state variables and returns additional metadata.""" + from deepspeed import DeepSpeedEngine + + strategy = DeepSpeedStrategy() + optimizer = Mock() + model = Mock(spec=DeepSpeedEngine, optimizer=optimizer) + model.modules.return_value = [model] + + # the client state contains the additional user data that was proveded when saving, plus some deepspeed metadata + loaded_client_state = {"user_data": {"iteration": 5}, "deepspeed_metadata": "data"} + model.load_checkpoint.return_value = [None, loaded_client_state] + + state = {"model": model, "user_data": {"iteration": 0}} + metadata = strategy.load_checkpoint(path=tmp_path, state=state) + + # the user's state gets updated with the loaded value + assert state == {"model": model, "user_data": {"iteration": 5}} + # additional metadata gets separated from client state + assert metadata == {"deepspeed_metadata": "data"} + + +@RunIf(deepspeed=True) +@pytest.mark.parametrize("optimzer_state_requested", [True, False]) +def test_deepspeed_load_checkpoint_optimzer_state_requested(optimzer_state_requested, tmp_path): + """Test that the DeepSpeed strategy loads the optimizer state only when requested.""" + from deepspeed import DeepSpeedEngine + + strategy = DeepSpeedStrategy() + optimizer = Mock(spec=Optimizer) + model = Mock(spec=DeepSpeedEngine, optimizer=optimizer) + model.modules.return_value = [model] + + # required, otherwise mock cannot be unpacked + model.load_checkpoint.return_value = [None, {}] + + state = {"model": model} + if optimzer_state_requested: + state["optimizer"] = optimizer + + strategy.load_checkpoint(path=tmp_path, state=state) + model.load_checkpoint.assert_called_with( + tmp_path, + tag="checkpoint", + load_optimizer_states=optimzer_state_requested, + load_lr_scheduler_states=False, + load_module_strict=True, + ) diff --git a/tests/tests_fabric/strategies/test_deepspeed_integration.py b/tests/tests_fabric/strategies/test_deepspeed_integration.py index 2c5b3b09cd19a..ed6aea76bb84a 100644 --- a/tests/tests_fabric/strategies/test_deepspeed_integration.py +++ b/tests/tests_fabric/strategies/test_deepspeed_integration.py @@ -241,7 +241,7 @@ def _make_block(self): @RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True) -def test_deepspeed_multigpu_stage_3(tmpdir): +def test_deepspeed_multigpu_stage_3(): """Test to ensure ZeRO Stage 3 works with a parallel model.""" fabric = ModelParallelClassification( strategy=DeepSpeedStrategy(stage=3), @@ -280,7 +280,7 @@ def test_deepspeed_env_variables_on_platforms(_, deepspeed_dist_mock, platform): @RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True) -def test_deepspeed_specific_gpu_device_index(tmpdir): +def test_deepspeed_specific_gpu_device_index(): """Test that the DeepSpeed strategy can run on specific device indices.""" class RunFabric(BoringFabric): @@ -296,7 +296,7 @@ def step(self, model, batch): @RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True, bf16_cuda=True) -def test_deepspeed_with_bfloat16_precision(tmpdir): +def test_deepspeed_with_bfloat16_precision(): """Test that the DeepSpeed strategy works with bfloat16 precision.""" class Model(nn.Module): @@ -323,3 +323,88 @@ def step(self, model, batch): assert fabric._strategy.precision.precision == "bf16" assert fabric._strategy.config["zero_optimization"]["stage"] == 3 fabric.run() + + +def _assert_saved_model_is_equal(fabric, model, checkpoint_path): + """Convert the saved checkpoint to a single file with the model weights consolidated to easily verify the full + weights in float32 precision.""" + from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict + + assert isinstance(fabric.strategy, DeepSpeedStrategy) + + # carry out the check only on rank 0 + if fabric.is_global_zero: + if fabric.strategy.config["zero_optimization"]["stage"] in (2, 3): + single_ckpt_path = checkpoint_path / "single_model.pt" + # the tag is hardcoded in DeepSpeedStrategy + convert_zero_checkpoint_to_fp32_state_dict(checkpoint_path, single_ckpt_path, tag="checkpoint") + state_dict = torch.load(single_ckpt_path) + else: + # 'checkpoint' is the tag, hardcoded in DeepSpeedStrategy + single_ckpt_path = checkpoint_path / "checkpoint" / "mp_rank_00_model_states.pt" + state_dict = torch.load(single_ckpt_path)["module"] + + model = model.cpu() + + # assert model parameters are identical after loading + for orig_param, saved_model_param in zip(model.parameters(), state_dict.values()): + # perform the equality check in the same precision + saved_model_param = saved_model_param.cpu().to(orig_param.dtype) + assert torch.equal(orig_param, saved_model_param) + + fabric.barrier() + + +@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True, bf16_cuda=True) +@pytest.mark.parametrize("stage", [1, 2, 3]) +def test_deepspeed_save_load_checkpoint_zero_3(stage, tmp_path): + """Test that DeepSpeed stage 1, 2, and 3 model checkpoints can be saved and loaded successfully.""" + from deepspeed import DeepSpeedEngine + + fabric = Fabric(accelerator="cuda", devices=2, strategy=DeepSpeedStrategy(stage=stage), precision="bf16") + fabric.launch() + + checkpoint_path = fabric.broadcast(tmp_path / "deepspeed-checkpoint") + + with fabric.sharded_model(): + model = BoringModel() + + optimizer = torch.optim.SGD(model.parameters(), lr=0.0001) + model, optimizer = fabric.setup(model, optimizer) + assert isinstance(model._forward_module, DeepSpeedEngine) + + # TODO(fabric): The dtype on the model is not correct, should be torch.bfloat16 + assert model.dtype == torch.float32 + assert next(model.parameters()).dtype == torch.bfloat16 + + # dummy training step + output = model(torch.randn(1, 32).to(fabric.device)) + loss = output.sum() + fabric.backward(loss) + optimizer.step() + optimizer.zero_grad() + + state = {"model": model, "optimizer": optimizer, "steps": 1} + fabric.save(checkpoint_path, state) + + fabric.barrier() + + # re-init all objects and resume + fabric = Fabric(accelerator="cuda", devices=2, strategy=DeepSpeedStrategy(stage=stage), precision="bf16") + fabric.launch() + with fabric.sharded_model(): + model = BoringModel() + + optimizer = torch.optim.SGD(model.parameters(), lr=0.0001) + model, optimizer = fabric.setup(model, optimizer) + state = {"model": model, "optimizer": optimizer, "steps": 0} + + metadata = fabric.load(checkpoint_path, state) + fabric.barrier() + + # check user data in state reloaded + assert state["steps"] == 1 + # the remainder of the deepspeed checkpoint contains metadata + assert "ds_version" in metadata + + _assert_saved_model_is_equal(fabric, model, checkpoint_path)