Skip to content

Conversation

WoosukKwon
Copy link
Collaborator

@WoosukKwon WoosukKwon commented Jul 24, 2025

Currently, when num_experts is not divisible evenly by ep_size, the experts are distributed unevenly among ranks. For instance, if there are 30 experts and an ep_size of 8, the existing logic results in an unbalanced distribution: [3, 3, 3, 3, 3, 3, 3, 9], assigning all remainder experts to the last rank.

This PR addresses this issue by distributing the experts in a more balanced manner. After applying this fix, the experts are distributed as [4, 4, 4, 4, 4, 4, 3, 3].

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a more balanced expert sharding strategy when the number of experts is not evenly divisible by the expert parallelism size. The previous implementation assigned all remaining experts to the last rank, which could create a performance bottleneck. The new implementation distributes the remainder experts one by one to the initial ranks, resulting in a much more even distribution. The logic is correct, clear, and improves performance by balancing the workload. I have no further comments and approve these changes.

@WoosukKwon WoosukKwon requested a review from tlrmchlsmth July 24, 2025 04:57
Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice :)

Comment on lines +577 to +590
# Distribute experts as evenly as possible to each rank.
base_experts = global_num_experts // ep_size
remainder = global_num_experts % ep_size
if ep_rank < remainder:
local_num_experts = base_experts + 1
else:
local_num_experts = base_experts

# Create a tensor of size num_experts filled with -1
expert_map = torch.full((global_num_experts, ), -1, dtype=torch.int32)
# Create a expert map for the local experts
if ep_rank < (ep_size - 1):
# Each non-last rank gets local_num_experts experts.
expert_map[ep_rank * local_num_experts:
(ep_rank + 1) * local_num_experts] = \
torch.arange(0, local_num_experts, dtype=torch.int32)
else:
# All remaining experts are assigned to the last rank.
local_num_experts = (global_num_experts - ep_rank * local_num_experts)

expert_map[-local_num_experts:] = \
torch.arange(0, local_num_experts, dtype=torch.int32)
start_idx = ep_rank * base_experts + min(ep_rank, remainder)
expert_map[start_idx:start_idx + local_num_experts] = torch.arange(
0, local_num_experts, dtype=torch.int32)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion for more concise version:

    # Distribute experts as evenly as possible to each rank.
    base_experts = global_num_experts // ep_size
    remainder = global_num_experts % ep_size
    start_idx = lambda r: r * base_experts + min(r, remainder)

    # Create a tensor of size num_experts filled with -1
    expert_map = torch.full((global_num_experts, ), -1, dtype=torch.int32)
    # Create a expert map for the local experts
    expert_map[start_idx(ep_rank):start_idx(ep_rank + 1)] = torch.arange(
        0, local_num_experts, dtype=torch.int32)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I'm not sure if the suggested code is simpler. 😅 We need to compute local_num_experts anyways because it's one of the return values.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WoosukKwon my bad, I missed that we also return that value.

@WoosukKwon WoosukKwon added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 24, 2025
@WoosukKwon WoosukKwon merged commit fe56180 into main Jul 24, 2025
76 of 78 checks passed
@WoosukKwon WoosukKwon deleted the woosuk/fix-ep branch July 24, 2025 22:56
liuyumoye pushed a commit to liuyumoye/vllm that referenced this pull request Jul 31, 2025
x22x22 pushed a commit to x22x22/vllm that referenced this pull request Aug 5, 2025
Pradyun92 pushed a commit to Pradyun92/vllm that referenced this pull request Aug 6, 2025
npanpaliya pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Aug 6, 2025
jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
paulpak58 pushed a commit to paulpak58/vllm that referenced this pull request Aug 13, 2025
diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 28, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants