Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .buildkite/scripts/hardware_ci/run-tpu-v1-test-part2.sh
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ run_and_track_test 5 "test_spmd_model_weight_loading.py" \
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_spmd_model_weight_loading.py"
run_and_track_test 6 "test_kv_cache_update_kernel.py" \
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_kv_cache_update_kernel.py"
run_and_track_test 7 "test_tpu_int8.py" \
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_tpu_int8.py"

# After all tests have been attempted, exit with the overall status.
if [ "$overall_script_exit_code" -ne 0 ]; then
Expand Down
73 changes: 73 additions & 0 deletions tests/v1/tpu/test_tpu_int8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests whether TPU Int8 computation is enabled correctly.

Run `pytest tests/quantization/test_tpu_int8.py`.
"""
import pytest

from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.tpu_int8 import (
TPUInt8LinearMethod)
from vllm.platforms import current_platform

from ...models.registry import HF_EXAMPLE_MODELS

MODELS = ["Qwen/Qwen2.5-0.5B-Instruct"]


@pytest.mark.skipif(not current_platform.is_tpu(),
reason="TPU Int8 is only enabled for TPUs.")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [10])
@pytest.mark.parametrize(
"hf_overrides",
[
# w8a8 dynamic activation
{
'quantization_config': {
'quant_method': 'tpu_int8',
'activation_scheme': 'dynamic'
}
}
])
def test_model_tpu_int8(vllm_runner, model: str, dtype: str, max_tokens: int,
hf_overrides: dict, monkeypatch) -> None:
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_transformers_version(on_fail="skip")

activation_scheme = hf_overrides.get('quantization_config',
{}).get('activation_scheme')
quantize_activation = activation_scheme == 'dynamic'

# Allows using apply_model
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
# Prevent error from re-initializing cache
monkeypatch.setenv("VLLM_XLA_CACHE_PATH", "")

prompts = [
"A robot may not injure a human being",
"It is only with the heart that one can see rightly;",
"The greatest glory in living lies not in never falling,",
]
answers = [
"or, being injured, not kill, except in",
"without the heart, one can only see wrongly.",
"but in rising every time we fall. - Nelson"
]

with vllm_runner(model, dtype=dtype, hf_overrides=hf_overrides) as vllm:

def check_model(model):
for name, module in model.named_modules():
if not isinstance(module, LinearBase):
continue
quant_method = module.quant_method
assert isinstance(quant_method, TPUInt8LinearMethod)
assert quant_method.quantize_activation == quantize_activation

vllm.apply_model(check_model)
outputs = vllm.generate_greedy(prompts, max_tokens)
for (_, output), answer in zip(outputs, answers):
assert answer in output
10 changes: 7 additions & 3 deletions vllm/model_executor/layers/quantization/tpu_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
QuantizationConfig)
from vllm.model_executor.parameter import ModelWeightParameter

ACTIVATION_SCHEMES = ["none"]
ACTIVATION_SCHEMES = ["none", "dynamic"]


class Int8TpuConfig(QuantizationConfig):
Expand Down Expand Up @@ -61,6 +61,9 @@ class TPUInt8LinearMethod(LinearMethodBase):

def __init__(self, quant_config: Int8TpuConfig):
self.quant_config = quant_config
self.quantize_activation = False
if self.quant_config.activation_scheme == 'dynamic':
self.quantize_activation = True

def create_weights(self, layer: Module, input_size_per_partition: int,
output_partition_sizes: list[int], input_size: int,
Expand Down Expand Up @@ -107,15 +110,16 @@ def apply(self,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
try:
import torch_xla.experimental.xla_quantized_matmul # noqa: F401
import torch_xla.experimental.custom_kernel # noqa: F401
except ImportError as err:
raise ImportError(
"Please install torch_xla by following the instructions at "
"https://docs.vllm.ai/en/latest/getting_started/tpu-installation.html " # noqa: E501
"to run vLLM on TPU.") from err
Comment on lines 112 to 118
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current try...except block only catches an ImportError if torch_xla or the custom_kernel module is missing. However, it doesn't handle the case where torch_xla is installed but is an older version that doesn't have the quantized_matmul_int8 operator. In that scenario, an AttributeError will be raised later when the operator is called, which can be confusing for the user.

To provide a better user experience and a more informative error message, it's best to check for the operator's existence within the try block and catch both ImportError and AttributeError. This ensures that users with an incompatible torch_xla version get a clear message on how to resolve the issue.

        try:
            import torch_xla.experimental.custom_kernel  # noqa: F401
            # Eagerly check for the op to provide a better error message.
            _ = torch.ops.xla.quantized_matmul_int8
        except (ImportError, AttributeError) as err:
            raise ImportError(
                "torch_xla is not installed or is too old to support w8a8 "
                "quantization. Please install/update torch_xla by following "
                "the instructions at "
                "https://docs.vllm.ai/en/latest/getting_started/tpu-installation.html "  # noqa: E501
                "to run vLLM on TPU.") from err

weight = layer.weight
scale = layer.scale
out = torch.ops.xla.quantized_matmul(x, weight, scale)
out = torch.ops.xla.quantized_matmul_int8(
x, weight, scale, quantize_activation=self.quantize_activation)
if bias is not None:
out = out + bias
return out