Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Scripts for development
scripts/

# version file generated by setuptools-scm
/vllm/_version.py

Expand Down
17 changes: 15 additions & 2 deletions examples/offline_inference/spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,11 @@ def parse_args():
parser.add_argument("--output-len", type=int, default=256)
parser.add_argument("--model-dir", type=str, default=None)
parser.add_argument("--eagle-dir", type=str, default=None)
parser.add_argument("--draft-model", type=str, default=None)
parser.add_argument("--custom-mm-prompts", action="store_true")
parser.add_argument("--gpu-memory-utilization", type=float, default=0.8)
parser.add_argument("--request-id-prefix", type=str, default="")
parser.add_argument("--max-model-len", type=int, default=16384)
return parser.parse_args()


Expand Down Expand Up @@ -117,6 +121,15 @@ def main():
"prompt_lookup_max": args.prompt_lookup_max,
"prompt_lookup_min": args.prompt_lookup_min,
}
elif args.method == "draft_model":
assert args.draft_model is not None and args.draft_model != ""
speculative_config = {
"method": args.method,
"model": args.draft_model,
"num_speculative_tokens": args.num_spec_tokens,
"enforce_eager": args.enforce_eager,
"max_model_len": args.max_model_len,
}
elif args.method.endswith("mtp"):
speculative_config = {
"method": args.method,
Expand All @@ -131,10 +144,10 @@ def main():
tensor_parallel_size=args.tp,
enable_chunked_prefill=args.enable_chunked_prefill,
enforce_eager=args.enforce_eager,
gpu_memory_utilization=0.8,
gpu_memory_utilization=args.gpu_memory_utilization,
speculative_config=speculative_config,
disable_log_stats=False,
max_model_len=16384,
max_model_len=args.max_model_len,
limit_mm_per_prompt={"image": 5},
disable_chunked_mm_input=True,
)
Expand Down
126 changes: 124 additions & 2 deletions tests/v1/e2e/test_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import random
from dataclasses import dataclass
from typing import Any, Union

import pytest
Expand All @@ -13,10 +14,12 @@
from vllm.assets.base import VLLM_S3_BUCKET_URL
from vllm.assets.image import VLM_IMAGES_DIR
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.outputs import RequestOutput
from vllm.platforms import current_platform
from vllm.v1.spec_decode.metrics import compute_acceptance_rate


def get_test_prompts(mm_enabled: bool):
def get_test_prompts(mm_enabled: bool, quiet: bool = False):
prompt_types = ["repeat", "sentence"]
if mm_enabled:
prompt_types.append("mm")
Expand All @@ -25,7 +28,9 @@ def get_test_prompts(mm_enabled: bool):

random.seed(0)
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
print(f"Prompt types: {random_prompt_type_choices}")

if not quiet:
print(f"Prompt types: {random_prompt_type_choices}")

# Generate a mixed batch of prompts, some of which can be easily
# predicted by n-gram matching and some which likely cannot.
Expand Down Expand Up @@ -69,9 +74,17 @@ def get_test_prompts(mm_enabled: bool):

@pytest.fixture
def sampling_config():
return greedy_sampling()


def greedy_sampling() -> SamplingParams:
return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False)


def stochastic_sampling() -> SamplingParams:
return SamplingParams(temperature=1.0, max_tokens=10, ignore_eos=False)


