Skip to content

[BUG]zero2 + autotp: IndexError: tuple index out of range #7249

@lyx564

Description

@lyx564

Describe the bug
I use zero2 + autotp to perform sft on the qwen2.5-7b model. When training to save_steps to save the model, there is an error

[rank0]: Traceback (most recent call last):
[rank0]:   File "/data/code/llm_train/transformers_deepspeed/train.py", line 296, in <module>
[rank0]:     train()
[rank0]:   File "/data/code/llm_train/transformers_deepspeed/train.py", line 288, in train
[rank0]:     trainer.train()
[rank0]:   File "/usr/local/python/lib/python3.10/site-packages/transformers/trainer.py", line 2245, in train
[rank0]:     return inner_training_loop(
[rank0]:   File "/usr/local/python/lib/python3.10/site-packages/transformers/trainer.py", line 2627, in _inner_training_loop
[rank0]:     self._maybe_log_save_evaluate(
[rank0]:   File "/usr/local/python/lib/python3.10/site-packages/transformers/trainer.py", line 3103, in _maybe_log_save_evaluate
[rank0]:     self._save_checkpoint(model, trial)
[rank0]:   File "/usr/local/python/lib/python3.10/site-packages/transformers/trainer.py", line 3200, in _save_checkpoint
[rank0]:     self.save_model(output_dir, _internal_call=True)
[rank0]:   File "/usr/local/python/lib/python3.10/site-packages/transformers/trainer.py", line 3887, in save_model
[rank0]:     state_dict = self.accelerator.get_state_dict(self.deepspeed)
[rank0]:   File "/usr/local/python/lib/python3.10/site-packages/accelerate/accelerator.py", line 3591, in get_state_dict
[rank0]:     model._consolidated_16bit_state_dict()
[rank0]:   File "/usr/local/python/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 3715, in _consolidated_16bit_state_dict
[rank0]:     return self._replace_module_consolidated_state_dict()
[rank0]:   File "/usr/local/python/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 3702, in _replace_module_consolidated_state_dict
[rank0]:     get_layer_state_dict(self.module, prefix="")
[rank0]:   File "/usr/local/python/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 3700, in get_layer_state_dict
[rank0]:     get_layer_state_dict(child, prefix + name + ".")
[rank0]:   File "/usr/local/python/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 3700, in get_layer_state_dict
[rank0]:     get_layer_state_dict(child, prefix + name + ".")
[rank0]:   File "/usr/local/python/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 3700, in get_layer_state_dict
[rank0]:     get_layer_state_dict(child, prefix + name + ".")
[rank0]:   [Previous line repeated 2 more times]
[rank0]:   File "/usr/local/python/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 3689, in get_layer_state_dict
[rank0]:     with GatherReplacedLayerParams(list(module.parameters(recurse=False)), module, enabled=True):
[rank0]:   File "/usr/local/python/lib/python3.10/site-packages/deepspeed/module_inject/layers.py", line 314, in __enter__
[rank0]:     self.params[0].gather_params(self.params)
[rank0]:   File "/usr/local/python/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/usr/local/python/lib/python3.10/site-packages/deepspeed/module_inject/layers.py", line 423, in gather_params
[rank0]:     param.shape[1],
[rank0]: IndexError: tuple index out of range

the ds_config.json

{
    "zero_optimization": {
        "stage": 2,
        "overlap_comm": true,
        "reduce_bucket_size": 2e8,
        "contiguous_gradients": true,
        "gather_16bit_weights_on_model_save": true,
        "offload_optimizer": {"device": "cpu"}
    },
    "tensor_parallel":{
        "autotp_size": 2
    },
    "offload_optimizer": {
        "device": "cpu",
        "pin_memory": true
    },
    "offload_param": {
        "device": "cpu",
        "pin_memory": true
    },
    "precision": "bfloat16",
    "bf16": {
        "enabled": true
    },
    "data_types": {
        "grad_accum_dtype": "bf16"
    },
    "optimizer": {
        "type": "AdamW",
        "params": {
          "lr": 2e-5,
          "betas": [
            0.9,
            0.999
          ],
          "eps": 1e-8,
          "weight_decay": 3e-7
        }
    },
    "scheduler": {
        "type": "WarmupDecayLR",
        "params": {
            "warmup_min_lr": 1e-5,
            "warmup_max_lr": 2e-5,
            "warmup_num_steps": "auto",
            "total_num_steps": "auto"
        }
    },
    "train_batch_size": "auto",
    "gradient_accumulation_steps": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "wall_clock_breakdown": false
}

the train.py

    model = AutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=training_args.cache_dir,
        trust_remote_code=True
    )
    config, kwargs = AutoConfig.from_pretrained(
        model_args.model_name_or_path,
        return_unused_kwargs=True,
        trust_remote_code=True,
    )
    trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args,
                      callbacks=[TimeInfoCallback(), TensorboardCallback()], train_dataset=train_dataset,
                      eval_dataset=eval_dataset, data_collator=data_collator)
    trainer.train()
    trainer.save_model(output_dir=training_args.output_dir)

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingtraining

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions