Skip to content

Commit f64306d

Browse files
varun-sundar-rabindranathVarun Sundar Rabindranath
authored andcommitted
[V1] LoRA - Add triton kernels for V1 (vllm-project#13096)
Signed-off-by: Varun Sundar Rabindranath <[email protected]> Co-authored-by: Varun Sundar Rabindranath <[email protected]> Signed-off-by: Louis Ulmer <[email protected]>
1 parent 01062b8 commit f64306d

File tree

11 files changed

+1162
-188
lines changed

11 files changed

+1162
-188
lines changed

benchmarks/kernels/benchmark_lora.py

Lines changed: 141 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from vllm.lora.ops.triton_ops.sgmv_expand import sgmv_expand
2424
from vllm.lora.ops.triton_ops.sgmv_shrink import sgmv_shrink
2525
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
26+
from vllm.lora.ops.triton_ops.v1 import V1KernelMeta, v1_expand, v1_shrink
2627
from vllm.utils import FlexibleArgumentParser
2728

2829
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
@@ -171,6 +172,8 @@ class OpType(Enum):
171172
SGMV_EXPAND = auto()
172173
BGMV_EXPAND = auto()
173174
BGMV_EXPAND_SLICE = auto()
175+
V1_SHRINK = auto()
176+
V1_EXPAND = auto()
174177

175178
@staticmethod
176179
def from_str(s: str) -> "OpType":
@@ -184,28 +187,43 @@ def from_str(s: str) -> "OpType":
184187
return OpType.BGMV_EXPAND
185188
if s.lower() == "bgmv_expand_slice":
186189
return OpType.BGMV_EXPAND_SLICE
190+
if s.lower() == "v1_shrink":
191+
return OpType.V1_SHRINK
192+
if s.lower() == "v1_expand":
193+
return OpType.V1_EXPAND
187194
raise ValueError(f"Unrecognized str {s} to convert to OpType")
188195

189196
def is_shrink_fn(self) -> bool:
190-
return self in [OpType.SGMV_SHRINK, OpType.BGMV_SHRINK]
197+
return self in [
198+
OpType.SGMV_SHRINK, OpType.BGMV_SHRINK, OpType.V1_SHRINK
199+
]
191200

192201
def is_expand_fn(self) -> bool:
193-
return self in [OpType.SGMV_EXPAND, OpType.BGMV_EXPAND]
202+
return self in [
203+
OpType.SGMV_EXPAND, OpType.BGMV_EXPAND, OpType.V1_EXPAND
204+
]
194205

195206
def is_prefill_op(self) -> bool:
196-
return self in [OpType.SGMV_SHRINK, OpType.SGMV_EXPAND]
207+
return self in [
208+
OpType.SGMV_SHRINK, OpType.SGMV_EXPAND, OpType.V1_SHRINK,
209+
OpType.V1_EXPAND
210+
]
197211

198212
def is_decode_op(self) -> bool:
199213
return self in [
200-
OpType.BGMV_SHRINK, OpType.BGMV_EXPAND, OpType.BGMV_EXPAND_SLICE
214+
OpType.BGMV_SHRINK, OpType.BGMV_EXPAND, OpType.BGMV_EXPAND_SLICE,
215+
OpType.V1_SHRINK, OpType.V1_EXPAND
201216
]
202217

203218
def is_expand_slice_fn(self) -> bool:
204219
return self in [OpType.BGMV_EXPAND_SLICE]
205220

206221
def num_slices(self) -> list[int]:
207-
if self in [OpType.SGMV_EXPAND, OpType.SGMV_SHRINK]:
208-
# SGMV kernels supports slices
222+
if self in [
223+
OpType.SGMV_EXPAND, OpType.SGMV_SHRINK, OpType.V1_SHRINK,
224+
OpType.V1_EXPAND
225+
]:
226+
# SGMV kernels and v1 kernels supports slices
209227
return [1, 2, 3]
210228
if self in [OpType.BGMV_SHRINK, OpType.BGMV_EXPAND]:
211229
return [1]
@@ -250,11 +268,13 @@ def matmul_shapes(
250268
m, k, n = self.mkn(batch_size, seq_length, hidden_size, lora_rank)
251269

252270
b_shape = (num_loras, n, k) # col-major
253-
if self == OpType.SGMV_SHRINK:
254-
# SGMV shrink supports num_slices inherently in the kernel
271+
if self in [OpType.SGMV_SHRINK, OpType.V1_SHRINK]:
272+
# SGMV shrink and V1 shrink kernels support num_slices inherently
273+
# in the kernel.
255274
return ((m, k), b_shape, (num_slices, m, n))
256-
if self == OpType.SGMV_EXPAND:
257-
# SGMV expand supports num_slices inherently in the kernel
275+
if self in [OpType.SGMV_EXPAND, OpType.V1_EXPAND]:
276+
# SGMV expand and V1 expand kernels support num_slices inherently
277+
# in the kernel
258278
return ((num_slices, m, k), b_shape, (m, n * num_slices))
259279
if self == OpType.BGMV_SHRINK:
260280
return ((m, k), b_shape, (m, n))
@@ -281,25 +301,30 @@ def emulate_bgmv_expand_slice(kwargs_list: list[dict[str, Any]]):
281301
return bgmv_expand
282302
if self == OpType.BGMV_EXPAND_SLICE:
283303
return emulate_bgmv_expand_slice
304+
if self == OpType.V1_SHRINK:
305+
return v1_shrink
306+
if self == OpType.V1_EXPAND:
307+
return v1_expand
308+
284309
raise ValueError(f"Unrecognized optype {self}")
285310

286311
def run_ref_group_gemm(self, output: torch.Tensor, input: torch.Tensor,
287312
lora_weights: list[torch.Tensor],
288313
**kwargs) -> Callable:
289-
"""Each benchmark operation expected the input, lora_weights and outputs
314+
"""Each benchmark operation expects the input, lora_weights and outputs
290315
in a slightly different format. Refer to self.matmul_shapes().
291316
run_ref_group_gemm accounts for those differences in executing a
292317
reference group gemm for correctness testing.
293318
"""
294319
w_dtype = lora_weights[0].dtype
295320
num_slices = len(lora_weights)
296-
if self == OpType.SGMV_SHRINK:
321+
if self in [OpType.SGMV_SHRINK, OpType.V1_SHRINK]:
297322
for slice_idx in range(num_slices):
298323
ref_group_gemm(ref_out=output[slice_idx, :],
299324
input=input,
300325
lora_weights=lora_weights[slice_idx],
301326
**kwargs)
302-
if self == OpType.SGMV_EXPAND:
327+
elif self in [OpType.SGMV_EXPAND, OpType.V1_EXPAND]:
303328
hidden_size = lora_weights[0].shape[1]
304329
for slice_idx in range(num_slices):
305330
slice_offset = slice_idx * hidden_size
@@ -308,19 +333,19 @@ def run_ref_group_gemm(self, output: torch.Tensor, input: torch.Tensor,
308333
input=input[slice_idx].clone().to(dtype=w_dtype),
309334
lora_weights=lora_weights[slice_idx],
310335
**kwargs)
311-
if self == OpType.BGMV_SHRINK:
336+
elif self == OpType.BGMV_SHRINK:
312337
assert num_slices == 1
313338
ref_group_gemm(ref_out=output,
314339
input=input,
315340
lora_weights=lora_weights[0],
316341
**kwargs)
317-
if self == OpType.BGMV_EXPAND:
342+
elif self == OpType.BGMV_EXPAND:
318343
assert num_slices == 1
319344
ref_group_gemm(ref_out=output,
320345
input=input.clone().to(dtype=w_dtype),
321346
lora_weights=lora_weights[0],
322347
**kwargs)
323-
if self == OpType.BGMV_EXPAND_SLICE:
348+
elif self == OpType.BGMV_EXPAND_SLICE:
324349
hidden_size = lora_weights[0].shape[1]
325350
for slice_idx in range(num_slices):
326351
slice_offset = slice_idx * hidden_size
@@ -329,7 +354,8 @@ def run_ref_group_gemm(self, output: torch.Tensor, input: torch.Tensor,
329354
input=input[slice_idx].clone().to(dtype=w_dtype),
330355
lora_weights=lora_weights[slice_idx],
331356
**kwargs)
332-
raise ValueError(f"Unrecognized optype {self}")
357+
else:
358+
raise ValueError(f"Unrecognized optype {self}")
333359

334360

335361
@dataclass
@@ -390,6 +416,8 @@ class BenchmarkTensors:
390416
seq_start_loc: torch.Tensor
391417
prompt_lora_mapping: torch.Tensor
392418
token_lora_mapping: torch.Tensor
419+
# v1 kernel metadata
420+
v1_kernel_meta: Optional[V1KernelMeta] = None
393421

394422
def io_types(self) -> str:
395423
return (f"{dtype_to_str(self.input.dtype)}x"
@@ -432,10 +460,19 @@ def make(ctx: BenchmarkContext,
432460
total_tokens, ctx.batch_size, prompt_lora_indices_tensor,
433461
seq_len_tensor, "cpu")
434462

463+
v1_kernel_meta = None
464+
if op_type in [OpType.V1_SHRINK, OpType.V1_EXPAND]:
465+
v1_kernel_meta = V1KernelMeta.make(
466+
max_loras=ctx.num_loras,
467+
max_num_tokens=token_lora_indices_tensor.size(0),
468+
device="cpu")
469+
v1_kernel_meta.prepare_tensors(
470+
token_lora_mapping=token_lora_indices_tensor)
471+
435472
return BenchmarkTensors(input_tensor, lora_weights, output_tensor,
436473
seq_len_tensor, seq_start_loc_tensor,
437474
prompt_lora_indices_tensor,
438-
token_lora_indices_tensor)
475+
token_lora_indices_tensor, v1_kernel_meta)
439476

440477
def sanity_check(self) -> None:
441478
"""
@@ -468,6 +505,13 @@ def to_device(tensor: torch.Tensor):
468505
for i in range(len(self.lora_weights_lst)):
469506
self.lora_weights_lst[i] = to_device(self.lora_weights_lst[i])
470507

508+
# v1 meta
509+
if self.v1_kernel_meta:
510+
for field_name in V1KernelMeta.__dataclass_fields__:
511+
field = getattr(self.v1_kernel_meta, field_name)
512+
assert isinstance(field, torch.Tensor)
513+
setattr(self.v1_kernel_meta, field_name, to_device(field))
514+
471515
def metadata(self) -> tuple[int, int, int]:
472516
"""
473517
Return num_seqs, num_tokens and max_seq_len
@@ -667,6 +711,78 @@ def as_bgmv_expand_slice_kwargs(self, add_inputs: bool) -> dict[str, Any]:
667711
})
668712
return {'kwargs_list': kwargs_list}
669713

714+
def as_v1_shrink_kwargs(self) -> dict[str, Any]:
715+
assert self.v1_kernel_meta is not None
716+
self.sanity_check()
717+
self.to_device(self.input.device)
718+
719+
_, num_tokens, _, num_slices = self.metadata()
720+
721+
# Sanity check matrix shapes.
722+
i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
723+
0].shape, self.output.shape
724+
# Expected input shape [num_tokens, hidden_size]
725+
assert len(i_shape) == 2
726+
assert i_shape[0] == num_tokens
727+
hidden_size = i_shape[1]
728+
# Expected lora weight shape [num_loras, lora_rank, hidden_size]
729+
assert len(lw_shape) == 3
730+
assert lw_shape[2] == hidden_size
731+
lora_rank = lw_shape[1]
732+
# Expected output shape [num_slices, num_tokens, lora_rank]
733+
assert len(o_shape) == 3
734+
assert o_shape == (num_slices, num_tokens, lora_rank)
735+
736+
return {
737+
'inputs': self.input,
738+
'lora_a_weights': self.lora_weights_lst,
739+
'output_tensor': self.output,
740+
'token_lora_mapping': self.v1_kernel_meta.token_lora_mapping,
741+
'token_indices_sorted_by_lora_ids':
742+
self.v1_kernel_meta.token_indices_sorted_by_lora_ids,
743+
'num_tokens_per_lora': self.v1_kernel_meta.num_tokens_per_lora,
744+
'lora_token_start_loc': self.v1_kernel_meta.lora_token_start_loc,
745+
'lora_ids': self.v1_kernel_meta.active_lora_ids,
746+
'scaling': 1.0,
747+
}
748+
749+
def as_v1_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]:
750+
assert self.v1_kernel_meta is not None
751+
self.sanity_check()
752+
self.to_device(self.input.device)
753+
754+
_, num_tokens, _, num_slices = self.metadata()
755+
756+
# Sanity check matrix shapes.
757+
i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
758+
0].shape, self.output.shape
759+
# Expected input shape : [num_slices, num_tokens, lora_rank]
760+
assert len(i_shape) == 3
761+
assert i_shape[0] == num_slices
762+
assert i_shape[1] == num_tokens
763+
lora_rank = i_shape[2]
764+
# Expected lora weight shape : [num_lora, hidden_size, lora_rank]
765+
assert len(lw_shape) == 3
766+
assert lw_shape[2] == lora_rank
767+
hidden_size = lw_shape[1]
768+
# Expected output shape : [num_tokens, hidden_size * num_slices]
769+
assert len(o_shape) == 2
770+
assert o_shape == (num_tokens, hidden_size * num_slices)
771+
772+
return {
773+
'inputs': self.input,
774+
'lora_b_weights': self.lora_weights_lst,
775+
'output_tensor': self.output,
776+
'token_lora_mapping': self.v1_kernel_meta.token_lora_mapping,
777+
'token_indices_sorted_by_lora_ids':
778+
self.v1_kernel_meta.token_indices_sorted_by_lora_ids,
779+
'num_tokens_per_lora': self.v1_kernel_meta.num_tokens_per_lora,
780+
'lora_token_start_loc': self.v1_kernel_meta.lora_token_start_loc,
781+
'lora_ids': self.v1_kernel_meta.active_lora_ids,
782+
'offset_start': 0,
783+
'add_inputs': add_inputs,
784+
}
785+
670786
def bench_fn_kwargs(self,
671787
op_type: OpType,
672788
add_inputs: Optional[bool] = None) -> dict[str, Any]:
@@ -685,6 +801,10 @@ def bench_fn_kwargs(self,
685801
return self.as_bgmv_expand_kwargs(add_inputs)
686802
if op_type == OpType.BGMV_EXPAND_SLICE:
687803
return self.as_bgmv_expand_slice_kwargs(add_inputs)
804+
if op_type == OpType.V1_SHRINK:
805+
return self.as_v1_shrink_kwargs()
806+
if op_type == OpType.V1_EXPAND:
807+
return self.as_v1_expand_kwargs(add_inputs)
688808
raise ValueError(f"Unrecognized optype {self}")
689809

690810
def test_correctness(self, op_type: OpType,
@@ -872,12 +992,9 @@ def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]):
872992
timers = []
873993
for bench_ctx in bench_ctxs:
874994
for seq_len in args.seq_lengths:
875-
bench_ops: list[OpType] = []
876-
if seq_len == 1:
877-
# bench all decode ops
878-
bench_ops = [op for op in args.op_types if op.is_decode_op()]
879-
else:
880-
# bench all prefill ops
995+
bench_ops: list[OpType] = args.op_types
996+
if seq_len > 1:
997+
# bench only prefill ops
881998
bench_ops = [op for op in args.op_types if op.is_prefill_op()]
882999

8831000
seq_len_timers = []

0 commit comments

Comments
 (0)