@@ -18,13 +18,13 @@ class QEffInternVLModel(nn.Module):
18
18
def get_specializations (
19
19
self , batch_size : int , prefill_seq_len : int , ctx_len : int , img_size : int , ** compiler_options
20
20
):
21
- # TODO: check if this should be named num_patches or something else
22
- num_patches = compiler_options .get ("num_patches " , None )
23
- if num_patches is None :
21
+ # TODO: check if this should be named num_crops or something else
22
+ num_crops = compiler_options .get ("num_crops " , None )
23
+ if num_crops is None :
24
24
logger .warning (
25
- "User should pass `num_patches ` to compile API to fix the dynamic axes `pixel_values`, you can get more info by calling get_inputs_info function!, Since its not found setting its value to 13"
25
+ "User should pass `num_crops ` to compile API to fix the dynamic axes `pixel_values`, you can get more info by calling get_inputs_info function!, Since its not found setting its value to 13"
26
26
)
27
- num_patches = 13
27
+ num_crops = 13
28
28
29
29
prefill_seq_len = prefill_seq_len if prefill_seq_len else 3840 # 4096-256
30
30
ctx_len = ctx_len if ctx_len else 4096
@@ -39,14 +39,14 @@ def get_specializations(
39
39
"batch_size" : batch_size ,
40
40
"seq_len" : prefill_seq_len ,
41
41
"ctx_len" : ctx_len ,
42
- "num_patches " : num_patches ,
42
+ "num_crops " : num_crops ,
43
43
"img_size" : img_size ,
44
44
},
45
45
{
46
46
"batch_size" : batch_size ,
47
47
"seq_len" : "1" ,
48
48
"ctx_len" : ctx_len ,
49
- "num_patches " : num_patches ,
49
+ "num_crops " : num_crops ,
50
50
"img_size" : img_size ,
51
51
},
52
52
]
@@ -58,7 +58,7 @@ def get_onnx_dynamic_axes(
58
58
dynamic_axes = {}
59
59
dynamic_axes ["input_ids" ] = {0 : "batch_size" , 1 : "seq_len" }
60
60
dynamic_axes ["position_ids" ] = {0 : "batch_size" , 1 : "seq_len" }
61
- dynamic_axes ["pixel_values" ] = {0 : "num_patches " , 2 : "img_size" , 3 : "img_size" }
61
+ dynamic_axes ["pixel_values" ] = {0 : "num_crops " , 2 : "img_size" , 3 : "img_size" }
62
62
63
63
pkv_dynamic_axes = {0 : "batch_size" , 2 : "ctx_len" }
64
64
for i in range (self .language_model .config .num_hidden_layers ):
@@ -79,12 +79,12 @@ def get_output_names(
79
79
def get_dummy_inputs (self , kv_offload : bool = False ):
80
80
if kv_offload :
81
81
raise ValueError ("kv_offload method not supported for InternVL yet!" )
82
- num_patches = 13
82
+ NUM_CROPS = 13
83
83
C = 3
84
84
if vis_cfg := getattr (self .config , "vision_config" , None ):
85
- img_size = getattr (vis_cfg , "image_size" , 448 )
85
+ img_size = getattr (vis_cfg , "image_size" , 336 )
86
86
else :
87
- img_size = 448
87
+ img_size = 336
88
88
89
89
# Define shapes
90
90
inputs_shapes = {}
@@ -93,7 +93,7 @@ def get_dummy_inputs(self, kv_offload: bool = False):
93
93
constants .ONNX_EXPORT_EXAMPLE_BATCH_SIZE ,
94
94
constants .ONNX_EXPORT_EXAMPLE_SEQ_LEN ,
95
95
)
96
- inputs_shapes ["pixel_values" ] = (num_patches , C , img_size , img_size )
96
+ inputs_shapes ["pixel_values" ] = (NUM_CROPS , C , img_size , img_size )
97
97
98
98
# Define inputs
99
99
inputs = {}
@@ -143,7 +143,7 @@ def get_inputs_info(self):
143
143
return [
144
144
IOInfo (name = "input_ids" , datatype = torch .int64 , shape = ("batch_size" , "seq_len" )),
145
145
IOInfo (name = "attention_mask" , datatype = torch .int64 , shape = ("batch_size" , "seq_len" )),
146
- IOInfo (name = "pixel_values" , datatype = torch .float32 , shape = ("num_patches " , 3 , "img_size" , "img_size" )),
146
+ IOInfo (name = "pixel_values" , datatype = torch .float32 , shape = ("num_crops " , 3 , "img_size" , "img_size" )),
147
147
]
148
148
149
149
0 commit comments