Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions vllm/v1/structured_output/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
import os
from typing import TYPE_CHECKING

import numpy as np
import regex as re
import torch
from cachetools import LRUCache
from diskcache import Cache

Expand All @@ -20,9 +22,13 @@
import outlines_core as oc
import transformers.file_utils as file_utils
import transformers.models.gpt2.tokenization_gpt2 as tokenization_gpt2
import xgrammar as xgr

from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
else:
xgr = LazyLoader("xgr", globals(), "xgrammar")
oc = LazyLoader("oc", globals(), "outlines_core")
file_utils = LazyLoader("file_utils", globals(), "transformers.file_utils")
tokenization_gpt2 = LazyLoader(
Expand All @@ -36,6 +42,80 @@
CACHE = None


def apply_grammar_bitmask(
scheduler_output: SchedulerOutput,
input_batch: InputBatch,
logits: torch.Tensor,
device: torch.device,
) -> None:
"""
Apply grammar bitmask to output logits of the model with xgrammar function.

Args:
scheduler_output (SchedulerOutput): The result of engine scheduling.
input_batch (InputBatch): The input of model runner.
logits (torch.Tensor): The output logits of model forward.
device (torch.device): The device that model runner running on.
"""
grammar_bitmask = scheduler_output.grammar_bitmask
if grammar_bitmask is None:
return

# We receive the structured output bitmask from the scheduler,
# compacted to contain bitmasks only for structured output requests.
# The order of the requests in the bitmask is not guaranteed to be the
# same as the order of the requests in the gpu runner's batch. We need
# to sort the bitmask to match the order of the requests used here.

# Get the batch indices of the structured output requests.
# Keep track of the number of speculative tokens scheduled for every
# request in the batch, as the logit indices are offset by this amount.
struct_out_req_batch_indices: dict[str, int] = {}
cumulative_offset = 0
seq = sorted(input_batch.req_id_to_index.items(), key=lambda x: x[1])
for req_id, batch_index in seq:
logit_index = batch_index + cumulative_offset
cumulative_offset += len(
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
if req_id in scheduler_output.structured_output_request_ids:
struct_out_req_batch_indices[req_id] = logit_index

out_indices = []

# Reorder the bitmask to match the order of the requests in the batch.
sorted_bitmask = np.full(shape=(logits.shape[0], grammar_bitmask.shape[1]),
fill_value=-1,
dtype=grammar_bitmask.dtype)
cumulative_index = 0
seq = sorted(scheduler_output.structured_output_request_ids.items(),
key=lambda x: x[1])
for req_id, _ in seq:
logit_index = struct_out_req_batch_indices[req_id]
num_spec_tokens = len(
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
for i in range(1 + num_spec_tokens):
sorted_bitmask[logit_index + i] = \
grammar_bitmask[cumulative_index + i]
out_indices.append(logit_index + i)
cumulative_index += 1 + num_spec_tokens
grammar_bitmask = sorted_bitmask

# If the length of out indices and the logits have the same shape
# we don't need to pass indices to the kernel,
# since the bitmask is already aligned with the logits.
skip_out_indices = len(out_indices) == logits.shape[0]

# Serialization of np.ndarray is much more efficient than a tensor,
# so we receive it in that format.
grammar_bitmask = torch.from_numpy(grammar_bitmask).contiguous()

xgr.apply_token_bitmask_inplace(
logits,
grammar_bitmask.to(device, non_blocking=True),
indices=out_indices if not skip_out_indices else None,
)


class OutlinesVocabulary:
"""
Wrapper class for `outlines_core.Vocabulary`,
Expand Down
75 changes: 4 additions & 71 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size,
GiB_bytes, check_use_alibi, get_dtype_size,
is_pin_memory_available, round_up, supports_dynamo)
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
from vllm.v1.attention.backends.utils import (
Expand Down Expand Up @@ -83,6 +83,7 @@
from vllm.v1.spec_decode.medusa import MedusaProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.structured_output.utils import apply_grammar_bitmask
from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.kv_connector_model_runner_mixin import (
Expand All @@ -95,12 +96,8 @@
scatter_mm_placeholders)

if TYPE_CHECKING:
import xgrammar as xgr

from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.v1.core.sched.output import SchedulerOutput
else:
xgr = LazyLoader("xgr", globals(), "xgrammar")

logger = init_logger(__name__)

Expand Down Expand Up @@ -1561,71 +1558,6 @@ def get_supported_tasks(self) -> tuple[SupportedTask, ...]:

return tuple(tasks)

def apply_grammar_bitmask(
self,
scheduler_output: "SchedulerOutput",
logits: torch.Tensor,
):
grammar_bitmask = scheduler_output.grammar_bitmask
if grammar_bitmask is None:
return

# We receive the structured output bitmask from the scheduler,
# compacted to contain bitmasks only for structured output requests.
# The order of the requests in the bitmask is not guaranteed to be the
# same as the order of the requests in the gpu runner's batch. We need
# to sort the bitmask to match the order of the requests used here.

# Get the batch indices of the structured output requests.
# Keep track of the number of speculative tokens scheduled for every
# request in the batch, as the logit indices are offset by this amount.
struct_out_req_batch_indices: dict[str, int] = {}
cumulative_offset = 0
seq = sorted(self.input_batch.req_id_to_index.items(),
key=lambda x: x[1])
for req_id, batch_index in seq:
logit_index = batch_index + cumulative_offset
cumulative_offset += len(
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
if req_id in scheduler_output.structured_output_request_ids:
struct_out_req_batch_indices[req_id] = logit_index

out_indices = []

# Reorder the bitmask to match the order of the requests in the batch.
sorted_bitmask = np.full(shape=(logits.shape[0],
grammar_bitmask.shape[1]),
fill_value=-1,
dtype=grammar_bitmask.dtype)
cumulative_index = 0
seq = sorted(scheduler_output.structured_output_request_ids.items(),
key=lambda x: x[1])
for req_id, _ in seq:
logit_index = struct_out_req_batch_indices[req_id]
num_spec_tokens = len(
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
for i in range(1 + num_spec_tokens):
sorted_bitmask[logit_index + i] = \
grammar_bitmask[cumulative_index + i]
out_indices.append(logit_index + i)
cumulative_index += 1 + num_spec_tokens
grammar_bitmask = sorted_bitmask

# If the length of out indices and the logits have the same shape
# we don't need to pass indices to the kernel,
# since the bitmask is already aligned with the logits.
skip_out_indices = len(out_indices) == logits.shape[0]

# Serialization of np.ndarray is much more efficient than a tensor,
# so we receive it in that format.
grammar_bitmask = torch.from_numpy(grammar_bitmask).contiguous()

xgr.apply_token_bitmask_inplace(
logits,
grammar_bitmask.to(self.device, non_blocking=True),
indices=out_indices if not skip_out_indices else None,
)

def sync_and_slice_intermediate_tensors(
self, num_tokens: int, intermediate_tensors: IntermediateTensors,
sync_self: bool) -> IntermediateTensors:
Expand Down Expand Up @@ -2128,7 +2060,8 @@ def execute_model(

# Apply structured output bitmasks if present
if scheduler_output.grammar_bitmask is not None:
self.apply_grammar_bitmask(scheduler_output, logits)
apply_grammar_bitmask(scheduler_output, self.input_batch,
logits, self.device)

with record_function_or_nullcontext("Sample"):
sampler_output = self._sample(logits, spec_decode_metadata)
Expand Down