Skip to content

Conversation

houseroad
Copy link
Collaborator

Summary:
CUDA kernel variables matching the type (thread|block|grid).(Idx|Dim).(x|y|z) have the data type uint.

Many programmers mistakenly use implicit casts to turn these data types into int. In fact, the CUDA Programming Guide it self is inconsistent and incorrect in its use of data types in programming examples.

The result of these implicit casts is that our kernels may give unexpected results when exposed to large datasets, i.e., those exceeding >~2B items.

While we now have linters in place to prevent simple mistakes (D71236150), our codebase has many problematic instances. This diff fixes some of them.

Differential Revision: D71355454

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@facebook-github-bot
Copy link

This pull request was exported from Phabricator. Differential Revision: D71355454

Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Also cc @tlrmchlsmth

@comaniac comaniac added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 20, 2025
@DarkLight1337
Copy link
Member

Can you update this after #15282 has been merged?

@houseroad
Copy link
Collaborator Author

rebased. hope the CI can be fixed. Keep running into timeout... sigh.

@DarkLight1337
Copy link
Member

Kernels tests aren't failing on main, so maybe there is indeed something wrong with this PR?

@houseroad
Copy link
Collaborator Author

kernel test-2:

attempt 1, failed on

[2025-03-22T08:06:54Z] kernels/test_marlin_gemm.py::test_gptq_marlin_gemm[False-False-False-False-mnk_factors1-128-quant_type1-256-128] PASSED

  | [2025-03-22T09:27:08Z] kernels/test_marlin_gemm.py::test_gptq_marlin_gemm[False-False-False-False-mnk_factors2--1-quant_type1-256-128] # Received cancellation signal, interrupting

attempt 2, failed on

[2025-03-23T00:44:47Z] kernels/test_marlin_gemm.py::test_gptq_marlin_gemm[False-False-False-False-mnk_factors1-128-quant_type1-256-128] PASSED

  | [2025-03-23T02:03:09Z] kernels/test_marlin_gemm.py::test_gptq_marlin_gemm[False-False-False-False-mnk_factors2--1-quant_type1-256-128] # Received cancellation signal, interrupting
  | [2025-03-23T02:03:11Z] 🚨 Error: The command exited with status -1

attempt3 failed on

[2025-03-23T06:29:36Z] kernels/test_marlin_gemm.py::test_gptq_marlin_gemm[False-False-False-False-mnk_factors1-64-quant_type0-256-128] PASSED

  | [2025-03-23T06:29:36Z] kernels/test_marlin_gemm.py::test_gptq_marlin_gemm[False-False-False-False-mnk_factors1-128-quant_type1-256-128] PASSED
  | [2025-03-23T07:50:10Z] kernels/test_marlin_gemm.py::test_gptq_marlin_gemm[False-False-False-False-mnk_factors2--1-quant_type1-256-128] # Received cancellation signal, interrupting
  | [2025-03-23T07:50:12Z] 🚨 Error: The command exited with status -1

Seems problematic, let me take a close look

@houseroad houseroad marked this pull request as draft March 24, 2025 05:11
@houseroad houseroad marked this pull request as ready for review March 24, 2025 07:06
@houseroad
Copy link
Collaborator Author

I think I found the problematic place:

int delta_first = iters * blockIdx.x - col_first; // this shouldn't be changed to auto. Since the type deduce to uint, where iters and col_first are int, and blockIdx.x is uint.

r-barnes and others added 2 commits March 24, 2025 09:26
…/awq_marlin_repack.cu +10

Summary:
CUDA kernel variables matching the type `(thread|block|grid).(Idx|Dim).(x|y|z)` [have the data type `uint`](https://docs.nvidia.com/cuda/cuda-c-programming-guide/#built-in-variables).

Many programmers mistakenly use implicit casts to turn these data types into `int`. In fact, the [CUDA Programming Guide](https://docs.nvidia.com/cuda/cuda-c-programming-guide/) it self is inconsistent and incorrect in its use of data types in programming examples.

The result of these implicit casts is that our kernels may give unexpected results when exposed to large datasets, i.e., those exceeding >~2B items.

While we now have linters in place to prevent simple mistakes (D71236150), our codebase has many problematic instances. This diff fixes some of them.

Differential Revision: D71355454

Signed-off-by: Lu Fang <[email protected]>
@houseroad houseroad closed this Mar 25, 2025
@houseroad houseroad reopened this Mar 25, 2025
@houseroad
Copy link
Collaborator Author

The problems should be fixed, cc: @DarkLight1337

@DarkLight1337
Copy link
Member

Pre-commit is failing

@houseroad
Copy link
Collaborator Author

It shows all 60 checks passed

@DarkLight1337
Copy link
Member

Weird, the mobile app showed that it's failing. I'm back on my PC now and everything looks fine, sorry for the confusion!

@DarkLight1337 DarkLight1337 merged commit 051da7e into vllm-project:main Mar 25, 2025
61 of 62 checks passed
erictang000 pushed a commit to erictang000/vllm that referenced this pull request Mar 25, 2025
…/awq_marlin_repack.cu +10 (vllm-project#15160)

Signed-off-by: Lu Fang <[email protected]>
Co-authored-by: Richard Barnes <[email protected]>
wrmedford pushed a commit to wrmedford/vllm that referenced this pull request Mar 26, 2025
…/awq_marlin_repack.cu +10 (vllm-project#15160)

Signed-off-by: Lu Fang <[email protected]>
Co-authored-by: Richard Barnes <[email protected]>
Signed-off-by: Wes Medford <[email protected]>
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
…/awq_marlin_repack.cu +10 (vllm-project#15160)

Signed-off-by: Lu Fang <[email protected]>
Co-authored-by: Richard Barnes <[email protected]>
Signed-off-by: Louis Ulmer <[email protected]>
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
…/awq_marlin_repack.cu +10 (vllm-project#15160)

Signed-off-by: Lu Fang <[email protected]>
Co-authored-by: Richard Barnes <[email protected]>
shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
…/awq_marlin_repack.cu +10 (vllm-project#15160)

Signed-off-by: Lu Fang <[email protected]>
Co-authored-by: Richard Barnes <[email protected]>
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
…/awq_marlin_repack.cu +10 (vllm-project#15160)

Signed-off-by: Lu Fang <[email protected]>
Co-authored-by: Richard Barnes <[email protected]>
Signed-off-by: Mu Huai <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants