Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
269f965
[misc] remove deprecated call to `end_forward` in flashinfer backend
abmfy Dec 14, 2024
8c375a3
[flashinfer] upgrade to flashinfer 0.2.0
abmfy Dec 20, 2024
a62b854
[style] fix yapf check
abmfy Dec 20, 2024
b37ff55
[FlashInfer] Pass infered global hyperparameters to `plan`
abmfy Dec 31, 2024
72bdf7e
[FlashInfer] Cache inferred global hyperparameters
abmfy Dec 31, 2024
97dcedc
[Misc] Use `typing.Optional` for Python 3.9 compatability
abmfy Dec 31, 2024
56798c5
[Style] Fix lint errors
abmfy Dec 31, 2024
706a6f6
Merge branch 'main' into flashinfer-0.2
abmfy Jan 22, 2025
dacb6af
[FlashInfer] Cache global hyperparameters in AttentionMetadataBuilder…
abmfy Jan 22, 2025
06fa7cc
[Style] Fix ruff
abmfy Jan 22, 2025
bc480b0
[FlashInfer] Get per layer params from vllm config
abmfy Jan 23, 2025
5a70aac
[FlashInfer] Store vllm config in attention state
abmfy Jan 23, 2025
e0397e9
[CI] Update FlashInfer version
abmfy Jan 23, 2025
ec49257
format
youkaichao Jan 23, 2025
500ff5b
Merge branch 'main' into flashinfer-0.2
abmfy Jan 24, 2025
bde6807
[Misc] Add space in assert message
abmfy Jan 24, 2025
69d7c8d
[FlashInfer] Warn on models with interleaved attention
abmfy Jan 24, 2025
d4d63dc
[Test] Change backend to flash_attn for gemma in compile tests
abmfy Jan 24, 2025
6e7e933
fix inconsistent vllm config
youkaichao Jan 25, 2025
0b47067
Merge branch 'flashinfer-0.2' of github.com:abmfy/vllm-flashinfer int…
abmfy Jan 25, 2025
f6e33a7
[Test] Skip tests for Gemma2 with FlashInfer backend
abmfy Jan 25, 2025
847a4d6
[CI] Build FlashInfer from source
abmfy Jan 25, 2025
5b0fe64
[CI] Fix FlashInfer build command
abmfy Jan 25, 2025
69445cd
[CI] Fix Dockerfile
abmfy Jan 25, 2025
963aff7
[CI] Fix FlashInfer AOT build in Dockerfile
abmfy Jan 25, 2025
ae9da66
fix flashinfer docker build
youkaichao Jan 26, 2025
afa377c
Merge branch 'main' into flashinfer-0.2
youkaichao Jan 26, 2025
269e1eb
fix build command
youkaichao Jan 26, 2025
2e50ab8
move command
youkaichao Jan 26, 2025
0fe979d
unify to use setup.py
youkaichao Jan 26, 2025
3dd209c
fix cd
youkaichao Jan 26, 2025
bcd04fd
fix recursive clone
youkaichao Jan 26, 2025
bb44221
comment
youkaichao Jan 26, 2025
5ca67ae
[CI] Use precompiled FlashInfer AOT wheel
abmfy Jan 26, 2025
3c89bfb
[CI] Temporarily switch to CUDA develop image for vllm-base
abmfy Jan 26, 2025
293fdd6
Merge branch 'main' into flashinfer-0.2
abmfy Jan 26, 2025
5d8ad22
also install jit build dependency
youkaichao Jan 26, 2025
4d57ef9
[FlashInfer] Fix type of k_scale and v_scale
abmfy Jan 26, 2025
33ff07b
Merge branch 'main' into flashinfer-0.2
abmfy Jan 26, 2025
ef15977
Merge branch 'flashinfer-0.2' of github.com:abmfy/vllm-flashinfer int…
abmfy Jan 26, 2025
21efc67
fix reshape_and_cache_flash
youkaichao Jan 27, 2025
a6b6fe8
use new flashinfer
youkaichao Jan 27, 2025
1f13235
Merge branch 'main' into flashinfer-0.2
youkaichao Jan 27, 2025
f17dbc3
update v1 tests
youkaichao Jan 27, 2025
506b641
refactor test
youkaichao Jan 27, 2025
2e476a2
revert
youkaichao Jan 27, 2025
95b5493
add comments
youkaichao Jan 27, 2025
55b55d3
only check compile when loading
youkaichao Jan 27, 2025
1f80aee
test in ci?
youkaichao Jan 27, 2025
5be3783
fix one test
youkaichao Jan 27, 2025
071a68e
fix test_flashinfer_prefill_with_paged_kv
youkaichao Jan 27, 2025
0e0f57f
relax test for prefill
youkaichao Jan 27, 2025
2134e77
fix test_flashinfer_prefill_with_paged_fp8_kv
youkaichao Jan 27, 2025
8e42297
relax test for prefill
youkaichao Jan 27, 2025
b4a7992
fix test_flashinfer_decode_with_paged_fp8_kv
youkaichao Jan 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,16 @@ steps:
- vllm/
- tests/v1
commands:
- VLLM_USE_V1=1 pytest -v -s v1
# split the test to avoid interference
- VLLM_USE_V1=1 pytest -v -s v1/core
- VLLM_USE_V1=1 pytest -v -s v1/engine
- VLLM_USE_V1=1 pytest -v -s v1/sample
- VLLM_USE_V1=1 pytest -v -s v1/worker
- VLLM_USE_V1=1 pytest -v -s v1/test_stats.py
- VLLM_USE_V1=1 pytest -v -s v1/test_utils.py
# TODO: accuracy does not match, whether setting
# VLLM_USE_FLASHINFER_SAMPLER or not on H100.
- VLLM_USE_V1=1 pytest -v -s v1/e2e

