Skip to content

Commit 3b37f9b

Browse files
njhillpaulpak58
authored andcommitted
[Misc] Support more collective_rpc return types (vllm-project#21845)
Signed-off-by: Nick Hill <[email protected]> Signed-off-by: Paul Pak <[email protected]>
1 parent ede9955 commit 3b37f9b

File tree

5 files changed

+121
-6
lines changed

5 files changed

+121
-6
lines changed

tests/v1/engine/test_engine_core_client.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
import signal
77
import time
88
import uuid
9+
from dataclasses import dataclass
910
from threading import Thread
10-
from typing import Optional
11+
from typing import Optional, Union
1112
from unittest.mock import MagicMock
1213

1314
import pytest
@@ -292,6 +293,68 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch):
292293
client.shutdown()
293294

294295

296+
@dataclass
297+
class MyDataclass:
298+
message: str
299+
300+
301+
# Dummy utility function to monkey-patch into engine core.
302+
def echo_dc(
303+
self,
304+
msg: str,
305+
return_list: bool = False,
306+
) -> Union[MyDataclass, list[MyDataclass]]:
307+
print(f"echo dc util function called: {msg}")
308+
# Return dataclass to verify support for returning custom types
309+
# (for which there is special handling to make it work with msgspec).
310+
return [MyDataclass(msg) for _ in range(3)] if return_list \
311+
else MyDataclass(msg)
312+
313+
314+
@pytest.mark.asyncio(loop_scope="function")
315+
async def test_engine_core_client_util_method_custom_return(
316+
monkeypatch: pytest.MonkeyPatch):
317+
318+
with monkeypatch.context() as m:
319+
m.setenv("VLLM_USE_V1", "1")
320+
321+
# Must set insecure serialization to allow returning custom types.
322+
m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
323+
324+
# Monkey-patch core engine utility function to test.
325+
m.setattr(EngineCore, "echo_dc", echo_dc, raising=False)
326+
327+
engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True)
328+
vllm_config = engine_args.create_engine_config(
329+
usage_context=UsageContext.UNKNOWN_CONTEXT)
330+
executor_class = Executor.get_class(vllm_config)
331+
332+
with set_default_torch_num_threads(1):
333+
client = EngineCoreClient.make_client(
334+
multiprocess_mode=True,
335+
asyncio_mode=True,
336+
vllm_config=vllm_config,
337+
executor_class=executor_class,
338+
log_stats=True,
339+
)
340+
341+
try:
342+
# Test utility method returning custom / non-native data type.
343+
core_client: AsyncMPClient = client
344+
345+
result = await core_client.call_utility_async(
346+
"echo_dc", "testarg2", False)
347+
assert isinstance(result,
348+
MyDataclass) and result.message == "testarg2"
349+
result = await core_client.call_utility_async(
350+
"echo_dc", "testarg2", True)
351+
assert isinstance(result, list) and all(
352+
isinstance(r, MyDataclass) and r.message == "testarg2"
353+
for r in result)
354+
finally:
355+
client.shutdown()
356+
357+
295358
@pytest.mark.parametrize(
296359
"multiprocessing_mode,publisher_config",
297360
[(True, "tcp"), (False, "inproc")],

vllm/v1/engine/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,13 @@ def finished(self) -> bool:
123123
return self.finish_reason is not None
124124

125125

126+
class UtilityResult:
127+
"""Wrapper for special handling when serializing/deserializing."""
128+
129+
def __init__(self, r: Any = None):
130+
self.result = r
131+
132+
126133
class UtilityOutput(
127134
msgspec.Struct,
128135
array_like=True, # type: ignore[call-arg]
@@ -132,7 +139,7 @@ class UtilityOutput(
132139

133140
# Non-None implies the call failed, result should be None.
134141
failure_message: Optional[str] = None
135-
result: Any = None
142+
result: Optional[UtilityResult] = None
136143

137144

138145
class EngineCoreOutputs(

vllm/v1/engine/core.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
3737
EngineCoreRequestType,
3838
ReconfigureDistributedRequest, ReconfigureRankType,
39-
UtilityOutput)
39+
UtilityOutput, UtilityResult)
4040
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
4141
from vllm.v1.engine.utils import EngineHandshakeMetadata, EngineZmqAddresses
4242
from vllm.v1.executor.abstract import Executor
@@ -715,8 +715,8 @@ def _handle_client_request(self, request_type: EngineCoreRequestType,
715715
output = UtilityOutput(call_id)
716716
try:
717717
method = getattr(self, method_name)
718-
output.result = method(
719-
*self._convert_msgspec_args(method, args))
718+
result = method(*self._convert_msgspec_args(method, args))
719+
output.result = UtilityResult(result)
720720
except BaseException as e:
721721
logger.exception("Invocation of %s method failed", method_name)
722722
output.failure_message = (f"Call to {method_name} method"

vllm/v1/engine/core_client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,8 @@ def _process_utility_output(output: UtilityOutput,
552552
if output.failure_message is not None:
553553
future.set_exception(Exception(output.failure_message))
554554
else:
555-
future.set_result(output.result)
555+
assert output.result is not None
556+
future.set_result(output.result.result)
556557

557558

558559
class SyncMPClient(MPClient):

vllm/v1/serial_utils.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import dataclasses
5+
import importlib
56
import pickle
67
from collections.abc import Sequence
78
from inspect import isclass
89
from types import FunctionType
910
from typing import Any, Optional, Union
1011

1112
import cloudpickle
13+
import msgspec
1214
import numpy as np
1315
import torch
1416
import zmq
@@ -22,6 +24,7 @@
2224
MultiModalFlatField, MultiModalKwargs,
2325
MultiModalKwargsItem,
2426
MultiModalSharedField, NestedTensors)
27+
from vllm.v1.engine import UtilityResult
2528

2629
logger = init_logger(__name__)
2730

@@ -46,6 +49,10 @@ def _log_insecure_serialization_warning():
4649
"VLLM_ALLOW_INSECURE_SERIALIZATION=1")
4750

4851

52+
def _typestr(t: type):
53+
return t.__module__, t.__qualname__
54+
55+
4956
class MsgpackEncoder:
5057
"""Encoder with custom torch tensor and numpy array serialization.
5158
@@ -122,6 +129,18 @@ def enc_hook(self, obj: Any) -> Any:
122129
for itemlist in mm._items_by_modality.values()
123130
for item in itemlist]
124131

132+
if isinstance(obj, UtilityResult):
133+
result = obj.result
134+
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION or result is None:
135+
return None, result
136+
# Since utility results are not strongly typed, we also encode
137+
# the type (or a list of types in the case it's a list) to
138+
# help with correct msgspec deserialization.
139+
cls = result.__class__
140+
return _typestr(cls) if cls is not list else [
141+
_typestr(type(v)) for v in result
142+
], result
143+
125144
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
126145
raise TypeError(f"Object of type {type(obj)} is not serializable"
127146
"Set VLLM_ALLOW_INSECURE_SERIALIZATION=1 to allow "
@@ -237,8 +256,33 @@ def dec_hook(self, t: type, obj: Any) -> Any:
237256
k: self._decode_nested_tensors(v)
238257
for k, v in obj.items()
239258
})
259+
if t is UtilityResult:
260+
return self._decode_utility_result(obj)
240261
return obj
241262

263+
def _decode_utility_result(self, obj: Any) -> UtilityResult:
264+
result_type, result = obj
265+
if result_type is not None:
266+
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
267+
raise TypeError("VLLM_ALLOW_INSECURE_SERIALIZATION must "
268+
"be set to use custom utility result types")
269+
assert isinstance(result_type, list)
270+
if len(result_type) == 2 and isinstance(result_type[0], str):
271+
result = self._convert_result(result_type, result)
272+
else:
273+
assert isinstance(result, list)
274+
result = [
275+
self._convert_result(rt, r)
276+
for rt, r in zip(result_type, result)
277+
]
278+
return UtilityResult(result)
279+
280+
def _convert_result(self, result_type: Sequence[str], result: Any):
281+
mod_name, name = result_type
282+
mod = importlib.import_module(mod_name)
283+
result_type = getattr(mod, name)
284+
return msgspec.convert(result, result_type, dec_hook=self.dec_hook)
285+
242286
def _decode_ndarray(self, arr: Any) -> np.ndarray:
243287
dtype, shape, data = arr
244288
# zero-copy decode. We assume the ndarray will not be kept around,

0 commit comments

Comments
 (0)