Skip to content

Conversation

huanyuqu
Copy link
Contributor

This PR fixes an AttributeError that occurs during deepspeed.init_inference when using kernel injection (replace_with_kernel_inject=True) with Llama models from recent versions of transformers.

The Bug:

In newer transformers versions (e.g., 4.53.3), configurations like num_heads and rope_theta were moved from direct attributes of the LlamaAttention module into a nested config object.

The current DeepSpeed injection policy tries to access these attributes from their old, direct location, causing the initialization to fail with an AttributeError: 'LlamaAttention' object has no attribute 'num_heads'.

The Solution:

This change updates the Llama injection logic to be more robust:

  1. It first tries to read attributes like num_heads from the new config object location.
  2. If that fails, it falls back to the legacy direct attribute path.

@huanyuqu huanyuqu force-pushed the fix-llama-injection-compat branch from dab455e to 3a605b7 Compare July 23, 2025 13:32
@loadams loadams requested a review from stas00 July 23, 2025 17:07
Copy link
Collaborator

@stas00 stas00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the adaptation, @huanyuqu

Your PR is almost perfect, but missing a test reproducing the failure - do you know why the existing llama injection policy tests didn't discover the change in HF Transformers (perhaps they were missing in the first place)?

Totally good to merge, but just flagging that potentially some tests are missing. Perhaps the authors of this feature should be tagged to add tests?

@huanyuqu
Copy link
Contributor Author

Thank you for the review and approval, @stas00

You're right about the missing test. After searching for replace_with_kernel_inject across the tests folder, I noticed it's mainly used in tests/unit/inference/test_inference.py and test_checkpoint_sharding.py. I then confirmed that Llama models were not included in any of the existing test suites for kernel injection. The current tests primarily cover models like BERT, Roberta, GPT, and OPT, which explains why this regression wasn't caught by the CI.

To fix this, I am working on adding a new, focused test case (TestLlamaInjection) to that file. It will use a small Llama variant (e.g., huggyllama/llama-7b) to specifically validate that kernel injection works correctly after this patch. I'll push the new commit with this test shortly.

As an alternative, we could also add Llama models to the main TestModelTask suite for more comprehensive testing coverage. I've created a separate test case for now to keep it fast and focused, but I'm happy to discuss integrating it more broadly if you think that's a better approach.

@huanyuqu huanyuqu force-pushed the fix-llama-injection-compat branch from f9e9e9f to 23af0d2 Compare July 24, 2025 08:34
@huanyuqu
Copy link
Contributor Author

Following up on the suggestion from @stas00, I've pushed a new commit that adds a dedicated unit test (TestLlamaInjection) for this fix.

The new test case specifically targets the Llama kernel injection path:

  • Without my fix (e.g., when run against the master branch), the test correctly catches the AttributeError: 'LlamaAttention' object has no attribute 'num_heads'. It then calls pytest.skip, marking the test as skipped and confirming the presence of the bug. I have confirmed this behavior locally.
  • With my fix applied, the AttributeError is no longer raised, the except block is bypassed, and the test proceeds to validate the injection and inference, ultimately passing.

huanyuqu added 3 commits July 25, 2025 16:18
The LlamaAttention module in recent transformers versions stores configuration
like 'num_heads' and 'rope_theta' in a 'config' object, rather than as
direct attributes.

This change updates the LlamaLayerPolicy to fetch these attributes from
the 'config' object first, falling back to the direct attribute to maintain
backward compatibility. This resolves the AttributeError during kernel
injection with newer transformers versions.

Signed-off-by: huanyuqu <[email protected]>
This commit adds a new test case, TestLlamaInjection, to the
inference test suite.

It specifically validates the fix from the previous commit by running
kernel injection on a Llama model. This ensures that the AttributeError
is resolved and helps prevent future regressions.

Signed-off-by: huanyuqu <[email protected]>
The original Llama injection test was using a large 7B model, which is not suitable for CI.

This commit modifies the test to:
1. Use a manually created, small Llama model with a standard configuration.
2. This new approach successfully reproduces the original bug and passes after the fix.

Signed-off-by: huanyuqu <[email protected]>
@huanyuqu huanyuqu force-pushed the fix-llama-injection-compat branch from b135876 to e327ca3 Compare July 25, 2025 08:18
@stas00 stas00 merged commit 092625c into deepspeedai:master Jul 26, 2025
9 checks passed
lpnpcs pushed a commit to lpnpcs/DeepSpeed that referenced this pull request Jul 30, 2025
…epspeedai#7443)

This PR fixes an `AttributeError` that occurs during
`deepspeed.init_inference` when using kernel injection
(`replace_with_kernel_inject=True`) with Llama models from recent
versions of `transformers`.

**The Bug:**

In newer `transformers` versions (e.g., `4.53.3`), configurations like
`num_heads` and `rope_theta` were moved from direct attributes of the
`LlamaAttention` module into a nested `config` object.

The current DeepSpeed injection policy tries to access these attributes
from their old, direct location, causing the initialization to fail with
an `AttributeError: 'LlamaAttention' object has no attribute
'num_heads'`.

**The Solution:**

This change updates the Llama injection logic to be more robust:
1. It first tries to read attributes like `num_heads` from the new
`config` object location.
2. If that fails, it falls back to the legacy direct attribute path.

---------

Signed-off-by: huanyuqu <[email protected]>
LYMDLUT pushed a commit to LYMDLUT/DeepSpeed that referenced this pull request Aug 20, 2025
…epspeedai#7443)

This PR fixes an `AttributeError` that occurs during
`deepspeed.init_inference` when using kernel injection
(`replace_with_kernel_inject=True`) with Llama models from recent
versions of `transformers`.

**The Bug:**

In newer `transformers` versions (e.g., `4.53.3`), configurations like
`num_heads` and `rope_theta` were moved from direct attributes of the
`LlamaAttention` module into a nested `config` object.

The current DeepSpeed injection policy tries to access these attributes
from their old, direct location, causing the initialization to fail with
an `AttributeError: 'LlamaAttention' object has no attribute
'num_heads'`.

**The Solution:**

This change updates the Llama injection logic to be more robust:
1. It first tries to read attributes like `num_heads` from the new
`config` object location.
2. If that fails, it falls back to the legacy direct attribute path.

---------

Signed-off-by: huanyuqu <[email protected]>
Signed-off-by: lym <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants