2
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
3
4
4
import dataclasses
5
+ import importlib
5
6
import pickle
6
7
from collections .abc import Sequence
7
8
from inspect import isclass
8
9
from types import FunctionType
9
10
from typing import Any , Optional , Union
10
11
11
12
import cloudpickle
13
+ import msgspec
12
14
import numpy as np
13
15
import torch
14
16
import zmq
22
24
MultiModalFlatField , MultiModalKwargs ,
23
25
MultiModalKwargsItem ,
24
26
MultiModalSharedField , NestedTensors )
27
+ from vllm .v1 .engine import UtilityResult
25
28
26
29
logger = init_logger (__name__ )
27
30
@@ -46,6 +49,10 @@ def _log_insecure_serialization_warning():
46
49
"VLLM_ALLOW_INSECURE_SERIALIZATION=1" )
47
50
48
51
52
+ def _typestr (t : type ):
53
+ return t .__module__ , t .__qualname__
54
+
55
+
49
56
class MsgpackEncoder :
50
57
"""Encoder with custom torch tensor and numpy array serialization.
51
58
@@ -122,6 +129,18 @@ def enc_hook(self, obj: Any) -> Any:
122
129
for itemlist in mm ._items_by_modality .values ()
123
130
for item in itemlist ]
124
131
132
+ if isinstance (obj , UtilityResult ):
133
+ result = obj .result
134
+ if not envs .VLLM_ALLOW_INSECURE_SERIALIZATION or result is None :
135
+ return None , result
136
+ # Since utility results are not strongly typed, we also encode
137
+ # the type (or a list of types in the case it's a list) to
138
+ # help with correct msgspec deserialization.
139
+ cls = result .__class__
140
+ return _typestr (cls ) if cls is not list else [
141
+ _typestr (type (v )) for v in result
142
+ ], result
143
+
125
144
if not envs .VLLM_ALLOW_INSECURE_SERIALIZATION :
126
145
raise TypeError (f"Object of type { type (obj )} is not serializable"
127
146
"Set VLLM_ALLOW_INSECURE_SERIALIZATION=1 to allow "
@@ -237,8 +256,33 @@ def dec_hook(self, t: type, obj: Any) -> Any:
237
256
k : self ._decode_nested_tensors (v )
238
257
for k , v in obj .items ()
239
258
})
259
+ if t is UtilityResult :
260
+ return self ._decode_utility_result (obj )
240
261
return obj
241
262
263
+ def _decode_utility_result (self , obj : Any ) -> UtilityResult :
264
+ result_type , result = obj
265
+ if result_type is not None :
266
+ if not envs .VLLM_ALLOW_INSECURE_SERIALIZATION :
267
+ raise TypeError ("VLLM_ALLOW_INSECURE_SERIALIZATION must "
268
+ "be set to use custom utility result types" )
269
+ assert isinstance (result_type , list )
270
+ if len (result_type ) == 2 and isinstance (result_type [0 ], str ):
271
+ result = self ._convert_result (result_type , result )
272
+ else :
273
+ assert isinstance (result , list )
274
+ result = [
275
+ self ._convert_result (rt , r )
276
+ for rt , r in zip (result_type , result )
277
+ ]
278
+ return UtilityResult (result )
279
+
280
+ def _convert_result (self , result_type : Sequence [str ], result : Any ):
281
+ mod_name , name = result_type
282
+ mod = importlib .import_module (mod_name )
283
+ result_type = getattr (mod , name )
284
+ return msgspec .convert (result , result_type , dec_hook = self .dec_hook )
285
+
242
286
def _decode_ndarray (self , arr : Any ) -> np .ndarray :
243
287
dtype , shape , data = arr
244
288
# zero-copy decode. We assume the ndarray will not be kept around,
0 commit comments