Skip to content

Commit 15f054d

Browse files
authored
fix: engine initializes optimizer attributes at the beginning (#7410)
As in `destroy`, `self.optimizer` is called, but the error out calling to `destroy` can happen in `__init__`, even before optimizer and scheduler is configured. So we need to move `self.optimizer` to the top to avoid triggering another exception. e.g.: ```logs File "deepspeed/runtime/engine.py", line 453, in _configure_tensor_parallel_states assert self.zero_optimization_stage( AssertionError: Currently, the compatibility between 'autotp' and 'zero_stage = 3' has not been validated Exception ignored in: <function DeepSpeedEngine.__del__ at 0x1516c0610820> Traceback (most recent call last): File "deepspeed/runtime/engine.py", line 509, in __del__ self.destroy() File "deepspeed/runtime/engine.py", line 512, in destroy if self.optimizer is not None and hasattr(self.optimizer, 'destroy'): File "deepspeed/runtime/engine.py", line 621, in __getattr__ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") AttributeError: 'DeepSpeedEngine' object has no attribute 'optimizer' ``` Signed-off-by: Hollow Man <[email protected]>
1 parent da60a87 commit 15f054d

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

deepspeed/runtime/engine.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,9 @@ def __init__(self,
243243
self._global_grad_norm = None
244244
self.use_ds_comm = False # False --> Use torch.dist, True --> Use ds.comm backend.
245245
self.checkpoint_engine = None
246+
self.optimizer = None
247+
self.basic_optimizer = None
248+
self.lr_scheduler = None
246249

247250
self._is_gradient_accumulation_boundary = None
248251
self.scale_wrt_gas = None
@@ -313,9 +316,6 @@ def __init__(self,
313316
self.training_dataloader = None
314317

315318
# Configure optimizer and scheduler
316-
self.optimizer = None
317-
self.basic_optimizer = None
318-
self.lr_scheduler = None
319319
has_optimizer = False
320320

321321
if optimizer or self.optimizer_name():

0 commit comments

Comments
 (0)