Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/offline_inference/offline_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
sampling_params = SamplingParams() #temperature=0.8, top_p=0.95)

# Create an LLM.
llm = LLM(model="facebook/opt-125m")
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", max_model_len=256, max_num_seqs=16, tensor_parallel_size=4)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
42 changes: 42 additions & 0 deletions llava_tpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset

# Initialize LLaVA model
llm = LLM(
model="llava-hf/llava-1.5-7b-hf",
max_model_len=2048,
max_num_seqs=2,
dtype="bfloat16",
)

# Load sample image
image = ImageAsset("cherry_blossom").pil_image.convert("RGB")

# Create two different prompts
prompts = [
"What do you see in this image?",
"What colors are most prominent in this image?",
]

# Format prompts according to LLaVA's requirements
formatted_inputs = [
{
"prompt": f"USER: <image>\n{prompt}\nASSISTANT:",
"multi_modal_data": {"image": image}
}
for prompt in prompts
]

# Set up sampling parameters
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=64,
)

# Generate responses
outputs = llm.generate(formatted_inputs, sampling_params=sampling_params)

# Print results
for i, output in enumerate(outputs):
print(f"\nPrompt {i + 1}: {prompts[i]}")
print(f"Response: {output.outputs[0].text}")
1 change: 1 addition & 0 deletions requirements-tpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ setuptools-scm>=8
wheel
jinja2
ray[default]
ray[adag] # TODO: Remove this

# Install torch_xla
--pre
Expand Down
13 changes: 10 additions & 3 deletions tests/entrypoints/openai/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,21 @@ def run_test(more_args):
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"


@pytest.mark.skipif(not current_platform.is_cuda(),
reason="V1 currently only supported on CUDA")
@pytest.mark.skipif(not current_platform.is_cuda()
and not current_platform.is_tpu(),
reason="V1 currently only supported on CUDA and TPU")
def test_lm_eval_accuracy_v1_engine(monkeypatch):
"""Run with the V1 Engine."""

with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
run_test([])
more_args = []

# Limit compilation time for V1
if current_platform.is_tpu():
more_args = ["--max-num-seqs", "64"]

run_test(more_args)


@pytest.mark.parametrize("more_args", MORE_ARGS_LIST)
Expand Down
7 changes: 7 additions & 0 deletions vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,13 @@ def _cached_get_attn_backend(
logger.info("Using Pallas backend.")
from vllm.attention.backends.pallas import PallasAttentionBackend
return PallasAttentionBackend
elif backend == _Backend.PALLAS_VLLM_V1:
logger.info("Using Pallas backend.")
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend
as
PallasAttentionBackendV1
)
return PallasAttentionBackendV1
elif backend == _Backend.NO_ATTENTION:
from vllm.attention.backends.placeholder_attn import (
PlaceholderAttentionBackend)
Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class _Backend(enum.Enum):
FLASHINFER = enum.auto()
HPU_ATTN = enum.auto()
PALLAS = enum.auto()
PALLAS_VLLM_V1 = enum.auto()
IPEX = enum.auto()
NO_ATTENTION = enum.auto()

Expand Down
56 changes: 45 additions & 11 deletions vllm/platforms/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
else:
VllmConfig = None

import vllm.envs as envs

logger = init_logger(__name__)


Expand All @@ -23,9 +25,15 @@ class TpuPlatform(Platform):

@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
if selected_backend != _Backend.PALLAS:
logger.info("Cannot use %s backend on TPU.", selected_backend)
return _Backend.PALLAS
if envs.VLLM_USE_V1:
if selected_backend != _Backend.PALLAS_VLLM_V1:
logger.info("[V1] Cannot use %s backend on TPU.",
selected_backend)
return _Backend.PALLAS_VLLM_V1
else:
if selected_backend != _Backend.PALLAS:
logger.info("Cannot use %s backend on TPU.", selected_backend)
return _Backend.PALLAS

@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
Expand All @@ -37,7 +45,7 @@ def get_device_total_memory(cls, device_id: int = 0) -> int:

@classmethod
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
return True
return not envs.VLLM_USE_V1

@classmethod
def inference_mode(cls):
Expand All @@ -52,11 +60,18 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
cache_config.block_size = 16

compilation_config = vllm_config.compilation_config
if compilation_config.level == CompilationLevel.NO_COMPILATION:
# TPU does not support NO_COMPILATION

# TPU only supports DYNAMO_ONCE compilation level
if (compilation_config.level == CompilationLevel.NO_COMPILATION
or compilation_config.level == CompilationLevel.PIECEWISE):
logger.info("[TPU] Forcing DYNAMO_ONCE compilation level")
compilation_config.level = CompilationLevel.DYNAMO_ONCE

assert compilation_config.level < CompilationLevel.PIECEWISE,\
"TPU does not support Inductor."
("Current compilation level = {} but needs to be less"
" than {}".format(
compilation_config.level,
CompilationLevel.PIECEWISE))

if compilation_config.backend == "":
compilation_config.backend = "openxla"
Expand All @@ -67,8 +82,27 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
parallel_config = vllm_config.parallel_config
scheduler_config = vllm_config.scheduler_config
if parallel_config.worker_cls == "auto":
if scheduler_config.is_multi_step:
parallel_config.worker_cls = \
"vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker"
if envs.VLLM_USE_V1:
parallel_config.worker_cls = "vllm.v1.worker.tpu_worker.TPUWorker"
else:
parallel_config.worker_cls = "vllm.worker.tpu_worker.TPUWorker"
if scheduler_config.is_multi_step:
parallel_config.worker_cls = \
"vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker"
else:
parallel_config.worker_cls = "vllm.worker.tpu_worker.TPUWorker"

# Adjust scheduler config for V1
# TODO: Add support for these
if envs.VLLM_USE_V1:
if vllm_config.cache_config.enable_prefix_caching:
logger.info("[V1][TPU] Disable prefix caching")
vllm_config.cache_config.enable_prefix_caching = False

if vllm_config.scheduler_config.chunked_prefill_enabled:
logger.info("[V1][TPU] Disable chunked prefill")
vllm_config.scheduler_config.chunked_prefill_enabled = False

@classmethod
def is_pin_memory_available(cls):
logger.warning("Pin memory is not supported on TPU.")
return False
Loading
Loading