Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 70 additions & 59 deletions QEfficient/base/onnx_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@
#
# ----------------------------------------------------------------------------

import os
import warnings
from collections import namedtuple
from concurrent.futures import ThreadPoolExecutor
from typing import Optional, Tuple

import numpy as np
from onnx import ModelProto, external_data_helper, numpy_helper
from onnx import ModelProto, TensorProto, external_data_helper, numpy_helper


class OnnxTransform:
Expand All @@ -17,7 +21,7 @@ class OnnxTransform:
"""

def __init__(self):
raise TypeError("Transform classes are not to be instantiated. Directly use the `apply` method.")
raise TypeError("Transform classes are not to be instantiated. Use the `apply` method directly.")

@classmethod
def apply(cls, model: ModelProto, **kwargs) -> Tuple[ModelProto, bool]:
Expand All @@ -32,70 +36,77 @@ def apply(cls, model: ModelProto, **kwargs) -> Tuple[ModelProto, bool]:
raise NotImplementedError("Use subclasses for ONNX transform")


class FP16ClipTransform(OnnxTransform):
"""
Clips the tensor values to be in FP16 range, but preserves -inf values.
"""

@classmethod
def apply(cls, model: ModelProto, *, onnx_base_dir: Optional[str] = None, **kwargs) -> Tuple[ModelProto, bool]:
"""
:param onnx_base_dir: Base directory to load tensors
"""
finfo = np.finfo(np.float16)
fp16_max = finfo.max
fp16_min = finfo.min
transformed = False

for tensor in external_data_helper._get_all_tensors(model):
nptensor = numpy_helper.to_array(tensor, onnx_base_dir)
if nptensor.dtype == np.float32 and (np.any(nptensor > fp16_max) or np.any(nptensor < fp16_min)):
neg_inf_mask = np.isinf(nptensor) & (nptensor < 0)
clipped_tensor = np.clip(nptensor, fp16_min, fp16_max)

# Restore -inf values
if neg_inf_mask.any():
clipped_tensor = np.where(neg_inf_mask, np.float32("-inf"), clipped_tensor)

new_tensor = numpy_helper.from_array(clipped_tensor, tensor.name)
tensor.CopyFrom(new_tensor)
transformed = True

return model, transformed


class SplitTensorsTransform(OnnxTransform):
"""
Split external tensors file
"""

class ClipAndSplitTransform(OnnxTransform):
@classmethod
def apply(
cls,
model: ModelProto,
*,
model_name: str,
model_name: str = "",
onnx_base_dir: Optional[str] = None,
file_chunk_size: int = 10 * 2**30, # 10 GiB
apply_clip: bool = True,
apply_split: bool = True,
file_chunk_size: int = 10 * 2**30,
size_threshold: int = 1024,
**kwargs,
) -> Tuple[ModelProto, bool]:
"""
:param model_name: Used for naming external files. i.e. {model_name}_0.onnx.data
:param onnx_base_dir: Base directory to load tensors (if not already loaded).
:param file_chunk_size: Chunk size to split external files into.
:param size_threshold: Only tensors greater than this threshold (in bytes) will be saved externally.
"""
file_num = 0
current_file_size = 0
transformed = False
if not apply_clip and not apply_split:
warnings.warn("Both apply_clip and apply_split are False. Skipping transformation.")
return model, False

external_data_helper.load_external_data_for_model(model, onnx_base_dir)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even though combining the transform save time in this case, it also reduces the flexibility we have over multiple transforms. In future if we need to add more transforms the condition would become more complex and if its a new transform would need to load the tensors again. I have added few changes as part #538 of the memory clean and reducing the peak memory usage. Can you check if the same concepts can be used here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, its kind off tradeoff between time and flexibility, let me check that will get back to you.

for tensor in external_data_helper._get_all_tensors(model):
if tensor.HasField("raw_data") and ((tsize := len(tensor.raw_data)) > size_threshold):
transformed = True
current_file_size += tsize
if current_file_size > file_chunk_size:
file_num += 1
current_file_size = tsize
external_data_helper.set_external_data(tensor, f"{model_name}_{file_num}.onnx.data")
return model, transformed
tensors = external_data_helper._get_all_tensors(model)

TensorInfo = namedtuple("TensorInfo", ["tensor", "tsize"])
tensor_infos = [
TensorInfo(tensor, len(tensor.raw_data) if tensor.HasField("raw_data") else 0) for tensor in tensors
]

fp16_min, fp16_max = np.finfo(np.float16).min, np.finfo(np.float16).max
file_num_tracker = {"num": 0, "size": 0}

def process_tensor(info: TensorInfo) -> bool:
tensor, tsize = info
transformed_clip = False
transformed_split = False

if apply_clip:
transformed_clip = cls._clip_tensor(tensor, onnx_base_dir, fp16_min, fp16_max)

if apply_split and tsize > size_threshold:
if file_num_tracker["size"] + tsize > file_chunk_size:
file_num_tracker["num"] += 1
file_num_tracker["size"] = tsize
else:
file_num_tracker["size"] += tsize

cls._split_tensor(tensor, model_name, file_num_tracker["num"])
transformed_split = True

if apply_clip and apply_split:
return transformed_clip and transformed_split
return transformed_clip or transformed_split

with ThreadPoolExecutor(max_workers=os.cpu_count() * 4) as executor:
transformed_flags = list(executor.map(process_tensor, tensor_infos))
return model, any(transformed_flags)

@staticmethod
def _clip_tensor(tensor, onnx_base_dir, fp16_min, fp16_max) -> bool:
if tensor.data_type != TensorProto.FLOAT:
return False

nptensor = numpy_helper.to_array(tensor, onnx_base_dir)
if np.any(nptensor > fp16_max) or np.any(nptensor < fp16_min):
neg_inf_mask = np.isinf(nptensor) & (nptensor < 0)
clipped_tensor = np.clip(nptensor, fp16_min, fp16_max)
if neg_inf_mask.any():
clipped_tensor = np.where(neg_inf_mask, np.float32("-inf"), clipped_tensor)
new_tensor = numpy_helper.from_array(clipped_tensor, tensor.name)
tensor.CopyFrom(new_tensor)
return True
return False

@staticmethod
def _split_tensor(tensor, model_name: str, file_num: int) -> None:
external_data_helper.set_external_data(tensor, f"{model_name}_{file_num}.onnx.data")
6 changes: 4 additions & 2 deletions QEfficient/exporter/export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch
from onnx import external_data_helper

from QEfficient.base.onnx_transforms import FP16ClipTransform
from QEfficient.base.onnx_transforms import ClipAndSplitTransform


def export_onnx(
Expand Down Expand Up @@ -218,7 +218,9 @@ def fix_onnx_fp16(
:str: Updated base name of exported ONNX model.
"""
model = onnx.load(os.path.join(gen_models_path, f"{model_base_name}.onnx"))
model, fp16_fix = FP16ClipTransform.apply(model, onnx_base_dir=gen_models_path)
model, fp16_fix = ClipAndSplitTransform.apply(
model, model_name="", onnx_base_dir=gen_models_path, apply_split=False
)

if fp16_fix:
# Save FP16 model
Expand Down
4 changes: 2 additions & 2 deletions QEfficient/peft/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from transformers.generation.streamers import BaseStreamer

from QEfficient.base.modeling_qeff import QEFFBaseModel
from QEfficient.base.onnx_transforms import FP16ClipTransform, OnnxTransform, SplitTensorsTransform
from QEfficient.base.onnx_transforms import ClipAndSplitTransform, OnnxTransform
from QEfficient.base.pytorch_transforms import PytorchTransform
from QEfficient.generation.cloud_infer import QAICInferenceSession
from QEfficient.peft.lora import QEffAutoLoraModelForCausalLM
Expand Down Expand Up @@ -58,7 +58,7 @@ class QEffAutoPeftModelForCausalLM(QEFFBaseModel):
"""

_pytorch_transforms: List[PytorchTransform] = [CustomOpsTransform, KVCacheTransform, PeftModelInputsTransform]
_onnx_transforms: List[OnnxTransform] = [FP16ClipTransform, AdapterWeightsToInputsTransform, SplitTensorsTransform]
_onnx_transforms: List[OnnxTransform] = [ClipAndSplitTransform, AdapterWeightsToInputsTransform]
_hf_auto_class = AutoPeftModelForCausalLM

def __init__(self, model: nn.Module):
Expand Down
14 changes: 7 additions & 7 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

import QEfficient
from QEfficient.base.modeling_qeff import QEFFBaseModel
from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform
from QEfficient.base.onnx_transforms import ClipAndSplitTransform
from QEfficient.base.pytorch_transforms import SplitGateUpWeightsTransform
from QEfficient.generation.cloud_infer import QAICInferenceSession
from QEfficient.generation.text_generation_inference import (
Expand Down Expand Up @@ -159,7 +159,7 @@ class QEFFAutoModel(QEFFTransformersBase):

_hf_auto_class = AutoModel
_pytorch_transforms = [CustomOpsTransform, AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform]
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
_onnx_transforms = [ClipAndSplitTransform]

def __init__(self, model: nn.Module, pooling=None, **kwargs):
super().__init__(model, **kwargs)
Expand Down Expand Up @@ -426,7 +426,7 @@ class QEffVisionEncoderForTextImageToTextModel(QEFFBaseModel):
KVCacheTransform,
KVCacheExternalModuleMapperTransform,
]
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
_onnx_transforms = [ClipAndSplitTransform]

def __init__(self, model: nn.modules, **kwargs):
super().__init__(model, **kwargs)
Expand Down Expand Up @@ -483,7 +483,7 @@ class QEffCausalLMForTextImageToTextModel(QEFFBaseModel):
VlmKVOffloadTransform,
SplitGateUpWeightsTransform,
]
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
_onnx_transforms = [ClipAndSplitTransform]

def __init__(self, model, **kwargs):
super().__init__(model, **kwargs)
Expand Down Expand Up @@ -898,7 +898,7 @@ class _QEFFAutoModelForImageTextToTextSingleQPC(QEFFTransformersBase, Multimodal
VlmNoKVOffloadTransform,
SplitGateUpWeightsTransform,
]
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
_onnx_transforms = [ClipAndSplitTransform]

def __init__(
self,
Expand Down Expand Up @@ -1330,7 +1330,7 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel):
SplitGateUpWeightsTransform,
KVCacheExternalModuleMapperTransform,
]
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
_onnx_transforms = [ClipAndSplitTransform]

def __init__(
self,
Expand Down Expand Up @@ -1896,7 +1896,7 @@ class QEFFAutoModelForSpeechSeq2Seq(QEFFTransformersBase, MultimodalUtilityMixin

_hf_auto_class = AutoModelForSpeechSeq2Seq
_pytorch_transforms = [CustomOpsTransform, AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform, KVCacheTransform]
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
_onnx_transforms = [ClipAndSplitTransform]

def __init__(self, model: nn.Module, **kwargs):
model_class_name = model.__class__.__name__
Expand Down
11 changes: 7 additions & 4 deletions tests/base/test_onnx_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np
import onnx

from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform
from QEfficient.base.onnx_transforms import ClipAndSplitTransform


def test_fp16clip_transform():
Expand All @@ -32,7 +32,7 @@ def test_fp16clip_transform():
}
""")
onnx.checker.check_model(test_onnx, True, True, True)
transformed_onnx, transformed = FP16ClipTransform.apply(test_onnx)
transformed_onnx, transformed = ClipAndSplitTransform.apply(test_onnx, apply_split=False)
assert transformed
assert onnx.numpy_helper.to_array(transformed_onnx.graph.initializer[0]) == 65504.0
assert onnx.numpy_helper.to_array(transformed_onnx.graph.initializer[1]) == 2147483647
Expand Down Expand Up @@ -63,7 +63,9 @@ def test_fp16clip_transform_external(tmp_path):
np.array(-1e10, dtype="float32").tofile(tmp_path / external_tensors_file)
onnx.checker.check_model(onnx_path, True, True, True)

transformed_onnx, transformed = FP16ClipTransform.apply(test_onnx, onnx_base_dir=str(tmp_path))
transformed_onnx, transformed = ClipAndSplitTransform.apply(
test_onnx, onnx_base_dir=str(tmp_path), apply_split=False
)
assert transformed
assert onnx.numpy_helper.to_array(transformed_onnx.graph.initializer[0]) == -65504.0

Expand Down Expand Up @@ -92,12 +94,13 @@ def test_split_tensors_transform(tmp_path):
tensors.tofile(tmp_path / external_tensors_file)
onnx.checker.check_model(onnx_path, True, True, True)

trans_onnx, transformed = SplitTensorsTransform.apply(
trans_onnx, transformed = ClipAndSplitTransform.apply(
test_onnx,
model_name="test_split",
onnx_base_dir=str(tmp_path),
file_chunk_size=32 * 4,
size_threshold=16 * 4,
apply_clip=True,
)

tensor0_ext_data = onnx.external_data_helper.ExternalDataInfo(trans_onnx.graph.initializer[0])
Expand Down
Loading