Skip to content

Commit f59a988

Browse files
quic-agokhaleeplatero97ochougul
authored
[eplatero] Add support for exporting and compiling models for SpD (#119)
* rebasing with main. previous local gen_spd_models was broken since it was not picking up latest changes from main. as such, I found common ancestor, picked up latest changes from main, and made new commit to contain all unique changes for pr 119 Signed-off-by: eplatero <[email protected]> * add decode_seq_len to non-continuous batching case Signed-off-by: eplatero <[email protected]> * mirror test_causal_lm_models.py from main Signed-off-by: eplatero <[email protected]> * add more to the explanation of the model changes Signed-off-by: eplatero <[email protected]> * lint fixing Signed-off-by: eplatero <[email protected]> * alphabetical order imports on pytorch_transforms.py Signed-off-by: eplatero <[email protected]> * add init to spd directory Signed-off-by: eplatero <[email protected]> * replace modifying seq_len by letting user define proper config Signed-off-by: eplatero <[email protected]> * resolving 1st round comments from Onkar and made fix on gather implementation Signed-off-by: eplatero <[email protected]> * removing old unit tests Signed-off-by: eplatero <[email protected]> * * Added way to make num_logits_to_keep dynamic in ONNX and removed need to regenerate ONNX for different values of num_logits_to_keep only qpc is recompiled, * ran formatter , * reorganized pytorch transforms Signed-off-by: Onkar Chougule <[email protected]> * changed interface to be similar to CB Signed-off-by: Onkar Chougule <[email protected]> * made unit tests work with array approach Signed-off-by: eplatero <[email protected]> * for TLM, made specialization return 1 logit for prefill and for decode Signed-off-by: eplatero <[email protected]> * moved from to method because this flag only has implications for compile stage, not export Signed-off-by: eplatero <[email protected]> * fixing qpc directory naming to be backwards compatible Signed-off-by: eplatero <[email protected]> * updating docstrings and documentation Signed-off-by: eplatero <[email protected]> * revert changes to CLI exportation of onnx and specialization to reflect state in main branch Signed-off-by: eplatero <[email protected]> * fixed specializations creation and ran formatter Signed-off-by: Onkar Chougule <[email protected]> * add pytorch-level unit test Signed-off-by: eplatero <[email protected]> * uncommented non-llama pytorch-level unit test Signed-off-by: eplatero <[email protected]> * modified pytorch level unit test and added hf vs ort vs qaic unit test Signed-off-by: eplatero <[email protected]> * change llama test model from jackfram to tinyllama to match other tests Signed-off-by: eplatero <[email protected]> * fix failing tlm_dlm tests by passing is_tlm correctly in modeling_auto Signed-off-by: eplatero <[email protected]> * rm dlm specialization Signed-off-by: eplatero <[email protected]> * updated quick_docs Signed-off-by: eplatero <[email protected]> * rm tlm dims test since that's already tested and generalize common code in pytorch_transforms Signed-off-by: eplatero <[email protected]> * rm flag from non-test definition Signed-off-by: eplatero <[email protected]> * rm unnecessary function that is not used Signed-off-by: eplatero <[email protected]> * ran formatter and linter Signed-off-by: Onkar Chougule <[email protected]> --------- Signed-off-by: eplatero <[email protected]> Signed-off-by: Onkar Chougule <[email protected]> Co-authored-by: eplatero <[email protected]> Co-authored-by: Onkar Chougule <[email protected]>
1 parent ad1b1cf commit f59a988

File tree

13 files changed

+477
-64
lines changed

13 files changed

+477
-64
lines changed

QEfficient/base/modeling_qeff.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ def _compile(
201201
specializations: Optional[List[Dict[str, int]]] = None,
202202
custom_io: Optional[Dict[str, str]] = None,
203203
mdp_ts_num_devices: int = 1,
204+
num_speculative_tokens: Optional[int] = None,
204205
**compiler_options,
205206
) -> str:
206207
"""
@@ -212,6 +213,7 @@ def _compile(
212213
:specializations (list): List of specializations to compile for
213214
:custom_io (dict): Custom IO to specify the input and outputs in different formats than default
214215
:mdp_ts_num_devices (int): Number of devices to partition to use Multi-Device Partitioning with tensor-slicing.
216+
:num_speculative_tokens (int, optional): Number of speculative tokens to take as input for Speculative Decoding Target Language Model.
215217
:compiler_options: Pass any compiler option as input. Any flag that is supported by `qaic-exec` can be passed. Params are converted to flags as below:
216218
- aic_num_cores=16 -> -aic-num-cores=16
217219
- convert_to_fp16=True -> -convert-to-fp16
@@ -244,6 +246,9 @@ def _compile(
244246
if mdp_ts_num_devices > 1:
245247
compile_hash.update(to_hashable({"mdp_ts_num_devices": mdp_ts_num_devices}))
246248

249+
if num_speculative_tokens:
250+
compile_hash.update(to_hashable({"num_speculative_tokens": num_speculative_tokens}))
251+
247252
# Check if already compiled
248253
compile_hash = compile_hash.hexdigest()[:16]
249254
qpc_path = qpc_path.with_name(qpc_path.name + "-" + compile_hash)

QEfficient/generation/text_generation_inference.py

Lines changed: 51 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,7 @@ def cloud_ai_100_exec_kv(
274274
write_io_dir: Optional[str] = None,
275275
automation=False,
276276
prompt_to_lora_id_mapping: Optional[List[int]] = None,
277+
is_tlm: bool = False,
277278
):
278279
"""
279280
This method generates output until ``eos`` or ``generation_len`` by executing the compiled ``qpc`` on ``Cloud AI 100`` Hardware cards.
@@ -319,6 +320,7 @@ def cloud_ai_100_exec_kv(
319320
enable_debug_logs=enable_debug_logs,
320321
write_io_dir=write_io_dir,
321322
full_batch_size=full_batch_size,
323+
is_tlm=is_tlm,
322324
)
323325
if full_batch_size is None:
324326
exec_info = [
@@ -355,16 +357,19 @@ def __init__(
355357
device_id: Optional[List[int]] = None,
356358
enable_debug_logs: bool = False,
357359
write_io_dir: Optional[str] = None,
360+
is_tlm: Optional[int] = None,
358361
) -> None:
359362
self._ctx_len = ctx_len
360363
self._write_io_dir = write_io_dir
364+
self.is_tlm = is_tlm
361365

362366
# Load QPC
363367
self._session = QAICInferenceSession(qpc_path, device_id, enable_debug_logs=enable_debug_logs)
364368

365369
# Fetch the variables from the QPC
366370
self._vocab_size = self._fetch_vocab_size() # Fetch Vocab size
367371
self.batch_size, self._prefill_seq_len = self._fetch_batch_size_prefill_seq_len()
372+
self._decode_seq_len = self._fetch_decode_seq_len()
368373
self.full_batch_size = (
369374
full_batch_size if full_batch_size else self._fetch_full_batch_size()
370375
) # Check and fetch full batch size if CB is enabled
@@ -441,6 +446,22 @@ def _fetch_batch_size_prefill_seq_len(
441446
batch_size, prefill_seq_len = self._session.bindings[self._session.binding_index_map["input_ids"]].dims
442447
return batch_size, prefill_seq_len
443448

449+
def _fetch_decode_seq_len(
450+
self,
451+
):
452+
"""
453+
Fetches the decode sequence length from the session's bindings or allowed shapes.
454+
455+
Returns:
456+
decode_seq_len: The decode sequence length fetched from the session's bindings or allowed shapes.
457+
"""
458+
decode_seq_len = None
459+
if self._session.allowed_shapes:
460+
decode_seq_len = min(
461+
[x[self._session.binding_index_map["input_ids"]][1][1] for x in self._session.allowed_shapes]
462+
)
463+
return decode_seq_len
464+
444465
def _fetch_vocab_size(
445466
self,
446467
):
@@ -485,9 +506,19 @@ def prepare_decode_inputs(self):
485506
Returns:
486507
dict: The decode inputs.
487508
"""
509+
batch_size = self.full_batch_size if self.full_batch_size is not None else self.batch_size
488510
decode_inputs = {}
489-
decode_inputs["input_ids"] = self.decode_input_ids
490-
decode_inputs["position_ids"] = self.decode_pos_ids
511+
if self.is_tlm:
512+
position_ids = np.full((batch_size, self._decode_seq_len), -1, dtype=np.int64)
513+
position_ids[:, -1] = self.decode_pos_ids.flatten()
514+
input_ids = np.zeros((batch_size, self._decode_seq_len), dtype=np.int64)
515+
input_ids[:, -1] = self.decode_input_ids.flatten()
516+
decode_inputs["input_ids"] = input_ids
517+
decode_inputs["position_ids"] = position_ids
518+
decode_inputs["num_logits_to_keep"] = np.zeros((self._decode_seq_len, 1))
519+
else:
520+
decode_inputs["input_ids"] = self.decode_input_ids
521+
decode_inputs["position_ids"] = self.decode_pos_ids
491522
if self.batch_index is not None:
492523
decode_inputs["batch_index"] = self.batch_index
493524

@@ -628,6 +659,8 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
628659

629660
if decode_batch_id is not None:
630661
inputs["batch_index"] = decode_batch_id
662+
if self.is_tlm:
663+
inputs["num_logits_to_keep"] = np.zeros((1, 1))
631664

632665
if self._prompt_to_lora_id_mapping_prefill:
633666
if self.full_batch_size:
@@ -668,7 +701,9 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
668701
"""
669702

670703
# Set logits placeholder for decode
671-
logits_out_placeholder = np.zeros((self.full_batch_size, 1, self._vocab_size), dtype=np.float32)
704+
logits_out_placeholder = np.zeros(
705+
(self.full_batch_size, self._decode_seq_len, self._vocab_size), dtype=np.float32
706+
)
672707
self._session.set_buffers({"logits": logits_out_placeholder})
673708
# Generate flag for tracking progress for each batch ID
674709
current_decode_ongoing = np.full((self.full_batch_size, 1), True)
@@ -694,7 +729,7 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
694729

695730
for decode_batch_id in range(self.full_batch_size):
696731
if (
697-
next_token_id[decode_batch_id] == self.tokenizer.eos_token_id
732+
next_token_id[decode_batch_id, -1] == self.tokenizer.eos_token_id
698733
or generated_id_current_index[decode_batch_id] >= self.generation_len[decode_batch_id]
699734
):
700735
if prompt_queue:
@@ -724,10 +759,10 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
724759
current_decode_ongoing[decode_batch_id] = False
725760
else:
726761
# If the generated sequence is valid and within generation len prepare for next decode
727-
decode_inputs["input_ids"][decode_batch_id] = next_token_id[decode_batch_id]
728-
decode_inputs["position_ids"][decode_batch_id] += 1
762+
decode_inputs["input_ids"][decode_batch_id, -1] = next_token_id[decode_batch_id, -1]
763+
decode_inputs["position_ids"][decode_batch_id, -1] += 1
729764
self.generated_ids[batch_id_map[decode_batch_id], generated_id_current_index[decode_batch_id]] = (
730-
next_token_id[decode_batch_id]
765+
next_token_id[decode_batch_id, -1]
731766
)
732767

733768
generated_id_current_index[decode_batch_id] += 1
@@ -747,6 +782,11 @@ def run_decode(self, decode_inputs, generation_len, streamer: Optional[transform
747782
Returns:
748783
num_token (int): The number of tokens processed in the decoding process.
749784
"""
785+
if self.is_tlm:
786+
logits_out_placeholder = np.zeros(
787+
(self.batch_size, self._decode_seq_len, self._vocab_size), dtype=np.float32
788+
)
789+
self._session.set_buffers({"logits": logits_out_placeholder})
750790
finished_sequences = decode_inputs["input_ids"] == self.tokenizer.eos_token_id
751791
num_token = 0
752792
for num_token in range(1, generation_len):
@@ -760,8 +800,8 @@ def run_decode(self, decode_inputs, generation_len, streamer: Optional[transform
760800

761801
# Prepare inputs for next iteration
762802
decode_inputs["input_ids"] = outputs["logits"].argmax(2)
763-
decode_inputs["position_ids"] += 1
764-
self.generated_ids[:, num_token] = decode_inputs["input_ids"].squeeze(1)
803+
decode_inputs["position_ids"][:, -1] += 1
804+
self.generated_ids[:, num_token] = decode_inputs["input_ids"][:, -1]
765805
finished_sequences |= decode_inputs["input_ids"] == self.tokenizer.eos_token_id
766806

767807
if finished_sequences.all():
@@ -811,9 +851,10 @@ def __init__(
811851
device_id: Optional[List[int]] = None,
812852
enable_debug_logs: bool = False,
813853
write_io_dir: Optional[str] = None,
854+
is_tlm: bool = False,
814855
) -> None:
815856
self._qaic_model = QEffTextGenerationBase(
816-
tokenizer, qpc_path, full_batch_size, ctx_len, device_id, enable_debug_logs, write_io_dir
857+
tokenizer, qpc_path, full_batch_size, ctx_len, device_id, enable_debug_logs, write_io_dir, is_tlm
817858
)
818859
self._full_batch_size = self._qaic_model.full_batch_size
819860
self._tokenizer = self._qaic_model.tokenizer

QEfficient/peft/auto.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from QEfficient.peft.lora import QEffAutoLoraModelForCausalLM
2525
from QEfficient.peft.onnx_transforms import AdapterWeightsToInputsTransform
2626
from QEfficient.peft.pytorch_transforms import PeftModelInputsTransform
27-
from QEfficient.transformers.pytorch_transforms import CustomOpsTransform, KVCacheTransform
27+
from QEfficient.transformers.models.pytorch_transforms import CustomOpsTransform, KVCacheTransform
2828
from QEfficient.utils import constants
2929
from QEfficient.utils._utils import get_padding_shape_from_config
3030
from QEfficient.utils.cache import to_hashable

0 commit comments

Comments
 (0)