Skip to content

Conversation

tjtanaa
Copy link
Contributor

@tjtanaa tjtanaa commented Aug 6, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

This is to optimize the mrope forward pass using a triton kernel adapted from https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/qwen2vl_mrope.py to supports flatten input tensors from vLLM and and supports cos and sin cache with shape (3, num_tokens, head_dim // 2)

Related to #22293

Test Plan

  • mrope pass the kernel tests
  • lm_eval of the Qwen2.5-VL-7B-instruct on chartqa

Test Result

{
"explicit_prompt_relaxed_correctness": 0.8644,
"anywhere_in_answer_relaxed_correctness": 0.8644
}

After PR:

{
    "explicit_prompt_relaxed_correctness": 0.8636,
    "anywhere_in_answer_relaxed_correctness": 0.8636
}

Benchmark command

server:

MIOPEN_USER_DB_PATH=/app/vl/miopen \
MIOPEN_FIND_MODE=FAST \
VLLM_USE_V1=1 \
VLLM_ROCM_USE_AITER=1 \
SAFETENSORS_FAST_GPU=1 \
vllm serve Qwen/Qwen2.5-VL-7B-Instruct \
--tensor_parallel_size=1 \
--trust_remote_code \
--port 7899 \

client:

#!/bin/bash
python3 benchmarks/benchmark_serving.py  \
--backend openai-chat   \
--model Qwen/Qwen2.5-VL-7B-Instruct   \
--endpoint /v1/chat/completions   \
--dataset-name hf   \
--dataset-path lmarena-ai/VisionArena-Chat   \
--hf-split train   \
--num-prompts 1000 \
--max-concurrency 64 \
--port 7899 \
> speedtest_torch_upstream.log 2>&1
Metric Torch Implementation Triton Implementation Difference % Improvement
Overall Performance
Successful requests 1,000 1,000 0 0%
Benchmark duration (s) 453.55 447.35 -6.20 1.4% faster
Request throughput (req/s) 2.20 2.24 +0.04 1.8% higher
Output token throughput (tok/s) 251.88 255.26 +3.38 1.3% higher
Total token throughput (tok/s) 459.85 466.12 +6.27 1.4% higher
Time to First Token (TTFT)
Mean TTFT (ms) 4,877.85 4,688.14 -189.71 3.9% faster
Median TTFT (ms) 3,353.20 3,090.05 -263.15 7.8% faster
P99 TTFT (ms) 23,555.70 21,858.97 -1,696.73 7.2% faster
Time per Output Token
Mean TPOT (ms) 226.51 223.63 -2.88 1.3% faster
Median TPOT (ms) 220.16 222.08 +1.92 0.9% slower
P99 TPOT (ms) 508.29 508.06 -0.23 0.0% faster
Inter-token Latency
Mean ITL (ms) 331.43 333.89 +2.46 0.7% slower
Median ITL (ms) 11.02 11.10 +0.08 0.7% slower
P99 ITL (ms) 4,155.08 3,623.85 -531.23 12.8% faster

(Optional) Documentation Update

Signed-off-by: tjtanaa <[email protected]>
Copy link

github-actions bot commented Aug 6, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the performance Performance-related issues label Aug 6, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a Triton kernel for mrope to improve performance, along with corresponding benchmarks and tests. I've identified two critical issues: a NameError in the benchmark script due to an out-of-scope variable, and incorrect pointer arithmetic in the Triton kernel that could lead to out-of-bounds memory access. Please review the detailed comments for fixes.

q_size = num_heads * head_dim
kv_size = num_kv_heads * head_dim
is_neox_style = True
mrope_section = config.rope_scaling["mrope_section"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The mrope_section variable is defined only within this if __name__ == "__main__" block, making it inaccessible to the benchmark_mrope function. This will cause a NameError when benchmark_mrope is called on line 319. To fix this, pass config.rope_scaling["mrope_section"] to the benchmark_mrope function.

                    benchmark_mrope(
                        model_name=model_name,
                        num_tokens=num_tokens,
                        head_dim=head_dim,
                        tp_size=tp_size,
                        num_heads=num_heads,
                        num_kv_heads=num_kv_heads,
                        max_position=max_position,
                        rope_theta=rope_theta,
                        is_neox_style=is_neox_style,
                        rope_scaling=config.rope_scaling,
                        dtype=getattr(torch, args.dtype),
                        seed=args.seed,
                        warmup_iter=args.warmup_iter,
                        benchmark_iter=args.benchmark_iter,
                        csv_writer=csv_writer,
                    )

Comment on lines +108 to +113
t_cos = cos + pid * half_hd
h_cos = t_cos + num_tokens * half_hd
w_cos = h_cos + num_tokens * half_hd
t_sin = sin + pid * half_hd
h_sin = t_sin + num_tokens * half_hd
w_sin = h_sin + num_tokens * half_hd
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The offsets for h_cos, w_cos, h_sin, and w_sin are calculated incorrectly, leading to out-of-bounds memory access. The correct offsets should account for the stride of the cos and sin tensors, which is num_tokens * half_hd.

The current implementation calculates offsets for the h and w dimensions incorrectly:

h_cos = t_cos + num_tokens * half_hd
w_cos = h_cos + num_tokens * half_hd

This results in h_cos pointing to an offset of (pid + num_tokens) * half_hd from the base cos pointer, which attempts to access cos[0, pid + num_tokens, :], an out-of-bounds read along the num_tokens dimension.

To fix this, calculate the offsets relative to the base pointer and consider the stride of the first dimension.

Suggested change
t_cos = cos + pid * half_hd
h_cos = t_cos + num_tokens * half_hd
w_cos = h_cos + num_tokens * half_hd
t_sin = sin + pid * half_hd
h_sin = t_sin + num_tokens * half_hd
w_sin = h_sin + num_tokens * half_hd
dim0_stride = num_tokens * half_hd
token_offset = pid * half_hd
t_cos = cos + token_offset
h_cos = cos + dim0_stride + token_offset
w_cos = cos + 2 * dim0_stride + token_offset
t_sin = sin + token_offset
h_sin = sin + dim0_stride + token_offset
w_sin = sin + 2 * dim0_stride + token_offset

Signed-off-by: tjtanaa <[email protected]>
@tjtanaa
Copy link
Contributor Author

tjtanaa commented Aug 6, 2025

/gemini review

@DarkLight1337
Copy link
Member

cc @imkero @vadiklyutiy

@tjtanaa
Copy link
Contributor Author

tjtanaa commented Aug 7, 2025

CC. @wuhuikx

Comment on lines 17 to 34
# vLLM Native implementation of mrope forward pass
# used for benchmarking and unit testing
def mrope_forward_native(
positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor,
cos: torch.Tensor, sin: torch.Tensor, mrope_section: list[int],
is_neox_style: bool, head_size: int,
rotary_dim: int) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward().
Args:
positions:
[num_tokens,] (text only) or
[3, num_tokens] (T/H/W positions with multimodal inputs)
query: [num_tokens, num_heads * head_size]
key: [num_tokens, num_kv_heads * head_size]
cos: [3, num_tokens, head_dim // 2]
sin: [3, num_tokens, head_dim // 2]
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you move this to MRotaryEmbedding's forward_native method?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added back the forward_native code path

@wuhuikx
Copy link

wuhuikx commented Aug 7, 2025

cc @sunway513 we have a triton kernel for mrope.

@tjtanaa
Copy link
Contributor Author

tjtanaa commented Aug 7, 2025

@Isotr0py Ready for another round of review. Thank you.

@Isotr0py
Copy link
Member

Isotr0py commented Aug 8, 2025

Can you address the pre-commit issue?

Copy link
Member

@Isotr0py Isotr0py left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM now!

@Isotr0py Isotr0py added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 8, 2025
@tjtanaa
Copy link
Contributor Author

tjtanaa commented Aug 8, 2025

@Isotr0py Thank you very much for reviewing this PR and help me to iterate through at a very fast pace. Amazing !

Copy link
Contributor

@vadiklyutiy vadiklyutiy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR looks good for me!

Only one thought. When I experimented with covered vision part with torch.compile the performance of mrope looked good as well. I'm not sure what would better this PR or torch.compile. Like the kernel itself here is better but torch.compile have benefits from fusion with around kernels.

@vllm-bot vllm-bot merged commit 42172ad into vllm-project:main Aug 9, 2025
35 of 42 checks passed
paulpak58 pushed a commit to paulpak58/vllm that referenced this pull request Aug 13, 2025
diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
yiliu30 pushed a commit to yiliu30/vllm-fork that referenced this pull request Aug 19, 2025
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 28, 2025
xiao-llm pushed a commit to xiao-llm/vllm that referenced this pull request Aug 28, 2025
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Aug 28, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants