Skip to content

Commit 2bc3fbb

Browse files
abmfyyoukaichao
andauthored
[FlashInfer] Upgrade to 0.2.0 (#11194)
Signed-off-by: Bowen Wang <[email protected]> Signed-off-by: youkaichao <[email protected]> Co-authored-by: youkaichao <[email protected]>
1 parent 3f1fc74 commit 2bc3fbb

File tree

10 files changed

+257
-75
lines changed

10 files changed

+257
-75
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,16 @@ steps:
183183
- vllm/
184184
- tests/v1
185185
commands:
186-
- VLLM_USE_V1=1 pytest -v -s v1
186+
# split the test to avoid interference
187+
- VLLM_USE_V1=1 pytest -v -s v1/core
188+
- VLLM_USE_V1=1 pytest -v -s v1/engine
189+
- VLLM_USE_V1=1 pytest -v -s v1/sample
190+
- VLLM_USE_V1=1 pytest -v -s v1/worker
191+
- VLLM_USE_V1=1 pytest -v -s v1/test_stats.py
192+
- VLLM_USE_V1=1 pytest -v -s v1/test_utils.py
193+
# TODO: accuracy does not match, whether setting
194+
# VLLM_USE_FLASHINFER_SAMPLER or not on H100.
195+
- VLLM_USE_V1=1 pytest -v -s v1/e2e
187196

188197
- label: Examples Test # 25min
189198
working_dir: "/vllm-workspace/examples"

Dockerfile

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,8 @@ RUN --mount=type=cache,target=/root/.cache/pip \
149149

150150
#################### vLLM installation IMAGE ####################
151151
# image with vLLM installed
152-
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu22.04 AS vllm-base
152+
# TODO: Restore to base image after FlashInfer AOT wheel fixed
153+
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 AS vllm-base
153154
ARG CUDA_VERSION=12.4.1
154155
ARG PYTHON_VERSION=3.12
155156
WORKDIR /vllm-workspace
@@ -194,12 +195,30 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
194195
--mount=type=cache,target=/root/.cache/pip \
195196
python3 -m pip install dist/*.whl --verbose
196197

198+
# How to build this FlashInfer wheel:
199+
# $ export FLASHINFER_ENABLE_AOT=1
200+
# $ # Note we remove 7.0 from the arch list compared to the list below, since FlashInfer only supports sm75+
201+
# $ export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.6 8.9 9.0+PTX'
202+
# $ git clone https://github.com/flashinfer-ai/flashinfer.git --recursive
203+
# $ cd flashinfer
204+
# $ git checkout 524304395bd1d8cd7d07db083859523fcaa246a4
205+
# $ python3 setup.py bdist_wheel --dist-dir=dist --verbose
206+
197207
RUN --mount=type=cache,target=/root/.cache/pip \
198208
. /etc/environment && \
199209
if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \
200-
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.6/flashinfer-0.1.6+cu121torch2.4-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl; \
210+
python3 -m pip install https://wheels.vllm.ai/flashinfer/524304395bd1d8cd7d07db083859523fcaa246a4/flashinfer_python-0.2.0.post1-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl; \
201211
fi
202212
COPY examples examples
213+
214+
# Although we build Flashinfer with AOT mode, there's still
215+
# some issues w.r.t. JIT compilation. Therefore we need to
216+
# install build dependencies for JIT compilation.
217+
# TODO: Remove this once FlashInfer AOT wheel is fixed
218+
COPY requirements-build.txt requirements-build.txt
219+
RUN --mount=type=cache,target=/root/.cache/pip \
220+
python3 -m pip install -r requirements-build.txt
221+
203222
#################### vLLM installation IMAGE ####################
204223

205224
#################### TEST IMAGE ####################

tests/basic_correctness/test_basic_correctness.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,10 @@ def test_models(
6161
if backend == "FLASHINFER" and current_platform.is_rocm():
6262
pytest.skip("Flashinfer does not support ROCm/HIP.")
6363

64-
if backend == "XFORMERS" and model == "google/gemma-2-2b-it":
64+
if backend in ("XFORMERS",
65+
"FLASHINFER") and model == "google/gemma-2-2b-it":
6566
pytest.skip(
66-
"XFORMERS does not support gemma2 with full context length.")
67+
f"{backend} does not support gemma2 with full context length.")
6768

6869
os.environ["VLLM_ATTENTION_BACKEND"] = backend
6970

tests/compile/test_basic_correctness.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class TestSetting:
5858
model_args=["--task", "embed"],
5959
pp_size=1,
6060
tp_size=1,
61-
attn_backend="FLASHINFER",
61+
attn_backend="FLASH_ATTN",
6262
method="encode",
6363
fullgraph=True,
6464
),

tests/kernels/test_flashinfer.py

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -133,17 +133,19 @@ def test_flashinfer_decode_with_paged_kv(
133133
use_tensor_cores=(
134134
(num_query_heads//num_kv_heads) > 4)
135135
)
136-
wrapper.begin_forward(kv_indptr,
137-
kv_indices,
138-
kv_last_page_lens,
139-
num_query_heads,
140-
num_kv_heads,
141-
head_size,
142-
block_size,
143-
"NONE",
144-
data_type=dtype)
145-
146-
output = wrapper.forward(query, key_value_cache, logits_soft_cap=soft_cap)
136+
wrapper.plan(kv_indptr,
137+
kv_indices,
138+
kv_last_page_lens,
139+
num_query_heads,
140+
num_kv_heads,
141+
head_size,
142+
block_size,
143+
"NONE",
144+
q_data_type=dtype,
145+
kv_data_type=dtype,
146+
logits_soft_cap=soft_cap)
147+
148+
output = wrapper.run(query, key_value_cache)
147149

148150
ref_output = ref_paged_attn(query=query,
149151
key_cache=key_cache,
@@ -228,7 +230,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
228230
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
229231
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
230232
workspace_buffer, "NHD")
231-
wrapper.begin_forward(
233+
wrapper.plan(
232234
qo_indptr,
233235
kv_indptr,
234236
kv_indices,
@@ -237,12 +239,14 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
237239
num_kv_heads,
238240
head_size,
239241
block_size,
242+
q_data_type=dtype,
243+
kv_data_type=dtype,
244+
logits_soft_cap=soft_cap,
240245
)
241246

242-
output = wrapper.forward(
247+
output = wrapper.run(
243248
query,
244249
key_value_cache,
245-
logits_soft_cap=soft_cap,
246250
)
247251

248252
ref_output = ref_paged_attn(query=query,
@@ -253,7 +257,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
253257
block_tables=block_tables,
254258
scale=scale,
255259
soft_cap=soft_cap)
256-
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \
260+
torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2), \
257261
f"{torch.max(torch.abs(output - ref_output))}"
258262

259263

@@ -332,7 +336,7 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
332336
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
333337
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
334338
workspace_buffer, "NHD")
335-
wrapper.begin_forward(
339+
wrapper.plan(
336340
qo_indptr,
337341
kv_indptr,
338342
kv_indices,
@@ -341,13 +345,12 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
341345
num_kv_heads,
342346
head_size,
343347
block_size,
348+
q_data_type=dtype,
349+
kv_data_type=kv_cache_dtype,
350+
logits_soft_cap=soft_cap,
344351
)
345352

346-
output = wrapper.forward(query,
347-
kv_cache_fp8,
348-
logits_soft_cap=soft_cap,
349-
k_scale=k_scale,
350-
v_scale=v_scale)
353+
output = wrapper.run(query, kv_cache_fp8, k_scale=k_scale, v_scale=v_scale)
351354

352355
ref_output = ref_paged_attn(query=query,
353356
key_cache=key_cache.squeeze(1),
@@ -360,7 +363,7 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
360363
del query
361364
del block_tables
362365
# verify prefill fp8
363-
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \
366+
torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2), \
364367
f"{torch.max(torch.abs(output - ref_output))}"
365368

366369

@@ -439,21 +442,18 @@ def test_flashinfer_decode_with_paged_fp8_kv(
439442
wrapper = flashinfer.\
440443
BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD",
441444
use_tensor_cores=use_tensor_cores)
442-
wrapper.begin_forward(kv_indptr,
443-
kv_indices,
444-
kv_last_page_lens,
445-
num_query_heads,
446-
num_kv_heads,
447-
head_size,
448-
block_size,
449-
"NONE",
450-
data_type=dtype,
451-
q_data_type=dtype)
452-
output = wrapper.forward(query,
453-
kv_cache_fp8,
454-
logits_soft_cap=soft_cap,
455-
k_scale=k_scale,
456-
v_scale=v_scale)
445+
wrapper.plan(kv_indptr,
446+
kv_indices,
447+
kv_last_page_lens,
448+
num_query_heads,
449+
num_kv_heads,
450+
head_size,
451+
block_size,
452+
"NONE",
453+
q_data_type=dtype,
454+
kv_data_type=kv_cache_dtype,
455+
logits_soft_cap=soft_cap)
456+
output = wrapper.run(query, kv_cache_fp8, k_scale=k_scale, v_scale=v_scale)
457457
key_cache = key_value_cache[:, 0, :, :, :].squeeze(1)
458458
value_cache = key_value_cache[:, 1, :, :, :].squeeze(1)
459459

0 commit comments

Comments
 (0)