Skip to content

Commit 22ddbb2

Browse files
committed
[V1][Core] min_p sampling support
Signed-off-by: Aoyu <[email protected]>
1 parent 5e5c8e0 commit 22ddbb2

File tree

4 files changed

+93
-0
lines changed

4 files changed

+93
-0
lines changed

tests/v1/sample/test_sampler.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ def _create_default_sampling_metadata(
8181
top_k=torch.empty(batch_size, ),
8282
no_top_p=True,
8383
no_top_k=True,
84+
min_p=torch.empty(batch_size, ),
85+
no_min_p=True,
8486
generators={},
8587
max_num_logprobs=0,
8688
prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids,
@@ -338,6 +340,44 @@ def test_sampler_repetition_penalty(device: str, batch_size: int,
338340

339341
@pytest.mark.parametrize("device", CUDA_DEVICES)
340342
@pytest.mark.parametrize("batch_size", [1, 2, 32])
343+
@pytest.mark.parametrize("min_p", [0.0, 0.1])
344+
def test_sampler_min_p(device: str, batch_size: int, min_p: float):
345+
"""
346+
Tests that when min_p is applied, tokens with probability below
347+
min_p * max_prob are masked with -inf.
348+
"""
349+
torch.set_default_device(device)
350+
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
351+
352+
# Create one dominant token per batch
353+
for i in range(batch_size):
354+
fake_logits[i, 0] = 10.0 # High logit for first token
355+
fake_logits[i, 1:] = 1e-2 # Others remain low
356+
357+
sampling_metadata = _create_default_sampling_metadata(
358+
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
359+
360+
# Configure min_p parameters
361+
sampling_metadata.min_p = torch.full((batch_size, ), min_p, device=device)
362+
363+
sampler = Sampler()
364+
logits = sampler.apply_min_p(fake_logits, sampling_metadata.min_p)
365+
logits = logits.cpu()
366+
367+
for batch_idx in range(batch_size):
368+
for token_id in range(VOCAB_SIZE):
369+
if token_id == 0:
370+
# Dominant token should always be unmasked
371+
assert logits[batch_idx][token_id] != -float("inf")
372+
else:
373+
if min_p > 0.0:
374+
# Non-dominant tokens should be masked when min_p > 0
375+
assert logits[batch_idx][token_id] == -float("inf")
376+
else:
377+
# No masking when min_p is 0
378+
assert logits[batch_idx][token_id] != -float("inf")
379+
380+
341381
@pytest.mark.parametrize("bias_value", [-0.1, 1.2])
342382
def test_sampler_logit_bias(device: str, batch_size: int, bias_value: float):
343383
"""

vllm/v1/sample/metadata.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ class SamplingMetadata:
1717
top_k: torch.Tensor
1818
no_top_p: bool
1919
no_top_k: bool
20+
min_p: torch.Tensor
21+
no_min_p: bool
2022

2123
generators: Dict[int, torch.Generator]
2224

vllm/v1/sample/sampler.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@ def sample(
9393
sampling_metadata.no_top_p,
9494
sampling_metadata.top_p,
9595
)
96+
97+
if not sampling_metadata.no_min_p:
98+
logits = self.apply_min_p(logits, sampling_metadata.min_p)
99+
96100
if sampling_metadata.all_random:
97101
return random_sampled
98102

@@ -169,6 +173,27 @@ def apply_penalties(
169173
sampling_metadata.output_token_ids)
170174
return logits
171175

176+
def apply_min_p(
177+
self,
178+
logits: torch.Tensor,
179+
min_p: torch.Tensor,
180+
) -> torch.Tensor:
181+
"""
182+
Filters logits using adaptive probability thresholding.
183+
"""
184+
# Convert logits to probability distribution
185+
probability_values = torch.nn.functional.softmax(logits, dim=-1)
186+
# Calculate maximum probabilities per sequence
187+
max_probabilities = torch.amax(probability_values,
188+
dim=-1,
189+
keepdim=True)
190+
# Reshape min_p for broadcasting
191+
adjusted_min_p = min_p.unsqueeze(1) * max_probabilities
192+
# Identify valid tokens using threshold comparison
193+
valid_token_mask = probability_values >= adjusted_min_p
194+
# Apply mask using boolean indexing
195+
logits[~valid_token_mask] = -float('inf')
196+
172197
def apply_logits_bias(
173198
self,
174199
logits: torch.Tensor,

vllm/v1/worker/gpu_input_batch.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from vllm.v1.sample.metadata import SamplingMetadata
1515
from vllm.v1.worker.block_table import BlockTable
1616

17+
_SAMPLING_EPS = 1e-5
18+
1719
if TYPE_CHECKING:
1820
from vllm.multimodal.inputs import PlaceholderRange
1921

@@ -120,6 +122,16 @@ def __init__(
120122
self.top_k_cpu = self.top_k_cpu_tensor.numpy()
121123
self.top_k_reqs: Set[str] = set()
122124

125+
self.min_p = torch.empty((max_num_reqs, ),
126+
dtype=torch.float32,
127+
device=device)
128+
self.min_p_cpu_tensor = torch.empty((max_num_reqs, ),
129+
dtype=torch.float32,
130+
device="cpu",
131+
pin_memory=pin_memory)
132+
self.min_p_cpu = self.min_p_cpu_tensor.numpy()
133+
self.min_p_reqs: Set[str] = set()
134+
123135
# Frequency penalty related data structures
124136
self.frequency_penalties = torch.empty((max_num_reqs, ),
125137
dtype=torch.float,
@@ -223,8 +235,11 @@ def add_request(
223235
self.top_k_cpu[req_index] = sampling_params.top_k
224236
if sampling_params.top_k > 0:
225237
self.top_k_reqs.add(req_id)
238+
self.min_p_cpu[req_index] = sampling_params.min_p
226239
self.frequency_penalties_cpu[
227240
req_index] = sampling_params.frequency_penalty
241+
if sampling_params.min_p > _SAMPLING_EPS:
242+
self.min_p_reqs.add(req_id)
228243
if sampling_params.frequency_penalty != 0.0:
229244
self.frequency_penalties_reqs.add(req_id)
230245
self.presence_penalties_cpu[
@@ -273,6 +288,7 @@ def remove_request(self, req_id: str) -> Optional[int]:
273288
self.random_reqs.discard(req_id)
274289
self.top_p_reqs.discard(req_id)
275290
self.top_k_reqs.discard(req_id)
291+
self.min_p_reqs.discard(req_id)
276292
self.frequency_penalties_reqs.discard(req_id)
277293
self.presence_penalties_reqs.discard(req_id)
278294
self.repetition_penalties_reqs.discard(req_id)
@@ -299,6 +315,7 @@ def clear(self) -> None:
299315
self.random_reqs.clear()
300316
self.top_p_reqs.clear()
301317
self.top_k_reqs.clear()
318+
self.min_p_reqs.clear()
302319
self.frequency_penalties_reqs.clear()
303320
self.presence_penalties_reqs.clear()
304321
self.repetition_penalties_reqs.clear()
@@ -354,6 +371,7 @@ def condense(self, empty_req_indices: List[int]) -> None:
354371
empty_index] = self.presence_penalties_cpu[last_req_index]
355372
self.repetition_penalties_cpu[
356373
empty_index] = self.repetition_penalties_cpu[last_req_index]
374+
self.min_p_cpu[empty_index] = self.min_p_cpu[last_req_index]
357375
self.min_tokens[empty_index] = self.min_tokens[last_req_index]
358376
self.stop_token_ids[empty_index] = self.stop_token_ids[
359377
last_req_index]
@@ -381,6 +399,8 @@ def make_sampling_metadata(
381399
self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True)
382400
self.top_k[:self.num_reqs].copy_(
383401
self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True)
402+
self.min_p[:self.num_reqs].copy_(
403+
self.min_p_cpu_tensor[:self.num_reqs], non_blocking=True)
384404
if not self.no_penalties:
385405
# Since syncing these tensors is expensive only copy them
386406
# if necessary i.e. if there are requests which require
@@ -421,6 +441,8 @@ def make_sampling_metadata(
421441
all_random=self.all_random,
422442
top_p=self.top_p[:self.num_reqs],
423443
top_k=self.top_k[:self.num_reqs],
444+
min_p=self.min_p[:self.num_reqs],
445+
no_min_p=self.no_min_p,
424446
no_top_p=self.no_top_p,
425447
no_top_k=self.no_top_k,
426448
generators=self.generators,
@@ -497,6 +519,10 @@ def no_top_p(self) -> bool:
497519
def no_top_k(self) -> bool:
498520
return len(self.top_k_reqs) == 0
499521

522+
@property
523+
def no_min_p(self) -> bool:
524+
return len(self.min_p_reqs) == 0
525+
500526
@property
501527
def no_penalties(self) -> bool:
502528
return (len(self.presence_penalties_reqs) == 0

0 commit comments

Comments
 (0)