Skip to content

Conversation

jeejeelee
Copy link
Collaborator

@jeejeelee jeejeelee commented Aug 12, 2025

Essential Elements of an Effective PR Description Checklist

Purpose

  • Decouple Glm4vForConditionalGeneration and Glm4vMoeForConditionalGeneration. On one hand, their packed_modules_mapping have some differences(the root cause of [Bug]: GLM-4.1V lora trained model reports target_module mismatch error #22077), and on the other hand, it also facilitates the maintainability of the model code.
  • Currently, Glm4vMoeForConditionalGeneration does not support LoRA and will raise the following error. I have removed the LoRA label for now and will investigate the issue ASAP
(EngineCore_0 pid=1393314) (VllmWorker TP0 pid=1393320) (VllmWorker TP1 pid=1393322) (EngineCore_0 pid=1393314) (VllmWorker TP2 pid=1393324) ERROR 08-12 16:22:17 [multiproc_executor.py:596]   File "/root/anaconda3/envs/py310_vllm_dev/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
ERROR 08-12 16:22:17 [multiproc_executor.py:596]   File "/root/Code/vllm_dev/vllm/vllm/lora/ops/triton_ops/lora_shrink_op.py", line 139, in _lora_shrink
(VllmWorker TP3 pid=1393326) ERROR 08-12 16:22:17 [multiproc_executor.py:596]   File "/root/Code/vllm_dev/vllm/vllm/lora/ops/triton_ops/lora_shrink_op.py", line 139, in _lora_shrink
(EngineCore_0 pid=1393314) (EngineCore_0 pid=1393314) (VllmWorker TP0 pid=1393320) ERROR 08-12 16:22:17 [multiproc_executor.py:596]     assert inputs.dtype == lora_a_weights[0].dtype
(EngineCore_0 pid=1393314) (VllmWorker TP1 pid=1393322) ERROR 08-12 16:22:17 [multiproc_executor.py:596]     return func(*args, **kwargs)
(EngineCore_0 pid=1393314) (VllmWorker TP2 pid=1393324) ERROR 08-12 16:22:17 [multiproc_executor.py:596]     assert inputs.dtype == lora_a_weights[0].dtype
(EngineCore_0 pid=1393314) (VllmWorker TP3 pid=1393326) ERROR 08-12 16:22:17 [multiproc_executor.py:596]     assert inputs.dtype == lora_a_weights[0].dtype
(EngineCore_0 pid=1393314) (VllmWorker TP0 pid=1393320) ERROR 08-12 16:22:17 [multiproc_executor.py:596] AssertionError

cc @zRzRzRzRzRzRzR

Test Plan

Test Result

(Optional) Documentation Update

Signed-off-by: Jee Jee Li <[email protected]>
@jeejeelee jeejeelee requested a review from hmellor as a code owner August 12, 2025 16:47
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added documentation Improvements or additions to documentation new-model Requests to new models labels Aug 12, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively decouples Glm4vForConditionalGeneration and Glm4vMoeForConditionalGeneration to address differences in their configurations and improve maintainability. The changes to the model registry, documentation, and packed_modules_mapping are appropriate. I have one suggestion to make the code more robust by explicitly disabling LoRA support for Glm4vMoeForConditionalGeneration to prevent potential runtime errors, aligning the code's behavior with the stated intention in the documentation and pull request description.

Comment on lines +1579 to +1589
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The pull request description and documentation state that LoRA is not supported for Glm4vMoeForConditionalGeneration due to a bug. However, the class currently inherits SupportsLoRA from its parent, which will cause vLLM to attempt to apply LoRA adapters if provided, leading to a runtime error. To prevent this, I suggest adding an __init__ method that explicitly checks for and disallows LoRA configuration for this model. This will provide a clear error message to users and make the code's behavior consistent with the documentation.

Suggested change
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# LoRA is not supported for this model yet.
if vllm_config.lora_config:
raise NotImplementedError(
"LoRA is not currently supported for "
"Glm4vMoeForConditionalGeneration."
)
super().__init__(vllm_config=vllm_config, prefix=prefix)

@jeejeelee jeejeelee removed the new-model Requests to new models label Aug 12, 2025
@mergify mergify bot added the new-model Requests to new models label Aug 12, 2025
@Isotr0py Isotr0py enabled auto-merge (squash) August 12, 2025 17:03
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 12, 2025
@vllm-bot vllm-bot merged commit fde0b61 into vllm-project:main Aug 13, 2025
48 of 56 checks passed
@jeejeelee jeejeelee deleted the decouple-glm4v branch August 13, 2025 00:20
diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
Signed-off-by: Jee Jee Li <[email protected]>
Signed-off-by: Diego-Castan <[email protected]>
yiliu30 pushed a commit to yiliu30/vllm-fork that referenced this pull request Aug 19, 2025
divakar-amd pushed a commit to divakar-amd/vllm_upstream that referenced this pull request Aug 20, 2025
HeJunyan added a commit to HeJunyan/vllm-fork that referenced this pull request Aug 20, 2025
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 28, 2025
xiao-llm pushed a commit to xiao-llm/vllm that referenced this pull request Aug 28, 2025
Signed-off-by: Jee Jee Li <[email protected]>
Signed-off-by: Xiao Yu <[email protected]>
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Aug 28, 2025
HeJunyan added a commit to HeJunyan/vllm-fork that referenced this pull request Sep 22, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation new-model Requests to new models ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug]: GLM-4.1V lora trained model reports target_module mismatch error
3 participants