Skip to content

Commit ea89974

Browse files
[Benchmark][Bug] Fix multiple bugs in bench and add args to spec_decode offline (vllm-project#20083)
Signed-off-by: Will Eaton <[email protected]>
1 parent 3ac9ff6 commit ea89974

File tree

4 files changed

+28
-11
lines changed

4 files changed

+28
-11
lines changed

benchmarks/benchmark_dataset.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,8 +349,9 @@ def sample(
349349
# [1650, 939, 486] -> ['Ġcall', 'sh', 'ere']
350350
# To avoid uncontrolled change of the prompt length,
351351
# the encoded sequence is truncated before being decode again.
352+
total_input_len = prefix_len + int(input_lens[i])
352353
re_encoded_sequence = tokenizer.encode(prompt, add_special_tokens=False)[
353-
: input_lens[i]
354+
:total_input_len
354355
]
355356
prompt = tokenizer.decode(re_encoded_sequence)
356357
total_input_len = len(re_encoded_sequence)

examples/offline_inference/spec_decode.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,20 @@ def parse_args():
3939
parser.add_argument("--top-k", type=int, default=-1)
4040
parser.add_argument("--print-output", action="store_true")
4141
parser.add_argument("--output-len", type=int, default=256)
42+
parser.add_argument("--model-dir", type=str, default=None)
43+
parser.add_argument("--eagle-dir", type=str, default=None)
44+
parser.add_argument("--max-model-len", type=int, default=2048)
4245
return parser.parse_args()
4346

4447

4548
def main():
4649
args = parse_args()
4750
args.endpoint_type = "openai-chat"
4851

49-
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
52+
model_dir = args.model_dir
53+
if args.model_dir is None:
54+
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
5055
tokenizer = AutoTokenizer.from_pretrained(model_dir)
51-
max_model_len = 2048
5256

5357
prompts = get_samples(args, tokenizer)
5458
# add_special_tokens is False to avoid adding bos twice when using chat templates
@@ -57,24 +61,26 @@ def main():
5761
]
5862

5963
if args.method == "eagle" or args.method == "eagle3":
60-
if args.method == "eagle":
64+
eagle_dir = args.eagle_dir
65+
if args.method == "eagle" and eagle_dir is None:
6166
eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
62-
elif args.method == "eagle3":
67+
68+
elif args.method == "eagle3" and eagle_dir is None:
6369
eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
6470
speculative_config = {
6571
"method": args.method,
6672
"model": eagle_dir,
6773
"num_speculative_tokens": args.num_spec_tokens,
6874
"draft_tensor_parallel_size": args.draft_tp,
69-
"max_model_len": max_model_len,
75+
"max_model_len": args.max_model_len,
7076
}
7177
elif args.method == "ngram":
7278
speculative_config = {
7379
"method": "ngram",
7480
"num_speculative_tokens": args.num_spec_tokens,
7581
"prompt_lookup_max": args.prompt_lookup_max,
7682
"prompt_lookup_min": args.prompt_lookup_min,
77-
"max_model_len": max_model_len,
83+
"max_model_len": args.max_model_len,
7884
}
7985
else:
8086
raise ValueError(f"unknown method: {args.method}")
@@ -86,7 +92,7 @@ def main():
8692
enable_chunked_prefill=args.enable_chunked_prefill,
8793
max_num_batched_tokens=args.max_num_batched_tokens,
8894
enforce_eager=args.enforce_eager,
89-
max_model_len=max_model_len,
95+
max_model_len=args.max_model_len,
9096
max_num_seqs=args.max_num_seqs,
9197
gpu_memory_utilization=0.8,
9298
speculative_config=speculative_config,

vllm/benchmarks/datasets.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,8 @@ def __init__(
320320
**kwargs,
321321
) -> None:
322322
super().__init__(**kwargs)
323+
random.seed(self.random_seed)
324+
np.random.seed(self.random_seed)
323325

324326
def sample(
325327
self,
@@ -376,10 +378,11 @@ def sample(
376378
# [1650, 939, 486] -> ['Ġcall', 'sh', 'ere']
377379
# To avoid uncontrolled change of the prompt length,
378380
# the encoded sequence is truncated before being decode again.
381+
total_input_len = prefix_len + int(input_lens[i])
379382
re_encoded_sequence = tokenizer.encode(
380-
prompt, add_special_tokens=False)[:input_lens[i]]
383+
prompt, add_special_tokens=False)[:total_input_len]
381384
prompt = tokenizer.decode(re_encoded_sequence)
382-
total_input_len = prefix_len + int(input_lens[i])
385+
total_input_len = len(re_encoded_sequence)
383386
requests.append(
384387
SampleRequest(
385388
prompt=prompt,
@@ -692,7 +695,8 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
692695
dataset_path=args.dataset_path).
693696
sample(tokenizer=tokenizer, num_requests=args.num_prompts),
694697
"random":
695-
lambda: RandomDataset(dataset_path=args.dataset_path).sample(
698+
lambda: RandomDataset(random_seed=args.seed,
699+
dataset_path=args.dataset_path).sample(
696700
tokenizer=tokenizer,
697701
num_requests=args.num_prompts,
698702
prefix_len=args.random_prefix_len,

vllm/benchmarks/serve.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,12 @@ def add_cli_args(parser: argparse.ArgumentParser):
631631
help="The label (prefix) of the benchmark results. If not specified, "
632632
"the endpoint type will be used as the label.",
633633
)
634+
parser.add_argument(
635+
"--backend",
636+
type=str,
637+
default="vllm",
638+
choices=list(ASYNC_REQUEST_FUNCS.keys()),
639+
)
634640
parser.add_argument(
635641
"--base-url",
636642
type=str,

0 commit comments

Comments
 (0)