-
Notifications
You must be signed in to change notification settings - Fork 56
Optimized ONNX Transform via Class Merging and Thread Pooling #546
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
6361518
e957632
a9b01c3
cfc4809
6910ece
f64b429
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,10 +5,14 @@ | |
# | ||
# ---------------------------------------------------------------------------- | ||
|
||
import os | ||
import warnings | ||
from collections import namedtuple | ||
from concurrent.futures import ThreadPoolExecutor | ||
from typing import Optional, Tuple | ||
|
||
import numpy as np | ||
from onnx import ModelProto, external_data_helper, numpy_helper | ||
from onnx import ModelProto, TensorProto, external_data_helper, numpy_helper | ||
|
||
|
||
class OnnxTransform: | ||
|
@@ -17,7 +21,7 @@ class OnnxTransform: | |
""" | ||
|
||
def __init__(self): | ||
raise TypeError("Transform classes are not to be instantiated. Directly use the `apply` method.") | ||
raise TypeError("Transform classes are not to be instantiated. Use the `apply` method directly.") | ||
|
||
@classmethod | ||
def apply(cls, model: ModelProto, **kwargs) -> Tuple[ModelProto, bool]: | ||
|
@@ -32,70 +36,77 @@ def apply(cls, model: ModelProto, **kwargs) -> Tuple[ModelProto, bool]: | |
raise NotImplementedError("Use subclasses for ONNX transform") | ||
|
||
|
||
class FP16ClipTransform(OnnxTransform): | ||
""" | ||
Clips the tensor values to be in FP16 range, but preserves -inf values. | ||
""" | ||
|
||
@classmethod | ||
def apply(cls, model: ModelProto, *, onnx_base_dir: Optional[str] = None, **kwargs) -> Tuple[ModelProto, bool]: | ||
""" | ||
:param onnx_base_dir: Base directory to load tensors | ||
""" | ||
finfo = np.finfo(np.float16) | ||
fp16_max = finfo.max | ||
fp16_min = finfo.min | ||
transformed = False | ||
|
||
for tensor in external_data_helper._get_all_tensors(model): | ||
nptensor = numpy_helper.to_array(tensor, onnx_base_dir) | ||
if nptensor.dtype == np.float32 and (np.any(nptensor > fp16_max) or np.any(nptensor < fp16_min)): | ||
neg_inf_mask = np.isinf(nptensor) & (nptensor < 0) | ||
clipped_tensor = np.clip(nptensor, fp16_min, fp16_max) | ||
|
||
# Restore -inf values | ||
if neg_inf_mask.any(): | ||
clipped_tensor = np.where(neg_inf_mask, np.float32("-inf"), clipped_tensor) | ||
|
||
new_tensor = numpy_helper.from_array(clipped_tensor, tensor.name) | ||
tensor.CopyFrom(new_tensor) | ||
transformed = True | ||
|
||
return model, transformed | ||
|
||
|
||
class SplitTensorsTransform(OnnxTransform): | ||
""" | ||
Split external tensors file | ||
""" | ||
|
||
class ClipAndSplitTransform(OnnxTransform): | ||
@classmethod | ||
def apply( | ||
cls, | ||
model: ModelProto, | ||
*, | ||
model_name: str, | ||
model_name: str = "", | ||
onnx_base_dir: Optional[str] = None, | ||
file_chunk_size: int = 10 * 2**30, # 10 GiB | ||
apply_clip: bool = True, | ||
apply_split: bool = True, | ||
file_chunk_size: int = 10 * 2**30, | ||
size_threshold: int = 1024, | ||
**kwargs, | ||
) -> Tuple[ModelProto, bool]: | ||
""" | ||
:param model_name: Used for naming external files. i.e. {model_name}_0.onnx.data | ||
:param onnx_base_dir: Base directory to load tensors (if not already loaded). | ||
:param file_chunk_size: Chunk size to split external files into. | ||
:param size_threshold: Only tensors greater than this threshold (in bytes) will be saved externally. | ||
""" | ||
file_num = 0 | ||
current_file_size = 0 | ||
transformed = False | ||
if not apply_clip and not apply_split: | ||
warnings.warn("Both apply_clip and apply_split are False. Skipping transformation.") | ||
return model, False | ||
|
||
external_data_helper.load_external_data_for_model(model, onnx_base_dir) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Even though combining the transform save time in this case, it also reduces the flexibility we have over multiple transforms. In future if we need to add more transforms the condition would become more complex and if its a new transform would need to load the tensors again. I have added few changes as part #538 of the memory clean and reducing the peak memory usage. Can you check if the same concepts can be used here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, its kind off tradeoff between time and flexibility, let me check that will get back to you. |
||
for tensor in external_data_helper._get_all_tensors(model): | ||
if tensor.HasField("raw_data") and ((tsize := len(tensor.raw_data)) > size_threshold): | ||
transformed = True | ||
current_file_size += tsize | ||
if current_file_size > file_chunk_size: | ||
file_num += 1 | ||
current_file_size = tsize | ||
external_data_helper.set_external_data(tensor, f"{model_name}_{file_num}.onnx.data") | ||
return model, transformed | ||
tensors = external_data_helper._get_all_tensors(model) | ||
|
||
TensorInfo = namedtuple("TensorInfo", ["tensor", "tsize"]) | ||
tensor_infos = [ | ||
TensorInfo(tensor, len(tensor.raw_data) if tensor.HasField("raw_data") else 0) for tensor in tensors | ||
] | ||
abhishek-singh591 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
fp16_min, fp16_max = np.finfo(np.float16).min, np.finfo(np.float16).max | ||
file_num_tracker = {"num": 0, "size": 0} | ||
|
||
def process_tensor(info: TensorInfo) -> bool: | ||
tensor, tsize = info | ||
transformed_clip = False | ||
transformed_split = False | ||
|
||
if apply_clip: | ||
transformed_clip = cls._clip_tensor(tensor, onnx_base_dir, fp16_min, fp16_max) | ||
|
||
if apply_split and tsize > size_threshold: | ||
if file_num_tracker["size"] + tsize > file_chunk_size: | ||
file_num_tracker["num"] += 1 | ||
file_num_tracker["size"] = tsize | ||
else: | ||
file_num_tracker["size"] += tsize | ||
|
||
cls._split_tensor(tensor, model_name, file_num_tracker["num"]) | ||
transformed_split = True | ||
|
||
if apply_clip and apply_split: | ||
return transformed_clip and transformed_split | ||
return transformed_clip or transformed_split | ||
|
||
with ThreadPoolExecutor(max_workers=os.cpu_count() * 4) as executor: | ||
transformed_flags = list(executor.map(process_tensor, tensor_infos)) | ||
return model, any(transformed_flags) | ||
abhishek-singh591 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
@staticmethod | ||
def _clip_tensor(tensor, onnx_base_dir, fp16_min, fp16_max) -> bool: | ||
if tensor.data_type != TensorProto.FLOAT: | ||
return False | ||
|
||
nptensor = numpy_helper.to_array(tensor, onnx_base_dir) | ||
if np.any(nptensor > fp16_max) or np.any(nptensor < fp16_min): | ||
neg_inf_mask = np.isinf(nptensor) & (nptensor < 0) | ||
clipped_tensor = np.clip(nptensor, fp16_min, fp16_max) | ||
if neg_inf_mask.any(): | ||
clipped_tensor = np.where(neg_inf_mask, np.float32("-inf"), clipped_tensor) | ||
new_tensor = numpy_helper.from_array(clipped_tensor, tensor.name) | ||
tensor.CopyFrom(new_tensor) | ||
return True | ||
return False | ||
|
||
@staticmethod | ||
def _split_tensor(tensor, model_name: str, file_num: int) -> None: | ||
external_data_helper.set_external_data(tensor, f"{model_name}_{file_num}.onnx.data") |
Uh oh!
There was an error while loading. Please reload this page.