-
Notifications
You must be signed in to change notification settings - Fork 4.5k
Fix: Adapt Llama injection policy for newer transformers versions #7443
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix: Adapt Llama injection policy for newer transformers versions #7443
Conversation
dab455e
to
3a605b7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the adaptation, @huanyuqu
Your PR is almost perfect, but missing a test reproducing the failure - do you know why the existing llama injection policy tests didn't discover the change in HF Transformers (perhaps they were missing in the first place)?
Totally good to merge, but just flagging that potentially some tests are missing. Perhaps the authors of this feature should be tagged to add tests?
Thank you for the review and approval, @stas00 You're right about the missing test. After searching for To fix this, I am working on adding a new, focused test case ( As an alternative, we could also add Llama models to the main |
f9e9e9f
to
23af0d2
Compare
Following up on the suggestion from @stas00, I've pushed a new commit that adds a dedicated unit test ( The new test case specifically targets the Llama kernel injection path:
|
The LlamaAttention module in recent transformers versions stores configuration like 'num_heads' and 'rope_theta' in a 'config' object, rather than as direct attributes. This change updates the LlamaLayerPolicy to fetch these attributes from the 'config' object first, falling back to the direct attribute to maintain backward compatibility. This resolves the AttributeError during kernel injection with newer transformers versions. Signed-off-by: huanyuqu <[email protected]>
This commit adds a new test case, TestLlamaInjection, to the inference test suite. It specifically validates the fix from the previous commit by running kernel injection on a Llama model. This ensures that the AttributeError is resolved and helps prevent future regressions. Signed-off-by: huanyuqu <[email protected]>
The original Llama injection test was using a large 7B model, which is not suitable for CI. This commit modifies the test to: 1. Use a manually created, small Llama model with a standard configuration. 2. This new approach successfully reproduces the original bug and passes after the fix. Signed-off-by: huanyuqu <[email protected]>
b135876
to
e327ca3
Compare
Signed-off-by: huanyuqu <[email protected]>
Signed-off-by: huanyuqu <[email protected]>
…epspeedai#7443) This PR fixes an `AttributeError` that occurs during `deepspeed.init_inference` when using kernel injection (`replace_with_kernel_inject=True`) with Llama models from recent versions of `transformers`. **The Bug:** In newer `transformers` versions (e.g., `4.53.3`), configurations like `num_heads` and `rope_theta` were moved from direct attributes of the `LlamaAttention` module into a nested `config` object. The current DeepSpeed injection policy tries to access these attributes from their old, direct location, causing the initialization to fail with an `AttributeError: 'LlamaAttention' object has no attribute 'num_heads'`. **The Solution:** This change updates the Llama injection logic to be more robust: 1. It first tries to read attributes like `num_heads` from the new `config` object location. 2. If that fails, it falls back to the legacy direct attribute path. --------- Signed-off-by: huanyuqu <[email protected]>
…epspeedai#7443) This PR fixes an `AttributeError` that occurs during `deepspeed.init_inference` when using kernel injection (`replace_with_kernel_inject=True`) with Llama models from recent versions of `transformers`. **The Bug:** In newer `transformers` versions (e.g., `4.53.3`), configurations like `num_heads` and `rope_theta` were moved from direct attributes of the `LlamaAttention` module into a nested `config` object. The current DeepSpeed injection policy tries to access these attributes from their old, direct location, causing the initialization to fail with an `AttributeError: 'LlamaAttention' object has no attribute 'num_heads'`. **The Solution:** This change updates the Llama injection logic to be more robust: 1. It first tries to read attributes like `num_heads` from the new `config` object location. 2. If that fails, it falls back to the legacy direct attribute path. --------- Signed-off-by: huanyuqu <[email protected]> Signed-off-by: lym <[email protected]>
This PR fixes an
AttributeError
that occurs duringdeepspeed.init_inference
when using kernel injection (replace_with_kernel_inject=True
) with Llama models from recent versions oftransformers
.The Bug:
In newer
transformers
versions (e.g.,4.53.3
), configurations likenum_heads
andrope_theta
were moved from direct attributes of theLlamaAttention
module into a nestedconfig
object.The current DeepSpeed injection policy tries to access these attributes from their old, direct location, causing the initialization to fail with an
AttributeError: 'LlamaAttention' object has no attribute 'num_heads'
.The Solution:
This change updates the Llama injection logic to be more robust:
num_heads
from the newconfig
object location.