-
-
Notifications
You must be signed in to change notification settings - Fork 10.4k
[TPU] Add support for online w8a8 quantization #22425
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
"""Tests whether TPU Int8 computation is enabled correctly. | ||
|
||
Run `pytest tests/quantization/test_tpu_int8.py`. | ||
""" | ||
import pytest | ||
|
||
from vllm.model_executor.layers.linear import LinearBase | ||
from vllm.model_executor.layers.quantization.tpu_int8 import ( | ||
TPUInt8LinearMethod) | ||
from vllm.platforms import current_platform | ||
|
||
from ...models.registry import HF_EXAMPLE_MODELS | ||
|
||
MODELS = ["Qwen/Qwen2.5-0.5B-Instruct"] | ||
|
||
|
||
@pytest.mark.skipif(not current_platform.is_tpu(), | ||
reason="TPU Int8 is only enabled for TPUs.") | ||
@pytest.mark.parametrize("model", MODELS) | ||
@pytest.mark.parametrize("dtype", ["bfloat16"]) | ||
@pytest.mark.parametrize("max_tokens", [10]) | ||
@pytest.mark.parametrize( | ||
"hf_overrides", | ||
[ | ||
# w8a8 dynamic activation | ||
{ | ||
'quantization_config': { | ||
'quant_method': 'tpu_int8', | ||
'activation_scheme': 'dynamic' | ||
} | ||
} | ||
]) | ||
def test_model_tpu_int8(vllm_runner, model: str, dtype: str, max_tokens: int, | ||
hf_overrides: dict, monkeypatch) -> None: | ||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model) | ||
model_info.check_transformers_version(on_fail="skip") | ||
|
||
activation_scheme = hf_overrides.get('quantization_config', | ||
{}).get('activation_scheme') | ||
quantize_activation = activation_scheme == 'dynamic' | ||
|
||
# Allows using apply_model | ||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") | ||
# Prevent error from re-initializing cache | ||
monkeypatch.setenv("VLLM_XLA_CACHE_PATH", "") | ||
|
||
prompts = [ | ||
"A robot may not injure a human being", | ||
"It is only with the heart that one can see rightly;", | ||
"The greatest glory in living lies not in never falling,", | ||
] | ||
answers = [ | ||
"or, being injured, not kill, except in", | ||
"without the heart, one can only see wrongly.", | ||
"but in rising every time we fall. - Nelson" | ||
] | ||
|
||
with vllm_runner(model, dtype=dtype, hf_overrides=hf_overrides) as vllm: | ||
|
||
def check_model(model): | ||
for name, module in model.named_modules(): | ||
if not isinstance(module, LinearBase): | ||
continue | ||
quant_method = module.quant_method | ||
assert isinstance(quant_method, TPUInt8LinearMethod) | ||
assert quant_method.quantize_activation == quantize_activation | ||
|
||
vllm.apply_model(check_model) | ||
outputs = vllm.generate_greedy(prompts, max_tokens) | ||
for (_, output), answer in zip(outputs, answers): | ||
assert answer in output |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current
try...except
block only catches anImportError
iftorch_xla
or thecustom_kernel
module is missing. However, it doesn't handle the case wheretorch_xla
is installed but is an older version that doesn't have thequantized_matmul_int8
operator. In that scenario, anAttributeError
will be raised later when the operator is called, which can be confusing for the user.To provide a better user experience and a more informative error message, it's best to check for the operator's existence within the
try
block and catch bothImportError
andAttributeError
. This ensures that users with an incompatibletorch_xla
version get a clear message on how to resolve the issue.