Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 163 additions & 0 deletions tests/entrypoints/test_renderer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from dataclasses import dataclass
from typing import Optional
from unittest.mock import AsyncMock, MagicMock

import pytest

from vllm.entrypoints.renderer import CompletionRenderer


@dataclass
class MockModelConfig:
max_model_len: int = 100
encoder_config: Optional[dict] = None


class MockTokenizerResult:

def __init__(self, input_ids):
self.input_ids = input_ids


@pytest.fixture
def mock_model_config():
return MockModelConfig()


@pytest.fixture
def mock_tokenizer():
tokenizer = MagicMock()
return tokenizer


@pytest.fixture
def mock_async_tokenizer():
async_tokenizer = AsyncMock()
return async_tokenizer


@pytest.fixture
def renderer(mock_model_config, mock_tokenizer):
return CompletionRenderer(model_config=mock_model_config,
tokenizer=mock_tokenizer,
async_tokenizer_pool={})


class TestRenderPrompt:
"""Test Category A: Basic Functionality Tests"""

@pytest.mark.asyncio
async def test_token_input(self, renderer):
tokens = [101, 7592, 2088]
results = await renderer.render_prompt(prompt_or_prompts=tokens,
max_length=100)

assert len(results) == 1
assert results[0]["prompt_token_ids"] == tokens

@pytest.mark.asyncio
async def test_token_list_input(self, renderer):
token_lists = [[101, 7592, 2088], [102, 1234, 5678, 9012], [103, 4567]]
results = await renderer.render_prompt(prompt_or_prompts=token_lists,
max_length=100)

assert len(results) == 3
assert results[0]["prompt_token_ids"] == [101, 7592, 2088]
assert results[1]["prompt_token_ids"] == [102, 1234, 5678, 9012]
assert results[2]["prompt_token_ids"] == [103, 4567]

@pytest.mark.asyncio
async def test_text_input(self, renderer, mock_async_tokenizer):
mock_async_tokenizer.return_value = MockTokenizerResult(
[101, 7592, 2088])
renderer.async_tokenizer_pool[
renderer.tokenizer] = mock_async_tokenizer

results = await renderer.render_prompt(prompt_or_prompts="Hello world",
max_length=100)

assert len(results) == 1
assert results[0]["prompt_token_ids"] == [101, 7592, 2088]
mock_async_tokenizer.assert_called_once()

@pytest.mark.asyncio
async def test_text_list_input(self, renderer, mock_async_tokenizer):
mock_async_tokenizer.return_value = MockTokenizerResult(
[101, 7592, 2088])
renderer.async_tokenizer_pool[
renderer.tokenizer] = mock_async_tokenizer

text_list_input = ["Hello world", "How are you?", "Good morning"]
results = await renderer.render_prompt(
prompt_or_prompts=text_list_input, max_length=100)

assert len(results) == 3
for result in results:
assert result["prompt_token_ids"] == [101, 7592, 2088]
assert mock_async_tokenizer.call_count == 3

@pytest.mark.asyncio
async def test_no_truncation(self, renderer, mock_async_tokenizer):
mock_async_tokenizer.return_value = MockTokenizerResult(
[101, 7592, 2088])
renderer.async_tokenizer_pool[
renderer.tokenizer] = mock_async_tokenizer

results = await renderer.render_prompt(prompt_or_prompts="Hello world",
max_length=100)

assert len(results) == 1
call_args = mock_async_tokenizer.call_args
assert "truncation" not in call_args.kwargs or call_args.kwargs[
"truncation"] is False

@pytest.mark.asyncio
async def test_truncation_positive(self, renderer, mock_async_tokenizer):
mock_async_tokenizer.return_value = MockTokenizerResult(
[101, 7592, 2088]) # Truncated
renderer.async_tokenizer_pool[
renderer.tokenizer] = mock_async_tokenizer

results = await renderer.render_prompt(prompt_or_prompts="Hello world",
max_length=100,
truncate_prompt_tokens=50)

assert len(results) == 1
call_args = mock_async_tokenizer.call_args
assert call_args.kwargs["truncation"] is True
assert call_args.kwargs["max_length"] == 50

@pytest.mark.asyncio
async def test_token_truncation_last_elements(self, renderer):
# Test that token truncation keeps the last N elements
long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108,
109] # 10 tokens
results = await renderer.render_prompt(prompt_or_prompts=long_tokens,
max_length=100,
truncate_prompt_tokens=5)

assert len(results) == 1
# Should keep the last 5 tokens: [105, 106, 107, 108, 109]
assert results[0]["prompt_token_ids"] == [105, 106, 107, 108, 109]

@pytest.mark.asyncio
async def test_max_length_exceeded(self, renderer):
long_tokens = list(range(150)) # Exceeds max_model_len=100

with pytest.raises(ValueError, match="maximum context length"):
await renderer.render_prompt(prompt_or_prompts=long_tokens,
max_length=100)

@pytest.mark.asyncio
async def test_no_tokenizer_for_text(self, mock_model_config):
renderer_no_tokenizer = CompletionRenderer(
model_config=mock_model_config,
tokenizer=None,
async_tokenizer_pool={})

