@@ -133,17 +133,19 @@ def test_flashinfer_decode_with_paged_kv(
133
133
use_tensor_cores = (
134
134
(num_query_heads // num_kv_heads ) > 4 )
135
135
)
136
- wrapper .begin_forward (kv_indptr ,
137
- kv_indices ,
138
- kv_last_page_lens ,
139
- num_query_heads ,
140
- num_kv_heads ,
141
- head_size ,
142
- block_size ,
143
- "NONE" ,
144
- data_type = dtype )
145
-
146
- output = wrapper .forward (query , key_value_cache , logits_soft_cap = soft_cap )
136
+ wrapper .plan (kv_indptr ,
137
+ kv_indices ,
138
+ kv_last_page_lens ,
139
+ num_query_heads ,
140
+ num_kv_heads ,
141
+ head_size ,
142
+ block_size ,
143
+ "NONE" ,
144
+ q_data_type = dtype ,
145
+ kv_data_type = dtype ,
146
+ logits_soft_cap = soft_cap )
147
+
148
+ output = wrapper .run (query , key_value_cache )
147
149
148
150
ref_output = ref_paged_attn (query = query ,
149
151
key_cache = key_cache ,
@@ -228,7 +230,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
228
230
workspace_buffer = torch .empty (128 * 1024 * 1024 , dtype = torch .int8 )
229
231
wrapper = flashinfer .BatchPrefillWithPagedKVCacheWrapper (
230
232
workspace_buffer , "NHD" )
231
- wrapper .begin_forward (
233
+ wrapper .plan (
232
234
qo_indptr ,
233
235
kv_indptr ,
234
236
kv_indices ,
@@ -237,12 +239,14 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
237
239
num_kv_heads ,
238
240
head_size ,
239
241
block_size ,
242
+ q_data_type = dtype ,
243
+ kv_data_type = dtype ,
244
+ logits_soft_cap = soft_cap ,
240
245
)
241
246
242
- output = wrapper .forward (
247
+ output = wrapper .run (
243
248
query ,
244
249
key_value_cache ,
245
- logits_soft_cap = soft_cap ,
246
250
)
247
251
248
252
ref_output = ref_paged_attn (query = query ,
@@ -253,7 +257,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
253
257
block_tables = block_tables ,
254
258
scale = scale ,
255
259
soft_cap = soft_cap )
256
- torch .testing .assert_close (output , ref_output , atol = 1e -2 , rtol = 1e-2 ), \
260
+ torch .testing .assert_close (output , ref_output , atol = 5e -2 , rtol = 1e-2 ), \
257
261
f"{ torch .max (torch .abs (output - ref_output ))} "
258
262
259
263
@@ -332,7 +336,7 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
332
336
workspace_buffer = torch .empty (128 * 1024 * 1024 , dtype = torch .int8 )
333
337
wrapper = flashinfer .BatchPrefillWithPagedKVCacheWrapper (
334
338
workspace_buffer , "NHD" )
335
- wrapper .begin_forward (
339
+ wrapper .plan (
336
340
qo_indptr ,
337
341
kv_indptr ,
338
342
kv_indices ,
@@ -341,13 +345,12 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
341
345
num_kv_heads ,
342
346
head_size ,
343
347
block_size ,
348
+ q_data_type = dtype ,
349
+ kv_data_type = kv_cache_dtype ,
350
+ logits_soft_cap = soft_cap ,
344
351
)
345
352
346
- output = wrapper .forward (query ,
347
- kv_cache_fp8 ,
348
- logits_soft_cap = soft_cap ,
349
- k_scale = k_scale ,
350
- v_scale = v_scale )
353
+ output = wrapper .run (query , kv_cache_fp8 , k_scale = k_scale , v_scale = v_scale )
351
354
352
355
ref_output = ref_paged_attn (query = query ,
353
356
key_cache = key_cache .squeeze (1 ),
@@ -360,7 +363,7 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
360
363
del query
361
364
del block_tables
362
365
# verify prefill fp8
363
- torch .testing .assert_close (output , ref_output , atol = 1e -2 , rtol = 1e-2 ), \
366
+ torch .testing .assert_close (output , ref_output , atol = 5e -2 , rtol = 1e-2 ), \
364
367
f"{ torch .max (torch .abs (output - ref_output ))} "
365
368
366
369
@@ -439,21 +442,18 @@ def test_flashinfer_decode_with_paged_fp8_kv(
439
442
wrapper = flashinfer .\
440
443
BatchDecodeWithPagedKVCacheWrapper (workspace_buffer , "NHD" ,
441
444
use_tensor_cores = use_tensor_cores )
442
- wrapper .begin_forward (kv_indptr ,
443
- kv_indices ,
444
- kv_last_page_lens ,
445
- num_query_heads ,
446
- num_kv_heads ,
447
- head_size ,
448
- block_size ,
449
- "NONE" ,
450
- data_type = dtype ,
451
- q_data_type = dtype )
452
- output = wrapper .forward (query ,
453
- kv_cache_fp8 ,
454
- logits_soft_cap = soft_cap ,
455
- k_scale = k_scale ,
456
- v_scale = v_scale )
445
+ wrapper .plan (kv_indptr ,
446
+ kv_indices ,
447
+ kv_last_page_lens ,
448
+ num_query_heads ,
449
+ num_kv_heads ,
450
+ head_size ,
451
+ block_size ,
452
+ "NONE" ,
453
+ q_data_type = dtype ,
454
+ kv_data_type = kv_cache_dtype ,
455
+ logits_soft_cap = soft_cap )
456
+ output = wrapper .run (query , kv_cache_fp8 , k_scale = k_scale , v_scale = v_scale )
457
457
key_cache = key_value_cache [:, 0 , :, :, :].squeeze (1 )
458
458
value_cache = key_value_cache [:, 1 , :, :, :].squeeze (1 )
459
459
0 commit comments