@@ -39,16 +39,20 @@ def parse_args():
39
39
parser .add_argument ("--top-k" , type = int , default = - 1 )
40
40
parser .add_argument ("--print-output" , action = "store_true" )
41
41
parser .add_argument ("--output-len" , type = int , default = 256 )
42
+ parser .add_argument ("--model-dir" , type = str , default = None )
43
+ parser .add_argument ("--eagle-dir" , type = str , default = None )
44
+ parser .add_argument ("--max-model-len" , type = int , default = 2048 )
42
45
return parser .parse_args ()
43
46
44
47
45
48
def main ():
46
49
args = parse_args ()
47
50
args .endpoint_type = "openai-chat"
48
51
49
- model_dir = "meta-llama/Llama-3.1-8B-Instruct"
52
+ model_dir = args .model_dir
53
+ if args .model_dir is None :
54
+ model_dir = "meta-llama/Llama-3.1-8B-Instruct"
50
55
tokenizer = AutoTokenizer .from_pretrained (model_dir )
51
- max_model_len = 2048
52
56
53
57
prompts = get_samples (args , tokenizer )
54
58
# add_special_tokens is False to avoid adding bos twice when using chat templates
@@ -57,24 +61,26 @@ def main():
57
61
]
58
62
59
63
if args .method == "eagle" or args .method == "eagle3" :
60
- if args .method == "eagle" :
64
+ eagle_dir = args .eagle_dir
65
+ if args .method == "eagle" and eagle_dir is None :
61
66
eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
62
- elif args .method == "eagle3" :
67
+
68
+ elif args .method == "eagle3" and eagle_dir is None :
63
69
eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
64
70
speculative_config = {
65
71
"method" : args .method ,
66
72
"model" : eagle_dir ,
67
73
"num_speculative_tokens" : args .num_spec_tokens ,
68
74
"draft_tensor_parallel_size" : args .draft_tp ,
69
- "max_model_len" : max_model_len ,
75
+ "max_model_len" : args . max_model_len ,
70
76
}
71
77
elif args .method == "ngram" :
72
78
speculative_config = {
73
79
"method" : "ngram" ,
74
80
"num_speculative_tokens" : args .num_spec_tokens ,
75
81
"prompt_lookup_max" : args .prompt_lookup_max ,
76
82
"prompt_lookup_min" : args .prompt_lookup_min ,
77
- "max_model_len" : max_model_len ,
83
+ "max_model_len" : args . max_model_len ,
78
84
}
79
85
else :
80
86
raise ValueError (f"unknown method: { args .method } " )
@@ -86,7 +92,7 @@ def main():
86
92
enable_chunked_prefill = args .enable_chunked_prefill ,
87
93
max_num_batched_tokens = args .max_num_batched_tokens ,
88
94
enforce_eager = args .enforce_eager ,
89
- max_model_len = max_model_len ,
95
+ max_model_len = args . max_model_len ,
90
96
max_num_seqs = args .max_num_seqs ,
91
97
gpu_memory_utilization = 0.8 ,
92
98
speculative_config = speculative_config ,
0 commit comments