Skip to content

Commit 75719e9

Browse files
committed
added SpD unit tests
1 parent bce560e commit 75719e9

File tree

3 files changed

+164
-2
lines changed

3 files changed

+164
-2
lines changed

QEfficient/compile/compile_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def compile(
181181
ctx_len=ctx_len,
182182
path=specialization_json_path,
183183
full_batch_size=full_batch_size,
184-
is_dlm=kwargs.get("is_dlm", None),
184+
is_dlm=kwargs.get("is_dlm", False),
185185
num_speculative_tokens=kwargs.get("num_speculative_tokens", None),
186186
)
187187

QEfficient/transformers/models/modeling_auto.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs):
103103

104104
full_batch_size = kwargs.pop("full_batch_size", None)
105105

106+
num_speculative_tokens = kwargs.pop("num_speculative_tokens", None)
107+
is_dlm = kwargs.pop("is_dlm", False)
108+
106109
attn_implementation = kwargs.get("attn_implementation", None)
107110
if attn_implementation != "eager":
108111
logger.warning(f"Updating attn_implementation to be 'eager', got {attn_implementation}")
@@ -120,6 +123,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs):
120123
pretrained_model_name_or_path=pretrained_model_name_or_path,
121124
model_card_name=model_card_name,
122125
full_batch_size=full_batch_size,
126+
num_speculative_tokens=num_speculative_tokens,
127+
is_dlm=is_dlm,
123128
**kwargs,
124129
)
125130

