Skip to content

Conversation

youkaichao
Copy link
Member

@youkaichao youkaichao commented Feb 21, 2025

continue of #12071 in v1.

some changes to notice:

  • we need to disable VLLM_ENABLE_V1_MULTIPROCESSING so that the engine lives in the same process as the LLM class, which is required by RLHF framework https://github.com/volcengine/verl . this also reduces the scheduling non-determinism. (cc @robertgshaw2-redhat to confirm, in this case, can we guarantee that all calls of llm.generate will produce the same scheduling decision?)
  • some misc changes to fix the compatibility of user interface, and some code that does not take care of ExecutorWithExternalLauncher .

Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@@ -47,6 +47,9 @@ def test_consistent_across_ranks(obj):
llm.llm_engine.vllm_config.cache_config.num_cpu_blocks)
test_consistent_across_ranks(
llm.llm_engine.vllm_config.cache_config.num_gpu_blocks)
params = list(llm.llm_engine.model_executor.driver_worker.worker.model_runner.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is to test if we can directly access the model wih llm.llm_engine.model_executor.driver_worker.worker.model_runner.model . it is used in https://github.com/volcengine/verl/blob/0a1b16f800c25ac80504038fd8b8be4282d6c606/verl/workers/sharding_manager/fsdp_vllm.py#L84

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe worth a comment?

Signed-off-by: youkaichao <[email protected]>
@robertgshaw2-redhat
Copy link
Collaborator

Yes, this should cause deterministic scheduling.

Separately, do you think we can switch from an ENV variable to an EngineArg?

Signed-off-by: youkaichao <[email protected]>
@@ -567,6 +567,10 @@ def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None:
self.worker = worker_class(**kwargs)
assert self.worker is not None

def initialize_from_config(self, kv_cache_configs: List[Any]) -> None:
kv_cache_config = kv_cache_configs[self.rpc_rank]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ruisearch42 @comaniac FYI, if a method needs to send different argument to different ranks, the indexing should use self.rpc_rank , and it should happen in this WorkerWrapperBase

@youkaichao
Copy link
Member Author

Separately, do you think we can switch from an ENV variable to an EngineArg?

I don't have strong opinion here.

@@ -151,7 +152,7 @@ def execute_model(
scheduler_output: "SchedulerOutput",
) -> Optional[ModelRunnerOutput]:
output = self.model_runner.execute_model(scheduler_output)
return output if self.rank == 0 else None
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WoosukKwon we need to have a base class for the workers, so that we can reduce this part of duplicate code lol

right now i just change both of them, but we need to do the unification in the future.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Filed #13711 to track the issue.

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for the PR! Only left minor comments.

@@ -47,6 +47,9 @@ def test_consistent_across_ranks(obj):
llm.llm_engine.vllm_config.cache_config.num_cpu_blocks)
test_consistent_across_ranks(
llm.llm_engine.vllm_config.cache_config.num_gpu_blocks)
params = list(llm.llm_engine.model_executor.driver_worker.worker.model_runner.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe worth a comment?

@@ -151,7 +152,7 @@ def execute_model(
scheduler_output: "SchedulerOutput",
) -> Optional[ModelRunnerOutput]:
output = self.model_runner.execute_model(scheduler_output)
return output if self.rank == 0 else None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Filed #13711 to track the issue.

Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
@youkaichao youkaichao enabled auto-merge (squash) February 23, 2025 07:20
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 23, 2025
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
@youkaichao youkaichao disabled auto-merge February 23, 2025 14:47
@youkaichao youkaichao merged commit eb24dc4 into vllm-project:main Feb 23, 2025
56 of 58 checks passed
@youkaichao youkaichao deleted the v1_torchrun branch February 23, 2025 14:47
Akshat-Tripathi pushed a commit to krai/vllm that referenced this pull request Mar 3, 2025
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: Louis Ulmer <[email protected]>
shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants