-
Notifications
You must be signed in to change notification settings - Fork 57
Support of AutoModel #192
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Support of AutoModel #192
Changes from 14 commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
0b80dac
Docker-driven tests with latest SDKs (#180)
quic-amitraj e83b29a
Added support for embedding models
quic-amitraj be59287
Lint & Format
quic-amitraj 1284155
Added batch_size
quic-amitraj 3f95df7
Docstring added
quic-amitraj 74ffc16
Fix-1
quic-amitraj 2fb41ad
Comments Addressed-1
quic-amitraj 262f45e
Comments addressed-2
quic-amitraj ba0258b
Lint and formatted
quic-amitraj ba66c75
Comments addressed-3
quic-amitraj 38a4186
Fix-2
quic-amitraj 4401fd6
Comments addressed-4
quic-amitraj 206c81a
Minor fix-1
quic-amitraj 0f1f8bb
fix-major
quic-amitraj 6c9de4b
fix-minor-2
quic-amitraj 88e0fe6
fix-minor-3
quic-amitraj 157142a
Update ONNX_EXPORT_OPSET to 13
quic-amitraj File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
# ----------------------------------------------------------------------------- | ||
# | ||
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
# | ||
# ----------------------------------------------------------------------------- | ||
|
||
|
||
import numpy as np | ||
import onnxruntime as ort | ||
import pytest | ||
import torch | ||
from transformers import AutoModel, AutoTokenizer | ||
|
||
from QEfficient.transformers.models.modeling_auto import QEFFAutoModel | ||
from QEfficient.utils import hf_download | ||
from QEfficient.utils.constants import Constants | ||
|
||
embed_test_models = [ | ||
# model_name, architecture | ||
"sentence-transformers/multi-qa-mpnet-base-cos-v1", # MPNetForMaskedLM | ||
"BAAI/bge-reranker-v2-m3", # XLMRobertaForSequenceClassification | ||
"BAAI/bge-small-en-v1.5", # BertModel | ||
] | ||
|
||
|
||
def check_embed_pytorch_vs_ort_vs_ai100( | ||
model_name: str, | ||
seq_len: int = Constants.CTX_LEN, | ||
n_layer: int = 1, | ||
): | ||
model_path = hf_download( | ||
repo_id=model_name, | ||
ignore_patterns=["*.onnx", "*.ot", "*.md", "*.tflite", "*.pdf", "*.h5", "*.msgpack"], | ||
) | ||
# Prepare input | ||
tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
inputs = tokenizer("My name is", return_tensors="pt") | ||
|
||
input_ids = torch.nn.functional.pad(inputs["input_ids"], (0, seq_len - inputs["input_ids"].size(1)), "constant", 0) | ||
attention_mask = torch.nn.functional.pad( | ||
inputs["attention_mask"], (0, seq_len - inputs["attention_mask"].size(1)), "constant", 0 | ||
) | ||
inputs = dict(input_ids=input_ids, attention_mask=attention_mask) | ||
|
||
# Original PyTorch model | ||
pt_model = AutoModel.from_pretrained( | ||
model_path, | ||
num_hidden_layers=n_layer, | ||
attn_implementation="eager", | ||
trust_remote_code=True, | ||
) | ||
|
||
pt_outputs = pt_model(**inputs) | ||
pt_embeddings = pt_outputs[0][0].detach().numpy() | ||
|
||
# Pytorch transformed model | ||
qeff_model = QEFFAutoModel.from_pretrained( | ||
pretrained_model_name_or_path=model_path, | ||
num_hidden_layers=n_layer, | ||
attn_implementation="eager", | ||
trust_remote_code=True, | ||
) | ||
quic-amitraj marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
qeff_pt_outputs = qeff_model.generate(inputs=inputs, runtime_ai100=False) | ||
qeff_pt_embeddings = qeff_pt_outputs[0][0].detach().numpy() | ||
mad = np.mean(np.abs(pt_embeddings - qeff_pt_embeddings)) | ||
print("Mad for PyTorch and PyTorch transformed qeff_model is ", mad) | ||
assert mad <= 0, f"MAD is too high for onnx and Pytorch: {mad}" | ||
|
||
onnx_model = qeff_model.export() | ||
ort_session = ort.InferenceSession(str(onnx_model)) | ||
|
||
# Prepare the inputs for ONNX Runtime | ||
input_ids = np.array(input_ids) | ||
attention_mask = np.array(attention_mask) | ||
|
||
onnx_inputs = {"input_ids": input_ids, "attention_mask": attention_mask} | ||
# Run inference | ||
onnx_outputs = ort_session.run(None, onnx_inputs) | ||
|
||
# Compare Transformed PyTorch and ONNX outputs | ||
pt_embeddings = pt_outputs[0][0].detach().numpy() | ||
onnx_embeddings = onnx_outputs[0] | ||
mad = np.mean(np.abs(pt_embeddings - onnx_embeddings)) | ||
print("Mad for onnx and PyTorch is ", mad) | ||
assert mad <= 10**-5, f"MAD is too high for onnx and Pytorch: {mad}" | ||
|
||
qeff_model.compile( | ||
num_cores=14, | ||
) | ||
ai100_output = qeff_model.generate(inputs=inputs) | ||
|
||
# Compare ONNX and AI 100 outputs | ||
mad = np.mean(np.abs(ai100_output["output"] - onnx_outputs[0])) | ||
print("Mad for onnx and AI 100 output is ", mad) | ||
assert mad <= 10**-3, f"MAD is too high for onnx and Pytorch: {mad}" | ||
|
||
|
||
@pytest.mark.on_qaic | ||
@pytest.mark.parametrize("model_name", embed_test_models) | ||
def test_embed_model_pytorch_vs_onnx_vs_ai100(model_name): | ||
""" | ||
Test function to validate output of the Pytorch, ONNX and AI 100 runtime model output. | ||
""" | ||
check_embed_pytorch_vs_ort_vs_ai100(model_name=model_name, seq_len=32, n_layer=1) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.