@pytest.fixture
def model_name():
return "meta-llama/Llama-3.1-8B-Instruct"
Expand Down Expand Up @@ -223,3 +236,112 @@ def test_eagle_correctness(
del spec_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()


@dataclass
class ArgsTest:
model: str
draft_model: str
sampling_config: SamplingParams
expected_acceptance_rate: float
expected_same_output_fraction: float
# Defaults
target_tensor_parallel_size: int = 1
draft_tensor_parallel_size: int = 1
max_model_len: int = 1024
gpu_memory_utilization: float = 0.5


cases = [
ArgsTest(
model="Qwen/Qwen3-0.6B",
draft_model="Qwen/Qwen3-0.6B",
sampling_config=greedy_sampling(),
expected_acceptance_rate=1.0,
expected_same_output_fraction=1.0,
),
ArgsTest(
model="Qwen/Qwen3-1.7B",
draft_model="Qwen/Qwen3-0.6B",
sampling_config=stochastic_sampling(),
expected_acceptance_rate=0.9,
expected_same_output_fraction=0.9,
),
]


@pytest.mark.parametrize("args", cases)
@pytest.mark.parametrize("enforce_eager", [True, False])
@pytest.mark.parametrize("disable_padded_drafter_batch", [True, False])
def test_draft_model_correctness(
args: ArgsTest,
enforce_eager: bool,
disable_padded_drafter_batch: bool,
monkeypatch: pytest.MonkeyPatch,
):
"""Compare the outputs using and not using speculative decoding.
In the greedy decoding case, the outputs must match EXACTLY."""
monkeypatch.setenv("VLLM_USE_V1", "1")
test_prompts = get_test_prompts(mm_enabled=False, quiet=True)

spec_llm = LLM(
model=args.model,
speculative_config={
"model": args.draft_model,
"method": "draft_model",
"num_speculative_tokens": 3,
"max_model_len": args.max_model_len,
"enforce_eager": enforce_eager,
"tensor_parallel_size": args.draft_tensor_parallel_size,
"disable_padded_drafter_batch": disable_padded_drafter_batch,
},
max_model_len=args.max_model_len,
gpu_memory_utilization=args.gpu_memory_utilization,
tensor_parallel_size=args.target_tensor_parallel_size,
enforce_eager=enforce_eager,
disable_log_stats=False, # enables get_metrics()
)
spec_outputs = spec_llm.chat(test_prompts, args.sampling_config)
acceptance_rate = compute_acceptance_rate(spec_llm.get_metrics())
del spec_llm # CLEANUP
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()

assert acceptance_rate >= args.expected_acceptance_rate

ref_llm = LLM(
model=args.model,
max_model_len=args.max_model_len,
gpu_memory_utilization=args.gpu_memory_utilization,
tensor_parallel_size=args.target_tensor_parallel_size,
enforce_eager=enforce_eager,
)
ref_outputs = ref_llm.chat(test_prompts, args.sampling_config)
del ref_llm # CLEANUP
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()

assert len(ref_outputs) > 0
assert len(ref_outputs) == len(spec_outputs)

match_fraction = compute_exact_matches(ref_outputs, spec_outputs)
assert match_fraction >= args.expected_same_output_fraction

print(f"spec-decode: target={args.model}, draft={args.draft_model}, "
f"temperature={args.sampling_config.temperature:.2f}, "
f"acceptance_rate={acceptance_rate:.2f}, "
f"match_fraction={match_fraction:.2f}")


def compute_exact_matches(ref_outputs: list[RequestOutput],
spec_outputs: list[RequestOutput]) -> float:
"""Compute the fraction of the prompts that match exactly"""
assert len(ref_outputs) == len(spec_outputs)
matches = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
if ref_output.outputs[0].text == spec_output.outputs[0].text:
matches += 1
else:
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")
return matches / len(ref_outputs)
32 changes: 32 additions & 0 deletions tests/v1/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,38 @@ def test_bind_kv_cache():
assert runner_kv_caches[3] is kv_cache['layers.3.self_attn']


def test_bind_kv_cache_draft_model():
from vllm.attention import Attention
ctx = {
'model.layers.0.attn': Attention(32, 128, 0.1),
'model.layers.1.attn': Attention(32, 128, 0.1),
'draft_model.layers.0.attn': Attention(32, 128, 0.1),
'draft_model.layers.1.attn': Attention(32, 128, 0.1),
}
kv_cache = {
'model.layers.0.attn': torch.zeros((1, )),
'model.layers.1.attn': torch.zeros((1, )),
'draft_model.layers.0.attn': torch.zeros((1, )),
'draft_model.layers.1.attn': torch.zeros((1, )),
}
runner_kv_caches: list[torch.Tensor] = []
bind_kv_cache(kv_cache, ctx, runner_kv_caches)
assert ctx['model.layers.0.attn'].kv_cache[0] is kv_cache[
'model.layers.0.attn']
assert ctx['model.layers.1.attn'].kv_cache[0] is kv_cache[
'model.layers.1.attn']
assert ctx['draft_model.layers.0.attn'].kv_cache[0] is kv_cache[
'draft_model.layers.0.attn']
assert ctx['draft_model.layers.1.attn'].kv_cache[0] is kv_cache[
'draft_model.layers.1.attn']

# caches are ordered by layer_index, interleaving target and draft model
assert runner_kv_caches[0] is kv_cache['model.layers.0.attn']
assert runner_kv_caches[1] is kv_cache['draft_model.layers.0.attn']
assert runner_kv_caches[2] is kv_cache['model.layers.1.attn']
assert runner_kv_caches[3] is kv_cache['draft_model.layers.1.attn']


def test_bind_kv_cache_non_attention():
from vllm.attention import Attention

Expand Down
8 changes: 3 additions & 5 deletions vllm/config/speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,11 +336,6 @@ def __post_init__(self):
)
else:
self.method = "draft_model"
raise NotImplementedError(
"Speculative decoding with draft model is not "
"supported yet. Please consider using other "
"speculative decoding methods such as ngram, medusa, "
"eagle, or deepseek_mtp.")

# Replace hf_config for EAGLE draft_model
if self.method in ("eagle", "eagle3"):
Expand Down Expand Up @@ -552,6 +547,9 @@ def use_eagle(self) -> bool:
return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp",
"qwen3_next_mtp")