- label: Examples Test # 25min
working_dir: "/vllm-workspace/examples"
Expand Down
23 changes: 21 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@ RUN --mount=type=cache,target=/root/.cache/pip \

#################### vLLM installation IMAGE ####################
# image with vLLM installed
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu22.04 AS vllm-base
# TODO: Restore to base image after FlashInfer AOT wheel fixed
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 AS vllm-base
ARG CUDA_VERSION=12.4.1
ARG PYTHON_VERSION=3.12
WORKDIR /vllm-workspace
Expand Down Expand Up @@ -194,12 +195,30 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
--mount=type=cache,target=/root/.cache/pip \
python3 -m pip install dist/*.whl --verbose

# How to build this FlashInfer wheel:
# $ export FLASHINFER_ENABLE_AOT=1
# $ # Note we remove 7.0 from the arch list compared to the list below, since FlashInfer only supports sm75+
# $ export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.6 8.9 9.0+PTX'
# $ git clone https://github.com/flashinfer-ai/flashinfer.git --recursive
# $ cd flashinfer
# $ git checkout 524304395bd1d8cd7d07db083859523fcaa246a4
# $ python3 setup.py bdist_wheel --dist-dir=dist --verbose

RUN --mount=type=cache,target=/root/.cache/pip \
. /etc/environment && \
if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.6/flashinfer-0.1.6+cu121torch2.4-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl; \
python3 -m pip install https://wheels.vllm.ai/flashinfer/524304395bd1d8cd7d07db083859523fcaa246a4/flashinfer_python-0.2.0.post1-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl; \
fi
COPY examples examples

# Although we build Flashinfer with AOT mode, there's still
# some issues w.r.t. JIT compilation. Therefore we need to
# install build dependencies for JIT compilation.
# TODO: Remove this once FlashInfer AOT wheel is fixed
COPY requirements-build.txt requirements-build.txt
RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install -r requirements-build.txt

#################### vLLM installation IMAGE ####################

#################### TEST IMAGE ####################
Expand Down
5 changes: 3 additions & 2 deletions tests/basic_correctness/test_basic_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,10 @@ def test_models(
if backend == "FLASHINFER" and current_platform.is_rocm():
pytest.skip("Flashinfer does not support ROCm/HIP.")

if backend == "XFORMERS" and model == "google/gemma-2-2b-it":
if backend in ("XFORMERS",
"FLASHINFER") and model == "google/gemma-2-2b-it":
pytest.skip(
"XFORMERS does not support gemma2 with full context length.")
f"{backend} does not support gemma2 with full context length.")

os.environ["VLLM_ATTENTION_BACKEND"] = backend

Expand Down
2 changes: 1 addition & 1 deletion tests/compile/test_basic_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class TestSetting:
model_args=["--task", "embed"],
pp_size=1,
tp_size=1,
attn_backend="FLASHINFER",
attn_backend="FLASH_ATTN",
method="encode",
fullgraph=True,
),
Expand Down
74 changes: 37 additions & 37 deletions tests/kernels/test_flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,17 +133,19 @@ def test_flashinfer_decode_with_paged_kv(
use_tensor_cores=(
(num_query_heads//num_kv_heads) > 4)
)
wrapper.begin_forward(kv_indptr,
kv_indices,
kv_last_page_lens,
num_query_heads,
num_kv_heads,
head_size,
block_size,
"NONE",
data_type=dtype)

output = wrapper.forward(query, key_value_cache, logits_soft_cap=soft_cap)
wrapper.plan(kv_indptr,
kv_indices,
kv_last_page_lens,
num_query_heads,
num_kv_heads,
head_size,
block_size,
"NONE",
q_data_type=dtype,
kv_data_type=dtype,
logits_soft_cap=soft_cap)

output = wrapper.run(query, key_value_cache)

ref_output = ref_paged_attn(query=query,
key_cache=key_cache,
Expand Down Expand Up @@ -228,7 +230,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, "NHD")
wrapper.begin_forward(
wrapper.plan(
qo_indptr,
kv_indptr,
kv_indices,
Expand All @@ -237,12 +239,14 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
num_kv_heads,
head_size,
block_size,
q_data_type=dtype,
kv_data_type=dtype,
logits_soft_cap=soft_cap,
)

output = wrapper.forward(
output = wrapper.run(
query,
key_value_cache,
logits_soft_cap=soft_cap,
)

ref_output = ref_paged_attn(query=query,
Expand All @@ -253,7 +257,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
block_tables=block_tables,
scale=scale,
soft_cap=soft_cap)
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \
torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"


Expand Down Expand Up @@ -332,7 +336,7 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, "NHD")
wrapper.begin_forward(
wrapper.plan(
qo_indptr,
kv_indptr,
kv_indices,
Expand All @@ -341,13 +345,12 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
num_kv_heads,
head_size,
block_size,
q_data_type=dtype,
kv_data_type=kv_cache_dtype,
logits_soft_cap=soft_cap,
)

output = wrapper.forward(query,
kv_cache_fp8,
logits_soft_cap=soft_cap,
k_scale=k_scale,
v_scale=v_scale)
output = wrapper.run(query, kv_cache_fp8, k_scale=k_scale, v_scale=v_scale)

ref_output = ref_paged_attn(query=query,
key_cache=key_cache.squeeze(1),
Expand All @@ -360,7 +363,7 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
del query
del block_tables
# verify prefill fp8
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \
torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"


Expand Down Expand Up @@ -439,21 +442,18 @@ def test_flashinfer_decode_with_paged_fp8_kv(
wrapper = flashinfer.\
BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD",
use_tensor_cores=use_tensor_cores)
wrapper.begin_forward(kv_indptr,
kv_indices,
kv_last_page_lens,
num_query_heads,
num_kv_heads,
head_size,
block_size,
"NONE",
data_type=dtype,
q_data_type=dtype)
output = wrapper.forward(query,
kv_cache_fp8,
logits_soft_cap=soft_cap,
k_scale=k_scale,
v_scale=v_scale)
wrapper.plan(kv_indptr,
kv_indices,
kv_last_page_lens,
num_query_heads,
num_kv_heads,
head_size,
block_size,
"NONE",
q_data_type=dtype,
kv_data_type=kv_cache_dtype,
logits_soft_cap=soft_cap)
output = wrapper.run(query, kv_cache_fp8, k_scale=k_scale, v_scale=v_scale)
key_cache = key_value_cache[:, 0, :, :, :].squeeze(1)
value_cache = key_value_cache[:, 1, :, :, :].squeeze(1)

Expand Down
Loading
Loading