Skip to content

Commit 23af0d2

Browse files
committed
Fix: Adapt Llama injection policy for newer transformers versions
The LlamaAttention module in recent transformers versions stores configuration like 'num_heads' and 'rope_theta' in a 'config' object, rather than as direct attributes. This change updates the LlamaLayerPolicy to fetch these attributes from the 'config' object first, falling back to the direct attribute to maintain backward compatibility. This resolves the AttributeError during kernel injection with newer transformers versions. Signed-off-by: huanyuqu <[email protected]>
1 parent 70caefe commit 23af0d2

File tree

1 file changed

+9
-2
lines changed
  • deepspeed/module_inject/containers

1 file changed

+9
-2
lines changed

deepspeed/module_inject/containers/llama.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,10 @@ def create_module(self, config=None):
3434
_config.rotate_half = True
3535
_config.rotate_every_two = False
3636
_config.rotary_dim = self.hidden_size // self.num_attention_heads
37-
_config.rope_theta = self.policy.client_module.self_attn.rope_theta
37+
if hasattr(self.policy.client_module.self_attn, 'config'):
38+
_config.rope_theta = self.policy.client_module.self_attn.config.rope_theta
39+
else:
40+
_config.rope_theta = self.policy.client_module.self_attn.rope_theta
3841
self.module = DeepSpeedGPTInference(_config, mp_group=self.mp_group)
3942

4043
return self.module
@@ -128,9 +131,13 @@ def __init__(self, client_module, inference=True):
128131
LLAMALayerPolicy._orig_layer_class = None
129132

130133
def get_hidden_heads(self):
134+
if hasattr(self.client_module.self_attn, 'config'):
135+
num_heads = self.client_module.self_attn.config.num_attention_heads
136+
else:
137+
num_heads = self.client_module.self_attn.num_heads
131138
hidden_heads = (
132139
self.client_module.self_attn.q_proj.in_features,
133-
self.client_module.self_attn.num_heads,
140+
num_heads,
134141
self.client_module.input_layernorm.variance_epsilon,
135142
self.client_module.mlp.gate_proj.out_features,
136143
)

0 commit comments

Comments
 (0)