def uses_draft_model(self) -> bool:
return self.method == "draft_model"

def __repr__(self) -> str:
method = self.method
model = None if method == "ngram" else self.draft_model_config.model
Expand Down
8 changes: 0 additions & 8 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1534,14 +1534,6 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
recommend_to_remove=False)
return False

# V1 supports N-gram, Medusa, and Eagle speculative decoding.
if (self.speculative_config is not None
and self.speculative_config.get("method") == "draft_model"):
raise NotImplementedError(
"Speculative decoding with draft model is not supported yet. "
"Please consider using other speculative decoding methods "
"such as ngram, medusa, eagle, or deepseek_mtp.")

V1_BACKENDS = [
"FLASH_ATTN_VLLM_V1",
"FLASH_ATTN",
Expand Down
6 changes: 4 additions & 2 deletions vllm/model_executor/model_loader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,14 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:

def get_model(*,
vllm_config: VllmConfig,
model_config: Optional[ModelConfig] = None) -> nn.Module:
model_config: Optional[ModelConfig] = None,
prefix: str = "") -> nn.Module:
loader = get_model_loader(vllm_config.load_config)
if model_config is None:
model_config = vllm_config.model_config
return loader.load_model(vllm_config=vllm_config,
model_config=model_config)
model_config=model_config,
prefix=prefix)


__all__ = [
Expand Down
9 changes: 6 additions & 3 deletions vllm/model_executor/model_loader/base_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ def load_weights(self, model: nn.Module,
inplace weights loading for an already-initialized model"""
raise NotImplementedError

def load_model(self, vllm_config: VllmConfig,
model_config: ModelConfig) -> nn.Module:
def load_model(self,
vllm_config: VllmConfig,
model_config: ModelConfig,
prefix: str = "") -> nn.Module:
"""Load a model with the given configurations."""
device_config = vllm_config.device_config
load_config = vllm_config.load_config
Expand All @@ -43,7 +45,8 @@ def load_model(self, vllm_config: VllmConfig,
with set_default_torch_dtype(model_config.dtype):
with target_device:
model = initialize_model(vllm_config=vllm_config,
model_config=model_config)
model_config=model_config,
prefix=prefix)

logger.debug("Loading weights on %s ...", load_device)
# Quantization does not happen in `load_weights` but after it
Expand Down
9 changes: 6 additions & 3 deletions vllm/model_executor/model_loader/gguf_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,10 @@ def load_weights(self, model: nn.Module,
model.load_weights(
self._get_weights_iterator(local_model_path, gguf_weights_map))

def load_model(self, vllm_config: VllmConfig,
model_config: ModelConfig) -> nn.Module:
def load_model(self,
vllm_config: VllmConfig,
model_config: ModelConfig,
prefix: str = "") -> nn.Module:
device_config = vllm_config.device_config
local_model_path = self._prepare_weights(model_config.model)
gguf_weights_map = self._get_gguf_weights_map(model_config)
Expand All @@ -148,7 +150,8 @@ def load_model(self, vllm_config: VllmConfig,
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
with target_device:
model = initialize_model(vllm_config=vllm_config)
model = initialize_model(vllm_config=vllm_config,
prefix=prefix)
self.load_weights(model, model_config)

process_weights_after_loading(model, model_config, target_device)
Expand Down
13 changes: 9 additions & 4 deletions vllm/model_executor/model_loader/tensorizer_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def _get_weights_iterator(
def _load_model_serialized_cpu(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> nn.Module:
"""Load a serialized model with tensorizer to the CPU.

Expand All @@ -71,7 +72,8 @@ def _load_model_serialized_cpu(
model_config = vllm_config.model_config
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = initialize_model(vllm_config=vllm_config)
model = initialize_model(vllm_config=vllm_config,
prefix=prefix)

model.load_weights(self._get_weights_iterator())
return model.eval()
Expand Down Expand Up @@ -104,8 +106,10 @@ def load_weights(self, model: nn.Module,
else:
model.load_weights(self._get_weights_iterator())

def load_model(self, vllm_config: VllmConfig,
model_config: ModelConfig) -> nn.Module:
def load_model(self,
vllm_config: VllmConfig,
model_config: ModelConfig,
prefix: str = "") -> nn.Module:
parallel_config = vllm_config.parallel_config
self._verify_config(model_config, parallel_config)

Expand All @@ -126,7 +130,8 @@ def load_model(self, vllm_config: VllmConfig,
vllm_config=vllm_config)
self.load_weights(model, model_config)
return model
return self._load_model_serialized_cpu(vllm_config=vllm_config)
return self._load_model_serialized_cpu(vllm_config=vllm_config,
prefix=prefix)

@staticmethod
def save_model(
Expand Down
Loading