with pytest.raises(ValueError, match="No tokenizer available"):
await renderer_no_tokenizer.render_prompt(
prompt_or_prompts="Hello world", max_length=100)
20 changes: 15 additions & 5 deletions vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,10 @@
TranslationRequest)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParser
from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer
# yapf: enable
from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt
from vllm.inputs.data import PromptType
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger
Expand Down Expand Up @@ -243,6 +245,16 @@ def __init__(
AsyncMicrobatchTokenizer] = {}
self.log_error_stack = log_error_stack

def _get_renderer(self, tokenizer: Optional[AnyTokenizer]) -> BaseRenderer:
"""
Get a Renderer instance with the provided tokenizer.
Uses shared async tokenizer pool for efficiency.
"""
return CompletionRenderer(
model_config=self.model_config,
tokenizer=tokenizer,
async_tokenizer_pool=self._async_tokenizer_pool)

def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer:
"""
Return (and cache) an `AsyncMicrobatchTokenizer` bound to the
Expand Down Expand Up @@ -1098,7 +1110,7 @@ def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt:
def _log_inputs(
self,
request_id: str,
inputs: RequestPrompt,
inputs: Union[RequestPrompt, PromptType],
params: Optional[Union[SamplingParams, PoolingParams,
BeamSearchParams]],
lora_request: Optional[LoRARequest],
Expand All @@ -1110,11 +1122,9 @@ def _log_inputs(
prompt = inputs
elif isinstance(inputs, list):
prompt_token_ids = inputs
elif "prompt_embeds" in inputs:
prompt_embeds = inputs.get("prompt_embeds")
else:
prompt = inputs["prompt"]
prompt_token_ids = inputs["prompt_token_ids"]
prompt = getattr(inputs, 'prompt', None)
prompt_token_ids = getattr(inputs, 'prompt_token_ids', None)

self.request_logger.log_inputs(
request_id,
Expand Down
26 changes: 12 additions & 14 deletions vllm/entrypoints/openai/serving_pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import asyncio
import base64
import time
from collections.abc import AsyncGenerator, Sequence
from collections.abc import AsyncGenerator
from typing import Final, Literal, Optional, Union, cast

import jinja2
Expand All @@ -26,7 +26,7 @@
PoolingRequest, PoolingResponse,
PoolingResponseData, UsageInfo)
# yapf: enable
from vllm.entrypoints.openai.serving_engine import OpenAIServing, RequestPrompt
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.logger import init_logger
Expand Down Expand Up @@ -104,6 +104,7 @@ async def create_pooling(
else:
tokenizer = await self.engine_client.get_tokenizer(lora_request
)
renderer = self._get_renderer(tokenizer)

if getattr(request, "dimensions", None) is not None:
return self.create_error_response(
Expand All @@ -126,14 +127,11 @@ async def create_pooling(

engine_prompts = await self.io_processor.pre_process_async(
prompt=validated_prompt, request_id=request_id)
request_prompts: Sequence[RequestPrompt] = [
""
] * len(engine_prompts)

elif isinstance(request, PoolingChatRequest):
(
_,
request_prompts,
_,
engine_prompts,
) = await self._preprocess_chat(
request,
Expand All @@ -149,13 +147,13 @@ async def create_pooling(
add_special_tokens=request.add_special_tokens,
)
elif isinstance(request, PoolingCompletionRequest):
(request_prompts,
engine_prompts) = await self._preprocess_completion(
request,
tokenizer,
request.input,
add_special_tokens=request.add_special_tokens,
)
engine_prompts = await renderer.render_prompt(
prompt_or_prompts=request.input,
max_length=self.max_model_len,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
cache_salt=getattr(request, 'cache_salt', None),
)
else:
raise ValueError(
f"Unsupported request of type {type(request)}")
Expand All @@ -177,7 +175,7 @@ async def create_pooling(
request_id_item = f"{request_id}-{i}"

self._log_inputs(request_id_item,
request_prompts[i],
engine_prompt,
params=pooling_params,
lora_request=lora_request)

Expand Down
15 changes: 7 additions & 8 deletions vllm/entrypoints/openai/serving_tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ async def create_tokenize(
lora_request = self._maybe_get_adapters(request)

tokenizer = await self.engine_client.get_tokenizer(lora_request)
renderer = self._get_renderer(tokenizer)

if isinstance(request, TokenizeChatRequest):
tool_dicts = (None if request.tools is None else
Expand All @@ -87,21 +88,19 @@ async def create_tokenize(
add_special_tokens=request.add_special_tokens,
)
else:
(request_prompts,
engine_prompts) = await self._preprocess_completion(
request,
tokenizer,
request.prompt,
add_special_tokens=request.add_special_tokens,
)
engine_prompts = await renderer.render_prompt(
prompt_or_prompts=request.prompt,
add_special_tokens=request.add_special_tokens,
cache_salt=getattr(request, 'cache_salt', None),
)
except (ValueError, TypeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(f"{e} {e.__cause__}")

input_ids: list[int] = []
for i, engine_prompt in enumerate(engine_prompts):
self._log_inputs(request_id,
request_prompts[i],
engine_prompt,
params=None,
lora_request=lora_request)

Expand Down
Loading