Skip to content

Commit fdcb850

Browse files
angkywilliamAngky William
andauthored
[Misc] Enable vLLM to Dynamically Load LoRA from a Remote Server (#10546)
Signed-off-by: Angky William <[email protected]> Co-authored-by: Angky William <[email protected]>
1 parent 54a66e5 commit fdcb850

File tree

6 files changed

+505
-6
lines changed

6 files changed

+505
-6
lines changed

docs/source/features/lora.md

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,19 +106,18 @@ curl http://localhost:8000/v1/completions \
106106

107107
## Dynamically serving LoRA Adapters
108108

109-
In addition to serving LoRA adapters at server startup, the vLLM server now supports dynamically loading and unloading
110-
LoRA adapters at runtime through dedicated API endpoints. This feature can be particularly useful when the flexibility
111-
to change models on-the-fly is needed.
109+
In addition to serving LoRA adapters at server startup, the vLLM server supports dynamically configuring LoRA adapters at runtime through dedicated API endpoints and plugins. This feature can be particularly useful when the flexibility to change models on-the-fly is needed.
112110

113111
Note: Enabling this feature in production environments is risky as users may participate in model adapter management.
114112

115-
To enable dynamic LoRA loading and unloading, ensure that the environment variable `VLLM_ALLOW_RUNTIME_LORA_UPDATING`
116-
is set to `True`. When this option is enabled, the API server will log a warning to indicate that dynamic loading is active.
113+
To enable dynamic LoRA configuration, ensure that the environment variable `VLLM_ALLOW_RUNTIME_LORA_UPDATING`
114+
is set to `True`.
117115

118116
```bash
119117
export VLLM_ALLOW_RUNTIME_LORA_UPDATING=True
120118
```
121119

120+
### Using API Endpoints
122121
Loading a LoRA Adapter:
123122

124123
To dynamically load a LoRA adapter, send a POST request to the `/v1/load_lora_adapter` endpoint with the necessary
@@ -153,6 +152,58 @@ curl -X POST http://localhost:8000/v1/unload_lora_adapter \
153152
}'
154153
```
155154

155+
### Using Plugins
156+
Alternatively, you can use the LoRAResolver plugin to dynamically load LoRA adapters. LoRAResolver plugins enable you to load LoRA adapters from both local and remote sources such as local file system and S3. On every request, when there's a new model name that hasn't been loaded yet, the LoRAResolver will try to resolve and load the corresponding LoRA adapter.
157+
158+
You can set up multiple LoRAResolver plugins if you want to load LoRA adapters from different sources. For example, you might have one resolver for local files and another for S3 storage. vLLM will load the first LoRA adapter that it finds.
159+
160+
You can either install existing plugins or implement your own.
161+
162+
Steps to implement your own LoRAResolver plugin:
163+
1. Implement the LoRAResolver interface.
164+
165+
Example of a simple S3 LoRAResolver implementation:
166+
167+
```python
168+
import os
169+
import s3fs
170+
from vllm.lora.request import LoRARequest
171+
from vllm.lora.resolver import LoRAResolver
172+
173+
class S3LoRAResolver(LoRAResolver):
174+
def __init__(self):
175+
self.s3 = s3fs.S3FileSystem()
176+
self.s3_path_format = os.getenv("S3_PATH_TEMPLATE")
177+
self.local_path_format = os.getenv("LOCAL_PATH_TEMPLATE")
178+
179+
async def resolve_lora(self, base_model_name, lora_name):
180+
s3_path = self.s3_path_format.format(base_model_name=base_model_name, lora_name=lora_name)
181+
local_path = self.local_path_format.format(base_model_name=base_model_name, lora_name=lora_name)
182+
183+
# Download the LoRA from S3 to the local path
184+
await self.s3._get(
185+
s3_path, local_path, recursive=True, maxdepth=1
186+
)
187+
188+
lora_request = LoRARequest(
189+
lora_name=lora_name,
190+
lora_path=local_path,
191+
lora_int_id=abs(hash(lora_name))
192+
)
193+
return lora_request
194+
```
195+
196+
2. Register LoRAResolver plugin.
197+
198+
```python
199+
from vllm.lora.resolver import LoRAResolverRegistry
200+
201+
s3_resolver = S3LoRAResolver()
202+
LoRAResolverRegistry.register_resolver("s3_resolver", s3_resolver)
203+
```
204+
205+
For more details, refer to the [vLLM's Plugins System](../design/plugin_system.md).
206+
156207
## New format for `--lora-modules`
157208

158209
In the previous version, users would provide LoRA modules via the following format, either as a key-value pair or in JSON format. For example:
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from contextlib import suppress
4+
from dataclasses import dataclass, field
5+
from http import HTTPStatus
6+
from typing import Optional
7+
from unittest.mock import MagicMock
8+
9+
import pytest
10+
11+
from vllm.config import MultiModalConfig
12+
from vllm.engine.multiprocessing.client import MQLLMEngineClient
13+
from vllm.entrypoints.openai.protocol import CompletionRequest, ErrorResponse
14+
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
15+
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
16+
OpenAIServingModels)
17+
from vllm.lora.request import LoRARequest
18+
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry
19+
from vllm.transformers_utils.tokenizer import get_tokenizer
20+
21+
MODEL_NAME = "openai-community/gpt2"
22+
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
23+
24+
MOCK_RESOLVER_NAME = "mock_test_resolver"
25+
26+
27+
@dataclass
28+
class MockHFConfig:
29+
model_type: str = "any"
30+
31+
32+
@dataclass
33+
class MockModelConfig:
34+
"""Minimal mock ModelConfig for testing."""
35+
model: str = MODEL_NAME
36+
tokenizer: str = MODEL_NAME
37+
trust_remote_code: bool = False
38+
tokenizer_mode: str = "auto"
39+
max_model_len: int = 100
40+
tokenizer_revision: Optional[str] = None
41+
multimodal_config: MultiModalConfig = field(
42+
default_factory=MultiModalConfig)
43+
hf_config: MockHFConfig = field(default_factory=MockHFConfig)
44+
logits_processor_pattern: Optional[str] = None
45+
diff_sampling_param: Optional[dict] = None
46+
allowed_local_media_path: str = ""
47+
encoder_config = None
48+
generation_config: str = "auto"
49+
50+
def get_diff_sampling_param(self):
51+
return self.diff_sampling_param or {}
52+
53+
54+
class MockLoRAResolver(LoRAResolver):
55+
56+
async def resolve_lora(self, base_model_name: str,
57+
lora_name: str) -> Optional[LoRARequest]:
58+
if lora_name == "test-lora":
59+
return LoRARequest(lora_name="test-lora",
60+
lora_int_id=1,
61+
lora_local_path="/fake/path/test-lora")
62+
elif lora_name == "invalid-lora":
63+
return LoRARequest(lora_name="invalid-lora",
64+
lora_int_id=2,
65+
lora_local_path="/fake/path/invalid-lora")
66+
return None
67+
68+
69+
@pytest.fixture(autouse=True)
70+
def register_mock_resolver():
71+
"""Fixture to register and unregister the mock LoRA resolver."""
72+
resolver = MockLoRAResolver()
73+
LoRAResolverRegistry.register_resolver(MOCK_RESOLVER_NAME, resolver)
74+
yield
75+
# Cleanup: remove the resolver after the test runs
76+
if MOCK_RESOLVER_NAME in LoRAResolverRegistry.resolvers:
77+
del LoRAResolverRegistry.resolvers[MOCK_RESOLVER_NAME]
78+
79+
80+
@pytest.fixture
81+
def mock_serving_setup():
82+
"""Provides a mocked engine and serving completion instance."""
83+
mock_engine = MagicMock(spec=MQLLMEngineClient)
84+
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
85+
mock_engine.errored = False
86+
87+
def mock_add_lora_side_effect(lora_request: LoRARequest):
88+
"""Simulate engine behavior when adding LoRAs."""
89+
if lora_request.lora_name == "test-lora":
90+
# Simulate successful addition
91+
return
92+
elif lora_request.lora_name == "invalid-lora":
93+
# Simulate failure during addition (e.g. invalid format)
94+
raise ValueError(f"Simulated failure adding LoRA: "
95+
f"{lora_request.lora_name}")
96+
97+
mock_engine.add_lora.side_effect = mock_add_lora_side_effect
98+
mock_engine.generate.reset_mock()
99+
mock_engine.add_lora.reset_mock()
100+
101+
mock_model_config = MockModelConfig()
102+
models = OpenAIServingModels(engine_client=mock_engine,
103+
base_model_paths=BASE_MODEL_PATHS,
104+
model_config=mock_model_config)
105+
106+
serving_completion = OpenAIServingCompletion(mock_engine,
107+
mock_model_config,
108+
models,
109+
request_logger=None)
110+
111+
return mock_engine, serving_completion
112+
113+
114+
@pytest.mark.asyncio
115+
async def test_serving_completion_with_lora_resolver(mock_serving_setup,
116+
monkeypatch):
117+
monkeypatch.setenv("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "true")
118+
119+
mock_engine, serving_completion = mock_serving_setup
120+
121+
lora_model_name = "test-lora"
122+
req_found = CompletionRequest(
123+
model=lora_model_name,
124+
prompt="Generate with LoRA",
125+
)
126+
127+
# Suppress potential errors during the mocked generate call,
128+
# as we are primarily checking for add_lora and generate calls
129+
with suppress(Exception):
130+
await serving_completion.create_completion(req_found)
131+
132+
mock_engine.add_lora.assert_called_once()
133+
called_lora_request = mock_engine.add_lora.call_args[0][0]
134+
assert isinstance(called_lora_request, LoRARequest)
135+
assert called_lora_request.lora_name == lora_model_name
136+
137+
mock_engine.generate.assert_called_once()
138+
called_lora_request = mock_engine.generate.call_args[1]['lora_request']
139+
assert isinstance(called_lora_request, LoRARequest)
140+
assert called_lora_request.lora_name == lora_model_name
141+
142+
143+
@pytest.mark.asyncio
144+
async def test_serving_completion_resolver_not_found(mock_serving_setup,
145+
monkeypatch):
146+
monkeypatch.setenv("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "true")
147+
148+
mock_engine, serving_completion = mock_serving_setup
149+
150+
non_existent_model = "non-existent-lora-adapter"
151+
req = CompletionRequest(
152+
model=non_existent_model,
153+
prompt="what is 1+1?",
154+
)
155+
156+
response = await serving_completion.create_completion(req)
157+
158+
mock_engine.add_lora.assert_not_called()
159+
mock_engine.generate.assert_not_called()
160+
161+
assert isinstance(response, ErrorResponse)
162+
assert response.code == HTTPStatus.NOT_FOUND.value
163+
assert non_existent_model in response.message
164+
165+
166+
@pytest.mark.asyncio
167+
async def test_serving_completion_resolver_add_lora_fails(
168+
mock_serving_setup, monkeypatch):
169+
monkeypatch.setenv("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "true")
170+
171+
mock_engine, serving_completion = mock_serving_setup
172+
173+
invalid_model = "invalid-lora"
174+
req = CompletionRequest(
175+
model=invalid_model,
176+
prompt="what is 1+1?",
177+
)
178+
179+
response = await serving_completion.create_completion(req)
180+
181+
# Assert add_lora was called before the failure
182+
mock_engine.add_lora.assert_called_once()
183+
called_lora_request = mock_engine.add_lora.call_args[0][0]
184+
assert isinstance(called_lora_request, LoRARequest)
185+
assert called_lora_request.lora_name == invalid_model
186+
187+
# Assert generate was *not* called due to the failure
188+
mock_engine.generate.assert_not_called()
189+
190+
# Assert the correct error response
191+
assert isinstance(response, ErrorResponse)
192+
assert response.code == HTTPStatus.BAD_REQUEST.value
193+
assert invalid_model in response.message
194+
195+
196+
@pytest.mark.asyncio
197+
async def test_serving_completion_flag_not_set(mock_serving_setup):
198+
mock_engine, serving_completion = mock_serving_setup
199+
200+
lora_model_name = "test-lora"
201+
req_found = CompletionRequest(
202+
model=lora_model_name,
203+
prompt="Generate with LoRA",
204+
)
205+
206+
await serving_completion.create_completion(req_found)
207+
208+
mock_engine.add_lora.assert_not_called()
209+
mock_engine.generate.assert_not_called()

tests/lora/test_resolver.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from typing import Optional
4+
5+
import pytest
6+
7+
from vllm.lora.request import LoRARequest
8+
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry
9+
10+
11+
class DummyLoRAResolver(LoRAResolver):
12+
"""A dummy LoRA resolver for testing."""
13+
14+
async def resolve_lora(self, base_model_name: str,
15+
lora_name: str) -> Optional[LoRARequest]:
16+
if lora_name == "test_lora":
17+
return LoRARequest(
18+
lora_name=lora_name,
19+
lora_path=f"/dummy/path/{base_model_name}/{lora_name}",
20+
lora_int_id=abs(hash(lora_name)))
21+
return None
22+
23+
24+
def test_resolver_registry_registration():
25+
"""Test basic resolver registration functionality."""
26+
registry = LoRAResolverRegistry
27+
resolver = DummyLoRAResolver()
28+
29+
# Register a new resolver
30+
registry.register_resolver("dummy", resolver)
31+
assert "dummy" in registry.get_supported_resolvers()
32+
33+
# Get registered resolver
34+
retrieved_resolver = registry.get_resolver("dummy")
35+
assert retrieved_resolver is resolver
36+
37+
38+
def test_resolver_registry_duplicate_registration():
39+
"""Test registering a resolver with an existing name."""
40+
registry = LoRAResolverRegistry
41+
resolver1 = DummyLoRAResolver()
42+
resolver2 = DummyLoRAResolver()
43+
44+
registry.register_resolver("dummy", resolver1)
45+
registry.register_resolver("dummy", resolver2)
46+
47+
assert registry.get_resolver("dummy") is resolver2
48+
49+
50+
def test_resolver_registry_unknown_resolver():
51+
"""Test getting a non-existent resolver."""
52+
registry = LoRAResolverRegistry
53+
54+
with pytest.raises(KeyError, match="not found"):
55+
registry.get_resolver("unknown_resolver")
56+
57+
58+
@pytest.mark.asyncio
59+
async def test_dummy_resolver_resolve():
60+
"""Test the dummy resolver's resolve functionality."""
61+
dummy_resolver = DummyLoRAResolver()
62+
base_model_name = "base_model_test"
63+
lora_name = "test_lora"
64+
65+
# Test successful resolution
66+
result = await dummy_resolver.resolve_lora(base_model_name, lora_name)
67+
assert isinstance(result, LoRARequest)
68+
assert result.lora_name == lora_name
69+
assert result.lora_path == f"/dummy/path/{base_model_name}/{lora_name}"
70+
71+
# Test failed resolution
72+
result = await dummy_resolver.resolve_lora(base_model_name,
73+
"nonexistent_lora")
74+
assert result is None

0 commit comments

Comments
 (0)