19
19
from contextlib import contextmanager
20
20
from itertools import chain
21
21
from pathlib import Path
22
- from typing import Any , Dict , Generator , Iterable , List , Mapping , Optional , Tuple , TYPE_CHECKING , Union
22
+ from typing import Any , Dict , Generator , List , Mapping , Optional , Tuple , TYPE_CHECKING , Union
23
23
24
24
import torch
25
25
from lightning_utilities .core .imports import RequirementCache
@@ -428,6 +428,7 @@ def load_checkpoint(
428
428
torch .cuda .empty_cache ()
429
429
430
430
from deepspeed import DeepSpeedEngine
431
+
431
432
modules = chain (* (module .modules () for module in state .values () if isinstance (module , Module )))
432
433
engines = [engine for engine in modules if isinstance (engine , DeepSpeedEngine )]
433
434
print (list (modules ))
@@ -448,7 +449,10 @@ def load_checkpoint(
448
449
449
450
optimzer_state_requested = bool (len ([item for item in state .values () if isinstance (item , Optimizer )]))
450
451
_ , client_state = engine .load_checkpoint (
451
- path , tag = "checkpoint" , load_optimizer_states = optimzer_state_requested , load_lr_scheduler_states = False ,
452
+ path ,
453
+ tag = "checkpoint" ,
454
+ load_optimizer_states = optimzer_state_requested ,
455
+ load_lr_scheduler_states = False ,
452
456
load_module_strict = True , # TODO: make strict loading configurable
453
457
)
454
458
if client_state is None :
0 commit comments