1
+ # -----------------------------------------------------------------------------
2
+ #
3
+ # Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # -----------------------------------------------------------------------------
7
+
8
+ from QEfficient import QEFFAutoModelForCausalLM as AutoModelForCausalLM
9
+ from QEfficient .generation .cloud_infer import QAICInferenceSession
10
+ from transformers import AutoTokenizer
11
+ from typing import List
12
+ import numpy as np
13
+ import pytest
14
+
15
+ configs = [
16
+ pytest .param (
17
+ # device_group, num_speculative_tokens, prompt_len, ctx_len, prefill_bsz, full_batch_size, model_name, id
18
+ [0 ], 5 , 32 , 128 , 1 , 8 , "TinyLlama/TinyLlama-1.1B-Chat-v1.0" , id = "llama"
19
+ ),
20
+ ]
21
+
22
+ @pytest .mark .parametrize ("device_group,num_speculative_tokens,prompt_len,ctx_len,prefill_bsz,full_batch_size,model_name" , configs )
23
+ def test_llama_tlm_logit_dims (
24
+ device_group : List [int ],
25
+ num_speculative_tokens : int ,
26
+ prompt_len : int ,
27
+ ctx_len : int ,
28
+ prefill_bsz : int ,
29
+ full_batch_size : int ,
30
+ model_name : str
31
+ ):
32
+
33
+ # get vocab size
34
+ tokenizer = AutoTokenizer .from_pretrained (model_name )
35
+ vocab_size = len (tokenizer )
36
+
37
+ # export_and_compile tlm model
38
+ qeff_model = AutoModelForCausalLM .from_pretrained (model_name , num_speculative_tokens = num_speculative_tokens )
39
+ qpc_path : str = qeff_model .export_and_compile (
40
+ num_cores = 16 ,
41
+ device_group = device_group ,
42
+ batch_size = prefill_bsz ,
43
+ prompt_len = prompt_len ,
44
+ ctx_len = ctx_len ,
45
+ mxfp6 = True ,
46
+ mxint8 = True ,
47
+ full_batch_size = full_batch_size
48
+ )
49
+
50
+ # init qaic session
51
+ session = QAICInferenceSession (qpc_path , device_ids = device_group )
52
+ # skip inputs/outputs buffers
53
+ session .skip_buffers (
54
+ set ([x for x in session .input_names if x .startswith ("past_" )])
55
+ )
56
+ session .skip_buffers (
57
+ set ([x for x in session .output_names if x .endswith ("_RetainedState" )])
58
+ )
59
+ # prefill dummy inputs
60
+ prefill_inputs = dict (
61
+ input_ids = np .zeros ((prefill_bsz , prompt_len ), dtype = np .int64 ),
62
+ position_ids = np .arange (prompt_len , dtype = np .int64 ).reshape (- 1 ,1 ).repeat (prefill_bsz ,1 ).transpose (),
63
+ batch_index = np .arange (prefill_bsz , dtype = np .int64 ).reshape (prefill_bsz ,1 )
64
+ )
65
+ # decode dummy inputs
66
+ decode_inputs = dict (
67
+ input_ids = np .zeros ((full_batch_size , num_speculative_tokens + 1 ), dtype = np .int64 ),
68
+ position_ids = np .full ((full_batch_size , num_speculative_tokens + 1 ), - 1 , dtype = np .int64 ),
69
+ batch_index = np .arange (full_batch_size , dtype = np .int64 ).reshape (- 1 ,1 )
70
+ )
71
+ # create dummy logits
72
+ prefill_logits = dict (logits = np .random .randn (prefill_bsz , prompt_len , vocab_size ).astype (np .float32 ))
73
+ decode_logits = dict (logits = np .random .randn (full_batch_size , num_speculative_tokens + 1 , vocab_size ).astype (np .float32 ))
74
+ # get prefill/decode logits
75
+ session .set_buffers (prefill_logits )
76
+ prefill_outputs = session .run (prefill_inputs )
77
+ session .set_buffers (decode_logits )
78
+ decode_outputs = session .run (decode_inputs )
79
+
80
+
81
+ # assert expected logit dims
82
+ assert prefill_logits ["logits" ].shape == prefill_outputs ["logits" ].shape
83
+ assert decode_logits ["logits" ].shape == decode_outputs ["logits" ].shape
84
+
85
+
86
+ @pytest .mark .parametrize ("device_group,num_speculative_tokens,prompt_len,ctx_len,prefill_bsz,full_batch_size,model_name" , configs )
87
+ def test_llama_dlm_logit_dims (
88
+ device_group : List [int ],
89
+ num_speculative_tokens : int ,
90
+ prompt_len : int ,
91
+ ctx_len : int ,
92
+ prefill_bsz : int ,
93
+ full_batch_size : int ,
94
+ model_name : str
95
+ ):
96
+
97
+ # get vocab size
98
+ tokenizer = AutoTokenizer .from_pretrained (model_name )
99
+ vocab_size = len (tokenizer )
100
+
101
+ # export_and_compile tlm model
102
+ qeff_model = AutoModelForCausalLM .from_pretrained (model_name , is_dlm = True )
103
+ qpc_path : str = qeff_model .export_and_compile (
104
+ num_cores = 16 ,
105
+ device_group = device_group ,
106
+ batch_size = prefill_bsz ,
107
+ prompt_len = prompt_len ,
108
+ ctx_len = ctx_len ,
109
+ mxfp6 = True ,
110
+ mxint8 = True ,
111
+ full_batch_size = full_batch_size
112
+ )
113
+
114
+ # init qaic session
115
+ session = QAICInferenceSession (qpc_path , device_ids = device_group )
116
+ # skip inputs/outputs buffers
117
+ session .skip_buffers (
118
+ set ([x for x in session .input_names if x .startswith ("past_" )])
119
+ )
120
+ session .skip_buffers (
121
+ set ([x for x in session .output_names if x .endswith ("_RetainedState" )])
122
+ )
123
+ # prefill dummy inputs
124
+ prefill_inputs = dict (
125
+ input_ids = np .zeros ((prefill_bsz , prompt_len ), dtype = np .int64 ),
126
+ position_ids = np .arange (prompt_len , dtype = np .int64 ).reshape (- 1 ,1 ).repeat (prefill_bsz ,1 ).transpose (),
127
+ batch_index = np .arange (prefill_bsz , dtype = np .int64 ).reshape (- 1 ,1 )
128
+ )
129
+ # decode-1 dummy inputs
130
+ decode1_inputs = dict (
131
+ input_ids = np .zeros ((full_batch_size , 1 ), dtype = np .int64 ),
132
+ position_ids = np .full ((full_batch_size , 1 ), - 1 , dtype = np .int64 ),
133
+ batch_index = np .arange (full_batch_size , dtype = np .int64 ).reshape (- 1 ,1 )
134
+ )
135
+ # decode-2 dummy inputs
136
+ decode2_inputs = dict (
137
+ input_ids = np .zeros ((full_batch_size , 2 ), dtype = np .int64 ),
138
+ position_ids = np .full ((full_batch_size , 2 ), - 1 , dtype = np .int64 ),
139
+ batch_index = np .arange (full_batch_size , dtype = np .int64 ).reshape (- 1 ,1 )
140
+ )
141
+ # create dummy logits
142
+ prefill_logits = dict (logits = np .random .randn (prefill_bsz , 1 , vocab_size ).astype (np .float32 ))
143
+ decode_logits = dict (logits = np .random .randn (full_batch_size , 1 , vocab_size ).astype (np .float32 ))
144
+ # get prefill/decode logits
145
+ session .set_buffers (prefill_logits )
146
+ prefill_outputs = session .run (prefill_inputs )
147
+ session .set_buffers (decode_logits )
148
+ decode1_outputs = session .run (decode1_inputs )
149
+ decode2_outputs = session .run (decode2_inputs )
150
+
151
+
152
+ # assert expected logit dims
153
+ assert prefill_logits ["logits" ].shape == prefill_outputs ["logits" ].shape
154
+ assert decode_logits ["logits" ].shape == decode1_outputs ["logits" ].shape
155
+ assert decode_logits ["logits" ].shape == decode2_outputs ["logits" ].shape
0 commit comments