|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | + |
| 3 | +from contextlib import suppress |
| 4 | +from dataclasses import dataclass, field |
| 5 | +from http import HTTPStatus |
| 6 | +from typing import Optional |
| 7 | +from unittest.mock import MagicMock |
| 8 | + |
| 9 | +import pytest |
| 10 | + |
| 11 | +from vllm.config import MultiModalConfig |
| 12 | +from vllm.engine.multiprocessing.client import MQLLMEngineClient |
| 13 | +from vllm.entrypoints.openai.protocol import CompletionRequest, ErrorResponse |
| 14 | +from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion |
| 15 | +from vllm.entrypoints.openai.serving_models import (BaseModelPath, |
| 16 | + OpenAIServingModels) |
| 17 | +from vllm.lora.request import LoRARequest |
| 18 | +from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry |
| 19 | +from vllm.transformers_utils.tokenizer import get_tokenizer |
| 20 | + |
| 21 | +MODEL_NAME = "openai-community/gpt2" |
| 22 | +BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)] |
| 23 | + |
| 24 | +MOCK_RESOLVER_NAME = "mock_test_resolver" |
| 25 | + |
| 26 | + |
| 27 | +@dataclass |
| 28 | +class MockHFConfig: |
| 29 | + model_type: str = "any" |
| 30 | + |
| 31 | + |
| 32 | +@dataclass |
| 33 | +class MockModelConfig: |
| 34 | + """Minimal mock ModelConfig for testing.""" |
| 35 | + model: str = MODEL_NAME |
| 36 | + tokenizer: str = MODEL_NAME |
| 37 | + trust_remote_code: bool = False |
| 38 | + tokenizer_mode: str = "auto" |
| 39 | + max_model_len: int = 100 |
| 40 | + tokenizer_revision: Optional[str] = None |
| 41 | + multimodal_config: MultiModalConfig = field( |
| 42 | + default_factory=MultiModalConfig) |
| 43 | + hf_config: MockHFConfig = field(default_factory=MockHFConfig) |
| 44 | + logits_processor_pattern: Optional[str] = None |
| 45 | + diff_sampling_param: Optional[dict] = None |
| 46 | + allowed_local_media_path: str = "" |
| 47 | + encoder_config = None |
| 48 | + generation_config: str = "auto" |
| 49 | + |
| 50 | + def get_diff_sampling_param(self): |
| 51 | + return self.diff_sampling_param or {} |
| 52 | + |
| 53 | + |
| 54 | +class MockLoRAResolver(LoRAResolver): |
| 55 | + |
| 56 | + async def resolve_lora(self, base_model_name: str, |
| 57 | + lora_name: str) -> Optional[LoRARequest]: |
| 58 | + if lora_name == "test-lora": |
| 59 | + return LoRARequest(lora_name="test-lora", |
| 60 | + lora_int_id=1, |
| 61 | + lora_local_path="/fake/path/test-lora") |
| 62 | + elif lora_name == "invalid-lora": |
| 63 | + return LoRARequest(lora_name="invalid-lora", |
| 64 | + lora_int_id=2, |
| 65 | + lora_local_path="/fake/path/invalid-lora") |
| 66 | + return None |
| 67 | + |
| 68 | + |
| 69 | +@pytest.fixture(autouse=True) |
| 70 | +def register_mock_resolver(): |
| 71 | + """Fixture to register and unregister the mock LoRA resolver.""" |
| 72 | + resolver = MockLoRAResolver() |
| 73 | + LoRAResolverRegistry.register_resolver(MOCK_RESOLVER_NAME, resolver) |
| 74 | + yield |
| 75 | + # Cleanup: remove the resolver after the test runs |
| 76 | + if MOCK_RESOLVER_NAME in LoRAResolverRegistry.resolvers: |
| 77 | + del LoRAResolverRegistry.resolvers[MOCK_RESOLVER_NAME] |
| 78 | + |
| 79 | + |
| 80 | +@pytest.fixture |
| 81 | +def mock_serving_setup(): |
| 82 | + """Provides a mocked engine and serving completion instance.""" |
| 83 | + mock_engine = MagicMock(spec=MQLLMEngineClient) |
| 84 | + mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) |
| 85 | + mock_engine.errored = False |
| 86 | + |
| 87 | + def mock_add_lora_side_effect(lora_request: LoRARequest): |
| 88 | + """Simulate engine behavior when adding LoRAs.""" |
| 89 | + if lora_request.lora_name == "test-lora": |
| 90 | + # Simulate successful addition |
| 91 | + return |
| 92 | + elif lora_request.lora_name == "invalid-lora": |
| 93 | + # Simulate failure during addition (e.g. invalid format) |
| 94 | + raise ValueError(f"Simulated failure adding LoRA: " |
| 95 | + f"{lora_request.lora_name}") |
| 96 | + |
| 97 | + mock_engine.add_lora.side_effect = mock_add_lora_side_effect |
| 98 | + mock_engine.generate.reset_mock() |
| 99 | + mock_engine.add_lora.reset_mock() |
| 100 | + |
| 101 | + mock_model_config = MockModelConfig() |
| 102 | + models = OpenAIServingModels(engine_client=mock_engine, |
| 103 | + base_model_paths=BASE_MODEL_PATHS, |
| 104 | + model_config=mock_model_config) |
| 105 | + |
| 106 | + serving_completion = OpenAIServingCompletion(mock_engine, |
| 107 | + mock_model_config, |
| 108 | + models, |
| 109 | + request_logger=None) |
| 110 | + |
| 111 | + return mock_engine, serving_completion |
| 112 | + |
| 113 | + |
| 114 | +@pytest.mark.asyncio |
| 115 | +async def test_serving_completion_with_lora_resolver(mock_serving_setup, |
| 116 | + monkeypatch): |
| 117 | + monkeypatch.setenv("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "true") |
| 118 | + |
| 119 | + mock_engine, serving_completion = mock_serving_setup |
| 120 | + |
| 121 | + lora_model_name = "test-lora" |
| 122 | + req_found = CompletionRequest( |
| 123 | + model=lora_model_name, |
| 124 | + prompt="Generate with LoRA", |
| 125 | + ) |
| 126 | + |
| 127 | + # Suppress potential errors during the mocked generate call, |
| 128 | + # as we are primarily checking for add_lora and generate calls |
| 129 | + with suppress(Exception): |
| 130 | + await serving_completion.create_completion(req_found) |
| 131 | + |
| 132 | + mock_engine.add_lora.assert_called_once() |
| 133 | + called_lora_request = mock_engine.add_lora.call_args[0][0] |
| 134 | + assert isinstance(called_lora_request, LoRARequest) |
| 135 | + assert called_lora_request.lora_name == lora_model_name |
| 136 | + |
| 137 | + mock_engine.generate.assert_called_once() |
| 138 | + called_lora_request = mock_engine.generate.call_args[1]['lora_request'] |
| 139 | + assert isinstance(called_lora_request, LoRARequest) |
| 140 | + assert called_lora_request.lora_name == lora_model_name |
| 141 | + |
| 142 | + |
| 143 | +@pytest.mark.asyncio |
| 144 | +async def test_serving_completion_resolver_not_found(mock_serving_setup, |
| 145 | + monkeypatch): |
| 146 | + monkeypatch.setenv("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "true") |
| 147 | + |
| 148 | + mock_engine, serving_completion = mock_serving_setup |
| 149 | + |
| 150 | + non_existent_model = "non-existent-lora-adapter" |
| 151 | + req = CompletionRequest( |
| 152 | + model=non_existent_model, |
| 153 | + prompt="what is 1+1?", |
| 154 | + ) |
| 155 | + |
| 156 | + response = await serving_completion.create_completion(req) |
| 157 | + |
| 158 | + mock_engine.add_lora.assert_not_called() |
| 159 | + mock_engine.generate.assert_not_called() |
| 160 | + |
| 161 | + assert isinstance(response, ErrorResponse) |
| 162 | + assert response.code == HTTPStatus.NOT_FOUND.value |
| 163 | + assert non_existent_model in response.message |
| 164 | + |
| 165 | + |
| 166 | +@pytest.mark.asyncio |
| 167 | +async def test_serving_completion_resolver_add_lora_fails( |
| 168 | + mock_serving_setup, monkeypatch): |
| 169 | + monkeypatch.setenv("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "true") |
| 170 | + |
| 171 | + mock_engine, serving_completion = mock_serving_setup |
| 172 | + |
| 173 | + invalid_model = "invalid-lora" |
| 174 | + req = CompletionRequest( |
| 175 | + model=invalid_model, |
| 176 | + prompt="what is 1+1?", |
| 177 | + ) |
| 178 | + |
| 179 | + response = await serving_completion.create_completion(req) |
| 180 | + |
| 181 | + # Assert add_lora was called before the failure |
| 182 | + mock_engine.add_lora.assert_called_once() |
| 183 | + called_lora_request = mock_engine.add_lora.call_args[0][0] |
| 184 | + assert isinstance(called_lora_request, LoRARequest) |
| 185 | + assert called_lora_request.lora_name == invalid_model |
| 186 | + |
| 187 | + # Assert generate was *not* called due to the failure |
| 188 | + mock_engine.generate.assert_not_called() |
| 189 | + |
| 190 | + # Assert the correct error response |
| 191 | + assert isinstance(response, ErrorResponse) |
| 192 | + assert response.code == HTTPStatus.BAD_REQUEST.value |
| 193 | + assert invalid_model in response.message |
| 194 | + |
| 195 | + |
| 196 | +@pytest.mark.asyncio |
| 197 | +async def test_serving_completion_flag_not_set(mock_serving_setup): |
| 198 | + mock_engine, serving_completion = mock_serving_setup |
| 199 | + |
| 200 | + lora_model_name = "test-lora" |
| 201 | + req_found = CompletionRequest( |
| 202 | + model=lora_model_name, |
| 203 | + prompt="Generate with LoRA", |
| 204 | + ) |
| 205 | + |
| 206 | + await serving_completion.create_completion(req_found) |
| 207 | + |
| 208 | + mock_engine.add_lora.assert_not_called() |
| 209 | + mock_engine.generate.assert_not_called() |
0 commit comments