23
23
from vllm .lora .ops .triton_ops .sgmv_expand import sgmv_expand
24
24
from vllm .lora .ops .triton_ops .sgmv_shrink import sgmv_shrink
25
25
from vllm .lora .ops .triton_ops .utils import _LORA_A_PTR_DICT , _LORA_B_PTR_DICT
26
+ from vllm .lora .ops .triton_ops .v1 import V1KernelMeta , v1_expand , v1_shrink
26
27
from vllm .utils import FlexibleArgumentParser
27
28
28
29
DEFAULT_MODELS = list (WEIGHT_SHAPES .keys ())
@@ -171,6 +172,8 @@ class OpType(Enum):
171
172
SGMV_EXPAND = auto ()
172
173
BGMV_EXPAND = auto ()
173
174
BGMV_EXPAND_SLICE = auto ()
175
+ V1_SHRINK = auto ()
176
+ V1_EXPAND = auto ()
174
177
175
178
@staticmethod
176
179
def from_str (s : str ) -> "OpType" :
@@ -184,28 +187,43 @@ def from_str(s: str) -> "OpType":
184
187
return OpType .BGMV_EXPAND
185
188
if s .lower () == "bgmv_expand_slice" :
186
189
return OpType .BGMV_EXPAND_SLICE
190
+ if s .lower () == "v1_shrink" :
191
+ return OpType .V1_SHRINK
192
+ if s .lower () == "v1_expand" :
193
+ return OpType .V1_EXPAND
187
194
raise ValueError (f"Unrecognized str { s } to convert to OpType" )
188
195
189
196
def is_shrink_fn (self ) -> bool :
190
- return self in [OpType .SGMV_SHRINK , OpType .BGMV_SHRINK ]
197
+ return self in [
198
+ OpType .SGMV_SHRINK , OpType .BGMV_SHRINK , OpType .V1_SHRINK
199
+ ]
191
200
192
201
def is_expand_fn (self ) -> bool :
193
- return self in [OpType .SGMV_EXPAND , OpType .BGMV_EXPAND ]
202
+ return self in [
203
+ OpType .SGMV_EXPAND , OpType .BGMV_EXPAND , OpType .V1_EXPAND
204
+ ]
194
205
195
206
def is_prefill_op (self ) -> bool :
196
- return self in [OpType .SGMV_SHRINK , OpType .SGMV_EXPAND ]
207
+ return self in [
208
+ OpType .SGMV_SHRINK , OpType .SGMV_EXPAND , OpType .V1_SHRINK ,
209
+ OpType .V1_EXPAND
210
+ ]
197
211
198
212
def is_decode_op (self ) -> bool :
199
213
return self in [
200
- OpType .BGMV_SHRINK , OpType .BGMV_EXPAND , OpType .BGMV_EXPAND_SLICE
214
+ OpType .BGMV_SHRINK , OpType .BGMV_EXPAND , OpType .BGMV_EXPAND_SLICE ,
215
+ OpType .V1_SHRINK , OpType .V1_EXPAND
201
216
]
202
217
203
218
def is_expand_slice_fn (self ) -> bool :
204
219
return self in [OpType .BGMV_EXPAND_SLICE ]
205
220
206
221
def num_slices (self ) -> list [int ]:
207
- if self in [OpType .SGMV_EXPAND , OpType .SGMV_SHRINK ]:
208
- # SGMV kernels supports slices
222
+ if self in [
223
+ OpType .SGMV_EXPAND , OpType .SGMV_SHRINK , OpType .V1_SHRINK ,
224
+ OpType .V1_EXPAND
225
+ ]:
226
+ # SGMV kernels and v1 kernels supports slices
209
227
return [1 , 2 , 3 ]
210
228
if self in [OpType .BGMV_SHRINK , OpType .BGMV_EXPAND ]:
211
229
return [1 ]
@@ -250,11 +268,13 @@ def matmul_shapes(
250
268
m , k , n = self .mkn (batch_size , seq_length , hidden_size , lora_rank )
251
269
252
270
b_shape = (num_loras , n , k ) # col-major
253
- if self == OpType .SGMV_SHRINK :
254
- # SGMV shrink supports num_slices inherently in the kernel
271
+ if self in [OpType .SGMV_SHRINK , OpType .V1_SHRINK ]:
272
+ # SGMV shrink and V1 shrink kernels support num_slices inherently
273
+ # in the kernel.
255
274
return ((m , k ), b_shape , (num_slices , m , n ))
256
- if self == OpType .SGMV_EXPAND :
257
- # SGMV expand supports num_slices inherently in the kernel
275
+ if self in [OpType .SGMV_EXPAND , OpType .V1_EXPAND ]:
276
+ # SGMV expand and V1 expand kernels support num_slices inherently
277
+ # in the kernel
258
278
return ((num_slices , m , k ), b_shape , (m , n * num_slices ))
259
279
if self == OpType .BGMV_SHRINK :
260
280
return ((m , k ), b_shape , (m , n ))
@@ -281,25 +301,30 @@ def emulate_bgmv_expand_slice(kwargs_list: list[dict[str, Any]]):
281
301
return bgmv_expand
282
302
if self == OpType .BGMV_EXPAND_SLICE :
283
303
return emulate_bgmv_expand_slice
304
+ if self == OpType .V1_SHRINK :
305
+ return v1_shrink
306
+ if self == OpType .V1_EXPAND :
307
+ return v1_expand
308
+
284
309
raise ValueError (f"Unrecognized optype { self } " )
285
310
286
311
def run_ref_group_gemm (self , output : torch .Tensor , input : torch .Tensor ,
287
312
lora_weights : list [torch .Tensor ],
288
313
** kwargs ) -> Callable :
289
- """Each benchmark operation expected the input, lora_weights and outputs
314
+ """Each benchmark operation expects the input, lora_weights and outputs
290
315
in a slightly different format. Refer to self.matmul_shapes().
291
316
run_ref_group_gemm accounts for those differences in executing a
292
317
reference group gemm for correctness testing.
293
318
"""
294
319
w_dtype = lora_weights [0 ].dtype
295
320
num_slices = len (lora_weights )
296
- if self == OpType .SGMV_SHRINK :
321
+ if self in [ OpType .SGMV_SHRINK , OpType . V1_SHRINK ] :
297
322
for slice_idx in range (num_slices ):
298
323
ref_group_gemm (ref_out = output [slice_idx , :],
299
324
input = input ,
300
325
lora_weights = lora_weights [slice_idx ],
301
326
** kwargs )
302
- if self == OpType .SGMV_EXPAND :
327
+ elif self in [ OpType .SGMV_EXPAND , OpType . V1_EXPAND ] :
303
328
hidden_size = lora_weights [0 ].shape [1 ]
304
329
for slice_idx in range (num_slices ):
305
330
slice_offset = slice_idx * hidden_size
@@ -308,19 +333,19 @@ def run_ref_group_gemm(self, output: torch.Tensor, input: torch.Tensor,
308
333
input = input [slice_idx ].clone ().to (dtype = w_dtype ),
309
334
lora_weights = lora_weights [slice_idx ],
310
335
** kwargs )
311
- if self == OpType .BGMV_SHRINK :
336
+ elif self == OpType .BGMV_SHRINK :
312
337
assert num_slices == 1
313
338
ref_group_gemm (ref_out = output ,
314
339
input = input ,
315
340
lora_weights = lora_weights [0 ],
316
341
** kwargs )
317
- if self == OpType .BGMV_EXPAND :
342
+ elif self == OpType .BGMV_EXPAND :
318
343
assert num_slices == 1
319
344
ref_group_gemm (ref_out = output ,
320
345
input = input .clone ().to (dtype = w_dtype ),
321
346
lora_weights = lora_weights [0 ],
322
347
** kwargs )
323
- if self == OpType .BGMV_EXPAND_SLICE :
348
+ elif self == OpType .BGMV_EXPAND_SLICE :
324
349
hidden_size = lora_weights [0 ].shape [1 ]
325
350
for slice_idx in range (num_slices ):
326
351
slice_offset = slice_idx * hidden_size
@@ -329,7 +354,8 @@ def run_ref_group_gemm(self, output: torch.Tensor, input: torch.Tensor,
329
354
input = input [slice_idx ].clone ().to (dtype = w_dtype ),
330
355
lora_weights = lora_weights [slice_idx ],
331
356
** kwargs )
332
- raise ValueError (f"Unrecognized optype { self } " )
357
+ else :
358
+ raise ValueError (f"Unrecognized optype { self } " )
333
359
334
360
335
361
@dataclass
@@ -390,6 +416,8 @@ class BenchmarkTensors:
390
416
seq_start_loc : torch .Tensor
391
417
prompt_lora_mapping : torch .Tensor
392
418
token_lora_mapping : torch .Tensor
419
+ # v1 kernel metadata
420
+ v1_kernel_meta : Optional [V1KernelMeta ] = None
393
421
394
422
def io_types (self ) -> str :
395
423
return (f"{ dtype_to_str (self .input .dtype )} x"
@@ -432,10 +460,19 @@ def make(ctx: BenchmarkContext,
432
460
total_tokens , ctx .batch_size , prompt_lora_indices_tensor ,
433
461
seq_len_tensor , "cpu" )
434
462
463
+ v1_kernel_meta = None
464
+ if op_type in [OpType .V1_SHRINK , OpType .V1_EXPAND ]:
465
+ v1_kernel_meta = V1KernelMeta .make (
466
+ max_loras = ctx .num_loras ,
467
+ max_num_tokens = token_lora_indices_tensor .size (0 ),
468
+ device = "cpu" )
469
+ v1_kernel_meta .prepare_tensors (
470
+ token_lora_mapping = token_lora_indices_tensor )
471
+
435
472
return BenchmarkTensors (input_tensor , lora_weights , output_tensor ,
436
473
seq_len_tensor , seq_start_loc_tensor ,
437
474
prompt_lora_indices_tensor ,
438
- token_lora_indices_tensor )
475
+ token_lora_indices_tensor , v1_kernel_meta )
439
476
440
477
def sanity_check (self ) -> None :
441
478
"""
@@ -468,6 +505,13 @@ def to_device(tensor: torch.Tensor):
468
505
for i in range (len (self .lora_weights_lst )):
469
506
self .lora_weights_lst [i ] = to_device (self .lora_weights_lst [i ])
470
507
508
+ # v1 meta
509
+ if self .v1_kernel_meta :
510
+ for field_name in V1KernelMeta .__dataclass_fields__ :
511
+ field = getattr (self .v1_kernel_meta , field_name )
512
+ assert isinstance (field , torch .Tensor )
513
+ setattr (self .v1_kernel_meta , field_name , to_device (field ))
514
+
471
515
def metadata (self ) -> tuple [int , int , int ]:
472
516
"""
473
517
Return num_seqs, num_tokens and max_seq_len
@@ -667,6 +711,78 @@ def as_bgmv_expand_slice_kwargs(self, add_inputs: bool) -> dict[str, Any]:
667
711
})
668
712
return {'kwargs_list' : kwargs_list }
669
713
714
+ def as_v1_shrink_kwargs (self ) -> dict [str , Any ]:
715
+ assert self .v1_kernel_meta is not None
716
+ self .sanity_check ()
717
+ self .to_device (self .input .device )
718
+
719
+ _ , num_tokens , _ , num_slices = self .metadata ()
720
+
721
+ # Sanity check matrix shapes.
722
+ i_shape , lw_shape , o_shape = self .input .shape , self .lora_weights_lst [
723
+ 0 ].shape , self .output .shape
724
+ # Expected input shape [num_tokens, hidden_size]
725
+ assert len (i_shape ) == 2
726
+ assert i_shape [0 ] == num_tokens
727
+ hidden_size = i_shape [1 ]
728
+ # Expected lora weight shape [num_loras, lora_rank, hidden_size]
729
+ assert len (lw_shape ) == 3
730
+ assert lw_shape [2 ] == hidden_size
731
+ lora_rank = lw_shape [1 ]
732
+ # Expected output shape [num_slices, num_tokens, lora_rank]
733
+ assert len (o_shape ) == 3
734
+ assert o_shape == (num_slices , num_tokens , lora_rank )
735
+
736
+ return {
737
+ 'inputs' : self .input ,
738
+ 'lora_a_weights' : self .lora_weights_lst ,
739
+ 'output_tensor' : self .output ,
740
+ 'token_lora_mapping' : self .v1_kernel_meta .token_lora_mapping ,
741
+ 'token_indices_sorted_by_lora_ids' :
742
+ self .v1_kernel_meta .token_indices_sorted_by_lora_ids ,
743
+ 'num_tokens_per_lora' : self .v1_kernel_meta .num_tokens_per_lora ,
744
+ 'lora_token_start_loc' : self .v1_kernel_meta .lora_token_start_loc ,
745
+ 'lora_ids' : self .v1_kernel_meta .active_lora_ids ,
746
+ 'scaling' : 1.0 ,
747
+ }
748
+
749
+ def as_v1_expand_kwargs (self , add_inputs : bool ) -> dict [str , Any ]:
750
+ assert self .v1_kernel_meta is not None
751
+ self .sanity_check ()
752
+ self .to_device (self .input .device )
753
+
754
+ _ , num_tokens , _ , num_slices = self .metadata ()
755
+
756
+ # Sanity check matrix shapes.
757
+ i_shape , lw_shape , o_shape = self .input .shape , self .lora_weights_lst [
758
+ 0 ].shape , self .output .shape
759
+ # Expected input shape : [num_slices, num_tokens, lora_rank]
760
+ assert len (i_shape ) == 3
761
+ assert i_shape [0 ] == num_slices
762
+ assert i_shape [1 ] == num_tokens
763
+ lora_rank = i_shape [2 ]
764
+ # Expected lora weight shape : [num_lora, hidden_size, lora_rank]
765
+ assert len (lw_shape ) == 3
766
+ assert lw_shape [2 ] == lora_rank
767
+ hidden_size = lw_shape [1 ]
768
+ # Expected output shape : [num_tokens, hidden_size * num_slices]
769
+ assert len (o_shape ) == 2
770
+ assert o_shape == (num_tokens , hidden_size * num_slices )
771
+
772
+ return {
773
+ 'inputs' : self .input ,
774
+ 'lora_b_weights' : self .lora_weights_lst ,
775
+ 'output_tensor' : self .output ,
776
+ 'token_lora_mapping' : self .v1_kernel_meta .token_lora_mapping ,
777
+ 'token_indices_sorted_by_lora_ids' :
778
+ self .v1_kernel_meta .token_indices_sorted_by_lora_ids ,
779
+ 'num_tokens_per_lora' : self .v1_kernel_meta .num_tokens_per_lora ,
780
+ 'lora_token_start_loc' : self .v1_kernel_meta .lora_token_start_loc ,
781
+ 'lora_ids' : self .v1_kernel_meta .active_lora_ids ,
782
+ 'offset_start' : 0 ,
783
+ 'add_inputs' : add_inputs ,
784
+ }
785
+
670
786
def bench_fn_kwargs (self ,
671
787
op_type : OpType ,
672
788
add_inputs : Optional [bool ] = None ) -> dict [str , Any ]:
@@ -685,6 +801,10 @@ def bench_fn_kwargs(self,
685
801
return self .as_bgmv_expand_kwargs (add_inputs )
686
802
if op_type == OpType .BGMV_EXPAND_SLICE :
687
803
return self .as_bgmv_expand_slice_kwargs (add_inputs )
804
+ if op_type == OpType .V1_SHRINK :
805
+ return self .as_v1_shrink_kwargs ()
806
+ if op_type == OpType .V1_EXPAND :
807
+ return self .as_v1_expand_kwargs (add_inputs )
688
808
raise ValueError (f"Unrecognized optype { self } " )
689
809
690
810
def test_correctness (self , op_type : OpType ,
@@ -872,12 +992,9 @@ def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]):
872
992
timers = []
873
993
for bench_ctx in bench_ctxs :
874
994
for seq_len in args .seq_lengths :
875
- bench_ops : list [OpType ] = []
876
- if seq_len == 1 :
877
- # bench all decode ops
878
- bench_ops = [op for op in args .op_types if op .is_decode_op ()]
879
- else :
880
- # bench all prefill ops
995
+ bench_ops : list [OpType ] = args .op_types
996
+ if seq_len > 1 :
997
+ # bench only prefill ops
881
998
bench_ops = [op for op in args .op_types if op .is_prefill_op ()]
882
999
883
1000
seq_len_timers = []
0 commit comments