Skip to content

Commit b9df0dd

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 70ed73b commit b9df0dd

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

src/lightning_fabric/strategies/deepspeed.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from contextlib import contextmanager
2020
from itertools import chain
2121
from pathlib import Path
22-
from typing import Any, Dict, Generator, Iterable, List, Mapping, Optional, Tuple, TYPE_CHECKING, Union
22+
from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, TYPE_CHECKING, Union
2323

2424
import torch
2525
from lightning_utilities.core.imports import RequirementCache
@@ -428,6 +428,7 @@ def load_checkpoint(
428428
torch.cuda.empty_cache()
429429

430430
from deepspeed import DeepSpeedEngine
431+
431432
modules = chain(*(module.modules() for module in state.values() if isinstance(module, Module)))
432433
engines = [engine for engine in modules if isinstance(engine, DeepSpeedEngine)]
433434
print(list(modules))
@@ -448,7 +449,10 @@ def load_checkpoint(
448449

449450
optimzer_state_requested = bool(len([item for item in state.values() if isinstance(item, Optimizer)]))
450451
_, client_state = engine.load_checkpoint(
451-
path, tag="checkpoint", load_optimizer_states=optimzer_state_requested, load_lr_scheduler_states=False,
452+
path,
453+
tag="checkpoint",
454+
load_optimizer_states=optimzer_state_requested,
455+
load_lr_scheduler_states=False,
452456
load_module_strict=True, # TODO: make strict loading configurable
453457
)
454458
if client_state is None:

tests/tests_fabric/strategies/test_deepspeed_integration.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -346,9 +346,10 @@ def step(self, model, batch):
346346

347347

348348
def _assert_saved_model_is_equal(fabric, model, checkpoint_path):
349-
"""Convert the saved checkpoint to a single file with the model weights consolidated to easily verify the
350-
full weights in float32 precision."""
349+
"""Convert the saved checkpoint to a single file with the model weights consolidated to easily verify the full
350+
weights in float32 precision."""
351351
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
352+
352353
assert isinstance(fabric.strategy, DeepSpeedStrategy)
353354

354355
# carry out the check only on rank 0

0 commit comments

Comments
 (0)