Skip to content

Conversation

jmswen
Copy link
Contributor

@jmswen jmswen commented Jun 3, 2025

Purpose

This change allows an AsyncLLMEngine to target a specific DP rank. This is useful in contexts where load balancing to DP ranks happens externally (e.g., via Dynamo).

Test Plan

Added an example. To run the example, run:

python examples/online_serving/multi_instance_data_parallel.py

vllm serve ibm-research/PowerMoE-3b -dp 2 -dpr 1 \
    --data-parallel-address 127.0.0.1 --data-parallel-rpc-port 62300 \
    --data-parallel-size-local 1 --enforce-eager --headless

Test Result

With some additional local logging, verified that the request in the above example is directed to the correct DP rank.

Copy link

github-actions bot commented Jun 3, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added documentation Improvements or additions to documentation v1 labels Jun 3, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @jmswen, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

Summary of Changes

This pull request introduces the ability to target a specific data parallel (DP) rank when using AsyncLLMEngine.generate. This is intended for scenarios where load balancing across DP ranks is managed externally, such as with Dynamo. The changes involve adding a data_parallel_rank parameter to the generate and add_request_async methods in AsyncLLMEngine and related classes, and updating the engine core client to route requests to the appropriate engine based on the specified DP rank. An example has also been added to demonstrate the new functionality.

Highlights

  • New Feature: Introduces the ability to target a specific DP rank in AsyncLLMEngine.generate.
  • Parameter Addition: Adds data_parallel_rank parameter to generate and add_request_async methods.
  • Routing Logic: Updates engine core client to route requests to the correct engine based on data_parallel_rank.
  • Example: Adds a new example multi_instance_data_parallel.py to demonstrate the new feature.

Changelog

Click here to see the changelog
  • examples/online_serving/multi_instance_data_parallel.py
    • Added a new example demonstrating how to target a specific DP rank with AsyncLLMEngine.generate (lines 1-57).
  • vllm/engine/arg_utils.py
    • Added data_parallel_start_rank as an optional argument to EngineArgs (line 293).
  • vllm/engine/async_llm_engine.py
    • Added data_parallel_rank parameter to add_request_async method (lines 445, 460, 478).
    • Added data_parallel_rank parameter to add_request method (line 941).
    • Added data_parallel_rank parameter to generate method (lines 988, 1008, 1067).
  • vllm/v1/engine/init.py
    • Added data_parallel_rank to the EngineCoreRequest class (line 58).
  • vllm/v1/engine/async_llm.py
    • Added data_parallel_rank parameter to add_request method (line 232).
    • Pass data_parallel_rank to self.processor.process_inputs (line 249).
    • Added data_parallel_rank parameter to generate method (line 295).
    • Pass data_parallel_rank to self.add_request (line 326).
  • vllm/v1/engine/core_client.py
    • Modified get_core_engine_for_request to select the engine based on dp_rank (lines 986-993).
    • Pass request.data_parallel_rank to get_core_engine_for_request in add_request_async (line 1031).
  • vllm/v1/engine/processor.py
    • Added data_parallel_rank parameter to process_inputs method (line 215).
    • Pass data_parallel_rank to EngineCoreRequest (line 332).
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.


Across the parallel plane,
A rank-specific call we attain,
Dynamo's dance so fine,
Each instance in its line,
Load balanced, no longer in vain.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a valuable feature allowing AsyncLLMEngine.generate to target a specific Data Parallel (DP) rank. This is well-motivated for scenarios involving external load balancing. The changes are propagated consistently across the relevant engine components, and a clear example is provided to demonstrate the new functionality.

Overall, the implementation looks solid. I have a couple of suggestions for improvement, primarily around error handling in the core client and robustness in the new example.

Summary of Findings

  • Example Robustness: The new example multi_instance_data_parallel.py could be made more robust by initializing final_output and checking its state before usage to prevent potential UnboundLocalError.
  • Error Handling for Invalid DP Rank: In vllm/v1/engine/core_client.py, when a specified dp_rank does not uniquely identify an EngineCore, an AssertionError is raised. Consider whether a ValueError might be more appropriate for handling invalid dp_rank inputs, especially if the rank can be provided by external systems.

Merge Readiness

The pull request introduces a useful feature and is generally well-implemented. However, there are a couple of medium-severity suggestions related to error handling and example robustness that should be considered. Addressing these would improve the overall quality and resilience of the changes. I am unable to approve this pull request myself; please have other reviewers take a look and approve before merging, especially after considering the suggested changes.

@jmswen jmswen force-pushed the dpgen branch 4 times, most recently from f2ac64a to a17b606 Compare June 3, 2025 20:53
@jmswen jmswen marked this pull request as ready for review June 3, 2025 20:59
Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @jmswen, LGTM

@njhill njhill added the ready ONLY add when PR is ready to merge/full CI is needed label Jun 4, 2025
Copy link

mergify bot commented Jun 4, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @jmswen.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jun 4, 2025
@njhill
Copy link
Member

njhill commented Jun 4, 2025

An additional change I realized that would be nice is a check in the case that dp isn't enabled. Would be good to fail if this new arg is set in that case rather than silently ignore it. But I think that's fine to do in a follow-on PR.

@njhill njhill merged commit c8dcc15 into vllm-project:main Jun 4, 2025
70 checks passed
@jmswen jmswen deleted the dpgen branch June 4, 2025 15:35
patrickvonplaten pushed a commit to patrickvonplaten/vllm that referenced this pull request Jun 5, 2025
christian-pinto pushed a commit to christian-pinto/vllm that referenced this pull request Jun 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants