diff --git a/tests/entrypoints/test_renderer.py b/tests/entrypoints/test_renderer.py new file mode 100644 index 000000000000..54b5271ba67a --- /dev/null +++ b/tests/entrypoints/test_renderer.py @@ -0,0 +1,163 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass +from typing import Optional +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from vllm.entrypoints.renderer import CompletionRenderer + + +@dataclass +class MockModelConfig: + max_model_len: int = 100 + encoder_config: Optional[dict] = None + + +class MockTokenizerResult: + + def __init__(self, input_ids): + self.input_ids = input_ids + + +@pytest.fixture +def mock_model_config(): + return MockModelConfig() + + +@pytest.fixture +def mock_tokenizer(): + tokenizer = MagicMock() + return tokenizer + + +@pytest.fixture +def mock_async_tokenizer(): + async_tokenizer = AsyncMock() + return async_tokenizer + + +@pytest.fixture +def renderer(mock_model_config, mock_tokenizer): + return CompletionRenderer(model_config=mock_model_config, + tokenizer=mock_tokenizer, + async_tokenizer_pool={}) + + +class TestRenderPrompt: + """Test Category A: Basic Functionality Tests""" + + @pytest.mark.asyncio + async def test_token_input(self, renderer): + tokens = [101, 7592, 2088] + results = await renderer.render_prompt(prompt_or_prompts=tokens, + max_length=100) + + assert len(results) == 1 + assert results[0]["prompt_token_ids"] == tokens + + @pytest.mark.asyncio + async def test_token_list_input(self, renderer): + token_lists = [[101, 7592, 2088], [102, 1234, 5678, 9012], [103, 4567]] + results = await renderer.render_prompt(prompt_or_prompts=token_lists, + max_length=100) + + assert len(results) == 3 + assert results[0]["prompt_token_ids"] == [101, 7592, 2088] + assert results[1]["prompt_token_ids"] == [102, 1234, 5678, 9012] + assert results[2]["prompt_token_ids"] == [103, 4567] + + @pytest.mark.asyncio + async def test_text_input(self, renderer, mock_async_tokenizer): + mock_async_tokenizer.return_value = MockTokenizerResult( + [101, 7592, 2088]) + renderer.async_tokenizer_pool[ + renderer.tokenizer] = mock_async_tokenizer + + results = await renderer.render_prompt(prompt_or_prompts="Hello world", + max_length=100) + + assert len(results) == 1 + assert results[0]["prompt_token_ids"] == [101, 7592, 2088] + mock_async_tokenizer.assert_called_once() + + @pytest.mark.asyncio + async def test_text_list_input(self, renderer, mock_async_tokenizer): + mock_async_tokenizer.return_value = MockTokenizerResult( + [101, 7592, 2088]) + renderer.async_tokenizer_pool[ + renderer.tokenizer] = mock_async_tokenizer + + text_list_input = ["Hello world", "How are you?", "Good morning"] + results = await renderer.render_prompt( + prompt_or_prompts=text_list_input, max_length=100) + + assert len(results) == 3 + for result in results: + assert result["prompt_token_ids"] == [101, 7592, 2088] + assert mock_async_tokenizer.call_count == 3 + + @pytest.mark.asyncio + async def test_no_truncation(self, renderer, mock_async_tokenizer): + mock_async_tokenizer.return_value = MockTokenizerResult( + [101, 7592, 2088]) + renderer.async_tokenizer_pool[ + renderer.tokenizer] = mock_async_tokenizer + + results = await renderer.render_prompt(prompt_or_prompts="Hello world", + max_length=100) + + assert len(results) == 1 + call_args = mock_async_tokenizer.call_args + assert "truncation" not in call_args.kwargs or call_args.kwargs[ + "truncation"] is False + + @pytest.mark.asyncio + async def test_truncation_positive(self, renderer, mock_async_tokenizer): + mock_async_tokenizer.return_value = MockTokenizerResult( + [101, 7592, 2088]) # Truncated + renderer.async_tokenizer_pool[ + renderer.tokenizer] = mock_async_tokenizer + + results = await renderer.render_prompt(prompt_or_prompts="Hello world", + max_length=100, + truncate_prompt_tokens=50) + + assert len(results) == 1 + call_args = mock_async_tokenizer.call_args + assert call_args.kwargs["truncation"] is True + assert call_args.kwargs["max_length"] == 50 + + @pytest.mark.asyncio + async def test_token_truncation_last_elements(self, renderer): + # Test that token truncation keeps the last N elements + long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, + 109] # 10 tokens + results = await renderer.render_prompt(prompt_or_prompts=long_tokens, + max_length=100, + truncate_prompt_tokens=5) + + assert len(results) == 1 + # Should keep the last 5 tokens: [105, 106, 107, 108, 109] + assert results[0]["prompt_token_ids"] == [105, 106, 107, 108, 109] + + @pytest.mark.asyncio + async def test_max_length_exceeded(self, renderer): + long_tokens = list(range(150)) # Exceeds max_model_len=100 + + with pytest.raises(ValueError, match="maximum context length"): + await renderer.render_prompt(prompt_or_prompts=long_tokens, + max_length=100) + + @pytest.mark.asyncio + async def test_no_tokenizer_for_text(self, mock_model_config): + renderer_no_tokenizer = CompletionRenderer( + model_config=mock_model_config, + tokenizer=None, + async_tokenizer_pool={}) + + with pytest.raises(ValueError, match="No tokenizer available"): + await renderer_no_tokenizer.render_prompt( + prompt_or_prompts="Hello world", max_length=100) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index f506f7de1682..a218f6882f8c 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -62,8 +62,10 @@ TranslationRequest) from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.tool_parsers import ToolParser +from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer # yapf: enable from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt +from vllm.inputs.data import PromptType from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.inputs.parse import parse_and_batch_prompt from vllm.logger import init_logger @@ -243,6 +245,16 @@ def __init__( AsyncMicrobatchTokenizer] = {} self.log_error_stack = log_error_stack + def _get_renderer(self, tokenizer: Optional[AnyTokenizer]) -> BaseRenderer: + """ + Get a Renderer instance with the provided tokenizer. + Uses shared async tokenizer pool for efficiency. + """ + return CompletionRenderer( + model_config=self.model_config, + tokenizer=tokenizer, + async_tokenizer_pool=self._async_tokenizer_pool) + def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer: """ Return (and cache) an `AsyncMicrobatchTokenizer` bound to the @@ -1098,7 +1110,7 @@ def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt: def _log_inputs( self, request_id: str, - inputs: RequestPrompt, + inputs: Union[RequestPrompt, PromptType], params: Optional[Union[SamplingParams, PoolingParams, BeamSearchParams]], lora_request: Optional[LoRARequest], @@ -1110,11 +1122,9 @@ def _log_inputs( prompt = inputs elif isinstance(inputs, list): prompt_token_ids = inputs - elif "prompt_embeds" in inputs: - prompt_embeds = inputs.get("prompt_embeds") else: - prompt = inputs["prompt"] - prompt_token_ids = inputs["prompt_token_ids"] + prompt = getattr(inputs, 'prompt', None) + prompt_token_ids = getattr(inputs, 'prompt_token_ids', None) self.request_logger.log_inputs( request_id, diff --git a/vllm/entrypoints/openai/serving_pooling.py b/vllm/entrypoints/openai/serving_pooling.py index 685c98c817c3..c08c0743ffca 100644 --- a/vllm/entrypoints/openai/serving_pooling.py +++ b/vllm/entrypoints/openai/serving_pooling.py @@ -4,7 +4,7 @@ import asyncio import base64 import time -from collections.abc import AsyncGenerator, Sequence +from collections.abc import AsyncGenerator from typing import Final, Literal, Optional, Union, cast import jinja2 @@ -26,7 +26,7 @@ PoolingRequest, PoolingResponse, PoolingResponseData, UsageInfo) # yapf: enable -from vllm.entrypoints.openai.serving_engine import OpenAIServing, RequestPrompt +from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.utils import _validate_truncation_size from vllm.logger import init_logger @@ -104,6 +104,7 @@ async def create_pooling( else: tokenizer = await self.engine_client.get_tokenizer(lora_request ) + renderer = self._get_renderer(tokenizer) if getattr(request, "dimensions", None) is not None: return self.create_error_response( @@ -126,14 +127,11 @@ async def create_pooling( engine_prompts = await self.io_processor.pre_process_async( prompt=validated_prompt, request_id=request_id) - request_prompts: Sequence[RequestPrompt] = [ - "" - ] * len(engine_prompts) elif isinstance(request, PoolingChatRequest): ( _, - request_prompts, + _, engine_prompts, ) = await self._preprocess_chat( request, @@ -149,13 +147,13 @@ async def create_pooling( add_special_tokens=request.add_special_tokens, ) elif isinstance(request, PoolingCompletionRequest): - (request_prompts, - engine_prompts) = await self._preprocess_completion( - request, - tokenizer, - request.input, - add_special_tokens=request.add_special_tokens, - ) + engine_prompts = await renderer.render_prompt( + prompt_or_prompts=request.input, + max_length=self.max_model_len, + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=request.add_special_tokens, + cache_salt=getattr(request, 'cache_salt', None), + ) else: raise ValueError( f"Unsupported request of type {type(request)}") @@ -177,7 +175,7 @@ async def create_pooling( request_id_item = f"{request_id}-{i}" self._log_inputs(request_id_item, - request_prompts[i], + engine_prompt, params=pooling_params, lora_request=lora_request) diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 2f258255d5f1..70cb6c21b221 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -65,6 +65,7 @@ async def create_tokenize( lora_request = self._maybe_get_adapters(request) tokenizer = await self.engine_client.get_tokenizer(lora_request) + renderer = self._get_renderer(tokenizer) if isinstance(request, TokenizeChatRequest): tool_dicts = (None if request.tools is None else @@ -87,13 +88,11 @@ async def create_tokenize( add_special_tokens=request.add_special_tokens, ) else: - (request_prompts, - engine_prompts) = await self._preprocess_completion( - request, - tokenizer, - request.prompt, - add_special_tokens=request.add_special_tokens, - ) + engine_prompts = await renderer.render_prompt( + prompt_or_prompts=request.prompt, + add_special_tokens=request.add_special_tokens, + cache_salt=getattr(request, 'cache_salt', None), + ) except (ValueError, TypeError, jinja2.TemplateError) as e: logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(f"{e} {e.__cause__}") @@ -101,7 +100,7 @@ async def create_tokenize( input_ids: list[int] = [] for i, engine_prompt in enumerate(engine_prompts): self._log_inputs(request_id, - request_prompts[i], + engine_prompt, params=None, lora_request=lora_request) diff --git a/vllm/entrypoints/renderer.py b/vllm/entrypoints/renderer.py new file mode 100644 index 000000000000..29200dda8998 --- /dev/null +++ b/vllm/entrypoints/renderer.py @@ -0,0 +1,219 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +from abc import ABC, abstractmethod +from typing import Annotated, Optional, Union + +from pydantic import Field + +from vllm.config import ModelConfig +from vllm.inputs.data import TokensPrompt as EngineTokensPrompt +from vllm.inputs.parse import parse_and_batch_prompt +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import AsyncMicrobatchTokenizer + + +class BaseRenderer(ABC): + """ + Base class for unified input processing and rendering. + + The Renderer serves as a unified input processor that consolidates + tokenization, chat template formatting, and multimodal input handling + into a single component. + It converts high-level API requests (OpenAI-style JSON) into token IDs and + multimodal features ready for engine consumption. + + Key responsibilities: + - Convert text prompts to token sequences with proper special tokens + - Apply chat templates and format conversations + - Handle multimodal inputs (images, audio, etc.) when applicable + - Manage prompt truncation and length validation + - Provide clean separation between API layer and engine core + """ + + def __init__( + self, + model_config: ModelConfig, + tokenizer: Optional[AnyTokenizer] = None, + ): + super().__init__() + self.model_config = model_config + self.tokenizer = tokenizer + + @abstractmethod + async def render_prompt( + self, + prompt_or_prompts: Union[str, list[str], list[int], list[list[int]]], + max_length: Optional[int] = None, + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None, + add_special_tokens: Optional[bool] = True, + cache_salt: Optional[str] = None, + ) -> list[EngineTokensPrompt]: + """ + Convert input prompts into tokenized format for engine processing. + + This is the core method that transforms various input formats into + standardized TokensPrompt objects. Implementations should handle + tokenization, special token insertion, truncation, and validation + according to model requirements. + + Args: + prompt_or_prompts: Input data in various formats: + - str: Single text prompt + - list[str]: Batch of text prompts + - list[int]: Pre-tokenized sequence + - list[list[int]]: Batch of pre-tokenized sequences + max_length: Maximum sequence length (endpoint-specific behavior) + truncate_prompt_tokens: Truncate to last N tokens + (None=no truncation, 0=empty) + add_special_tokens: Add model-specific tokens (e.g., [CLS], [SEP]) + to text inputs + cache_salt: Optional string to disambiguate cached prompts + + Returns: + list[EngineTokensPrompt]: Tokenized prompts ready for engine + consumption + + Raises: + ValueError: If input format is invalid or length limits exceeded + """ + raise NotImplementedError + + +class CompletionRenderer(BaseRenderer): + + def __init__( + self, + model_config: ModelConfig, + tokenizer: Optional[AnyTokenizer] = None, + async_tokenizer_pool: Optional[dict[AnyTokenizer, + AsyncMicrobatchTokenizer]] = None, + ): + super().__init__(model_config, tokenizer) + self.async_tokenizer_pool = async_tokenizer_pool or {} + self.async_tokenizer: Optional[AsyncMicrobatchTokenizer] = None + + async def render_prompt( + self, + prompt_or_prompts: Union[str, list[str], list[int], list[list[int]]], + max_length: Optional[int] = None, + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None, + add_special_tokens: Optional[bool] = True, + cache_salt: Optional[str] = None, + ) -> list[EngineTokensPrompt]: + """Implementation of prompt rendering for completion-style requests. + + Uses async tokenizer pooling for improved performance. See base class + for detailed parameter documentation. + """ + if truncate_prompt_tokens is not None: + if max_length is not None: + assert 0 <= truncate_prompt_tokens <= max_length + if truncate_prompt_tokens == 0: + return [] + + # Parse and batch the input prompts + batch_inputs = parse_and_batch_prompt(prompt_or_prompts) + + rendered_prompts: list[EngineTokensPrompt] = [] + tokenize_tasks = [] + for prompt_input in batch_inputs: + if prompt_input["is_tokens"] is True: + # Token input + token_ids = self._maybe_apply_truncation( + prompt_input["content"], truncate_prompt_tokens) + rendered_prompts.append( + self._create_tokens_prompt(token_ids, max_length, + cache_salt)) + else: + # Text input + tokenize_task = asyncio.create_task( + self._tokenize(prompt_input["content"], max_length, + truncate_prompt_tokens, add_special_tokens, + cache_salt)) + tokenize_tasks.append(tokenize_task) + + # Wait for all text tokenization to finish + if tokenize_tasks: + tokenized_text_prompts = await asyncio.gather(*tokenize_tasks) + rendered_prompts.extend(tokenized_text_prompts) + + return rendered_prompts + + def _maybe_apply_truncation( + self, token_ids: list[int], + truncate_prompt_tokens: Optional[int]) -> list[int]: + """Apply truncation to token sequence.""" + if truncate_prompt_tokens is None: + return token_ids + if truncate_prompt_tokens >= len(token_ids): + return token_ids + + return token_ids[-truncate_prompt_tokens:] + + async def _tokenize( + self, + text: str, + max_length: Optional[int], + truncate_prompt_tokens: Optional[int], + add_special_tokens: Optional[bool], + cache_salt: Optional[str], + ) -> EngineTokensPrompt: + """Tokenize text input asynchronously.""" + async_tokenizer = self._get_async_tokenizer() + + # Handle encoder-specific preprocessing + if (self.model_config.encoder_config is not None + and self.model_config.encoder_config.get( + "do_lower_case", False)): + text = text.lower() + + # Tokenize texts + if truncate_prompt_tokens is None: + encoded = await async_tokenizer( + text, add_special_tokens=add_special_tokens) + else: + encoded = await async_tokenizer( + text, + add_special_tokens=add_special_tokens, + truncation=True, + max_length=truncate_prompt_tokens) + + return self._create_tokens_prompt(encoded.input_ids, max_length, + cache_salt) + + def _get_async_tokenizer(self) -> AsyncMicrobatchTokenizer: + """Get or create async tokenizer using shared pool.""" + if self.async_tokenizer is not None: + return self.async_tokenizer + if self.tokenizer is None: + raise ValueError( + "No tokenizer available for text input processing") + + # Check shared pool first + if self.tokenizer in self.async_tokenizer_pool: + return self.async_tokenizer_pool[self.tokenizer] + + # Create new async tokenizer and add to pool + self.async_tokenizer = AsyncMicrobatchTokenizer(self.tokenizer) + self.async_tokenizer_pool[self.tokenizer] = self.async_tokenizer + return self.async_tokenizer + + def _create_tokens_prompt( + self, + token_ids: list[int], + max_length: Optional[int] = None, + cache_salt: Optional[str] = None, + ) -> EngineTokensPrompt: + """Create validated EngineTokensPrompt.""" + if max_length is not None and len(token_ids) > max_length: + raise ValueError( + f"This maximum context length is {max_length} tokens. " + f"However, your request has {len(token_ids)} input tokens. " + "Please reduce the length of the input messages.") + + tokens_prompt = EngineTokensPrompt(prompt_token_ids=token_ids) + if cache_salt is not None: + tokens_prompt["cache_salt"] = cache_salt + return tokens_prompt