@@ -192,7 +197,7 @@ def transform(self, **kwargs):
192197
assert (
193198
not isinstance(num_speculative_tokens, int)
194199
) or not is_dlm, "number of speculative tokens are only to be specified for Target LM"
195-
if num_speculative_tokens:
200+
if num_speculative_tokens is not None:
196201
assert isinstance(num_speculative_tokens, int) and num_speculative_tokens > 0, (
197202
"argument num_speculative_tokens" " should be of type integer and" " be positive if specified"
198203
)
@@ -365,6 +370,8 @@ def export_and_compile(
365370
mxfp6=mxfp6,
366371
mxint8=mxint8,
367372
full_batch_size=full_batch_size,
373+
num_speculative_tokens=getattr(self.model, "num_speculative_tokens", None),
374+
is_dlm=getattr(self.model, "is_dlm", False),
368375
)
369376
return self.qpc_path
370377

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# -----------------------------------------------------------------------------
7+
8+
from QEfficient import QEFFAutoModelForCausalLM as AutoModelForCausalLM
9+
from QEfficient.generation.cloud_infer import QAICInferenceSession
10+
from transformers import AutoTokenizer
11+
from typing import List
12+
import numpy as np
13+
import pytest
14+
15+
configs = [
16+
pytest.param(
17+
# device_group, num_speculative_tokens, prompt_len, ctx_len, prefill_bsz, full_batch_size, model_name, id
18+
[0], 5, 32, 128, 1, 8, "TinyLlama/TinyLlama-1.1B-Chat-v1.0", id="llama"
19+
),
20+
]
21+
22+
@pytest.mark.parametrize("device_group,num_speculative_tokens,prompt_len,ctx_len,prefill_bsz,full_batch_size,model_name", configs)
23+
def test_llama_tlm_logit_dims(
24+
device_group: List[int],
25+
num_speculative_tokens: int,
26+
prompt_len: int,
27+
ctx_len: int,
28+
prefill_bsz: int,
29+
full_batch_size: int,
30+
model_name: str
31+
):
32+
33+
# get vocab size
34+
tokenizer = AutoTokenizer.from_pretrained(model_name)
35+
vocab_size = len(tokenizer)
36+
37+
# export_and_compile tlm model
38+
qeff_model = AutoModelForCausalLM.from_pretrained(model_name, num_speculative_tokens=num_speculative_tokens)
39+
qpc_path: str = qeff_model.export_and_compile(
40+
num_cores=16,
41+
device_group=device_group,
42+
batch_size=prefill_bsz,
43+
prompt_len=prompt_len,
44+
ctx_len=ctx_len,
45+
mxfp6=True,
46+
mxint8=True,
47+
full_batch_size=full_batch_size
48+
)
49+
50+
# init qaic session
51+
session = QAICInferenceSession(qpc_path, device_ids=device_group)
52+
# skip inputs/outputs buffers
53+
session.skip_buffers(
54+
set([x for x in session.input_names if x.startswith("past_")])
55+
)
56+
session.skip_buffers(
57+
set([x for x in session.output_names if x.endswith("_RetainedState")])
58+
)
59+
# prefill dummy inputs
60+
prefill_inputs = dict(
61+
input_ids = np.zeros((prefill_bsz, prompt_len), dtype=np.int64),
62+
position_ids = np.arange(prompt_len, dtype=np.int64).reshape(-1,1).repeat(prefill_bsz,1).transpose(),
63+
batch_index= np.arange(prefill_bsz, dtype=np.int64).reshape(prefill_bsz,1)
64+
)
65+
# decode dummy inputs
66+
decode_inputs = dict(
67+
input_ids = np.zeros((full_batch_size, num_speculative_tokens+1), dtype=np.int64),
68+
position_ids = np.full((full_batch_size, num_speculative_tokens+1), -1, dtype=np.int64),
69+
batch_index=np.arange(full_batch_size, dtype=np.int64).reshape(-1,1)
70+
)
71+
# create dummy logits
72+
prefill_logits = dict(logits=np.random.randn(prefill_bsz, prompt_len, vocab_size).astype(np.float32))
73+
decode_logits = dict(logits=np.random.randn(full_batch_size, num_speculative_tokens+1, vocab_size).astype(np.float32))
74+
# get prefill/decode logits
75+
session.set_buffers(prefill_logits)
76+
prefill_outputs = session.run(prefill_inputs)
77+
session.set_buffers(decode_logits)
78+
decode_outputs = session.run(decode_inputs)
79+
80+
81+
# assert expected logit dims
82+
assert prefill_logits["logits"].shape == prefill_outputs["logits"].shape
83+
assert decode_logits["logits"].shape == decode_outputs["logits"].shape
84+
85+
86+
@pytest.mark.parametrize("device_group,num_speculative_tokens,prompt_len,ctx_len,prefill_bsz,full_batch_size,model_name", configs)
87+
def test_llama_dlm_logit_dims(
88+
device_group: List[int],
89+
num_speculative_tokens: int,
90+
prompt_len: int,
91+
ctx_len: int,
92+
prefill_bsz: int,
93+
full_batch_size: int,
94+
model_name: str
95+
):
96+
97+
# get vocab size
98+
tokenizer = AutoTokenizer.from_pretrained(model_name)
99+
vocab_size = len(tokenizer)
100+
101+
# export_and_compile tlm model
102+
qeff_model = AutoModelForCausalLM.from_pretrained(model_name, is_dlm=True)
103+
qpc_path: str = qeff_model.export_and_compile(
104+
num_cores=16,
105+
device_group=device_group,
106+
batch_size=prefill_bsz,
107+
prompt_len=prompt_len,
108+
ctx_len=ctx_len,
109+
mxfp6=True,
110+
mxint8=True,
111+
full_batch_size=full_batch_size
112+
)
113+
114+
# init qaic session
115+
session = QAICInferenceSession(qpc_path, device_ids=device_group)
116+
# skip inputs/outputs buffers
117+
session.skip_buffers(
118+
set([x for x in session.input_names if x.startswith("past_")])
119+
)
120+
session.skip_buffers(
121+
set([x for x in session.output_names if x.endswith("_RetainedState")])
122+
)
123+
# prefill dummy inputs
124+
prefill_inputs = dict(
125+
input_ids = np.zeros((prefill_bsz, prompt_len), dtype=np.int64),
126+
position_ids = np.arange(prompt_len, dtype=np.int64).reshape(-1,1).repeat(prefill_bsz,1).transpose(),
127+
batch_index=np.arange(prefill_bsz, dtype=np.int64).reshape(-1,1)
128+
)
129+
# decode-1 dummy inputs
130+
decode1_inputs = dict(
131+
input_ids = np.zeros((full_batch_size, 1), dtype=np.int64),
132+
position_ids = np.full((full_batch_size, 1), -1, dtype=np.int64),
133+
batch_index=np.arange(full_batch_size, dtype=np.int64).reshape(-1,1)
134+
)
135+
# decode-2 dummy inputs
136+
decode2_inputs = dict(
137+
input_ids = np.zeros((full_batch_size, 2), dtype=np.int64),
138+
position_ids = np.full((full_batch_size, 2), -1, dtype=np.int64),
139+
batch_index=np.arange(full_batch_size, dtype=np.int64).reshape(-1,1)
140+
)
141+
# create dummy logits
142+
prefill_logits = dict(logits=np.random.randn(prefill_bsz, 1, vocab_size).astype(np.float32))
143+
decode_logits = dict(logits=np.random.randn(full_batch_size, 1, vocab_size).astype(np.float32))
144+
# get prefill/decode logits
145+
session.set_buffers(prefill_logits)
146+
prefill_outputs = session.run(prefill_inputs)
147+
session.set_buffers(decode_logits)
148+
decode1_outputs = session.run(decode1_inputs)
149+
decode2_outputs = session.run(decode2_inputs)
150+
151+
152+
# assert expected logit dims
153+
assert prefill_logits["logits"].shape == prefill_outputs["logits"].shape
154+
assert decode_logits["logits"].shape == decode1_outputs["logits"].shape
155+
assert decode_logits["logits"].shape == decode2_outputs["logits"].shape

0 commit comments

Comments
 (0)