11
11
12
12
from typing import ClassVar , Sequence
13
13
14
+ import numpy as np
15
+
14
16
from onnxscript import ir
15
17
from onnxscript .rewriter import _ir_utils as ir_utils
16
18
from onnxscript .rewriter ._basics import MatchResult
@@ -123,16 +125,37 @@ def pattern(self, op, x, shape_ignored, shape):
123
125
return op .Reshape (op .Reshape (x , shape_ignored ), shape )
124
126
125
127
def rewrite (self , op , x : ir .Value , shape_ignored : ir .Value , shape : ir .Value ):
126
- return op .Reshape (x , shape )
128
+ new_shape = op .initializer (ir .Tensor (self ._new_shape , name = shape .name ))
129
+ return op .Reshape (x , new_shape , allowzero = self ._allowzero )
127
130
128
131
def check (self , context , x , shape_ignored , shape ) -> MatchResult :
129
132
check_result = MatchResult ()
130
- if shape_ignored . const_value is None :
131
- return check_result . fail ( " Shape ignored is not a constant." )
132
- if shape . const_value is None :
133
+
134
+ # Shape must be a constant.
135
+ if ( np_shape := ir_utils . get_numpy_value ( shape )) is None :
133
136
return check_result .fail ("Shape is not a constant." )
134
- if shape .const_value .numpy ().min () <= 0 :
135
- return check_result .fail ("Shape has non-positive values." )
137
+ # Convert to array to support assignment destination.
138
+ self ._new_shape = np .array (np_shape , np_shape .dtype )
139
+
140
+ # Try to replace {0,-1} values in shape if reshape output is known.
141
+ if (reshape_output := context .output_values [0 ].shape ) is not None :
142
+ for i , dim in enumerate (reshape_output ):
143
+ if isinstance (dim , int ) and dim > 0 :
144
+ self ._new_shape [i ] = dim
145
+
146
+ # Constraints for shape.
147
+ self ._allowzero = context .nodes [0 ].attributes .get_int ("allowzero" , 0 )
148
+ if self ._allowzero == 1 and any (self ._new_shape == 0 ):
149
+ return check_result
150
+ if any (self ._new_shape == 0 ) and any (self ._new_shape < 0 ):
151
+ return check_result .fail ("Shape cannot contain both 0 and -1 dimensions." )
152
+ elif np .count_nonzero (self ._new_shape == 0 ) > 1 :
153
+ return check_result .fail ("Shape cannot contain more than one 0 dimension." )
154
+
155
+ # At this point, we can safely replace '0' with '-1'.
156
+ # Note allowzero is removed since at this point it does not have any effect.
157
+ self ._allowzero = None
158
+ self ._new_shape = np .where (self ._new_shape == 0 , - 1 , self ._new_shape )
136
159
return check_result
137
160
138
161
@@ -279,6 +302,55 @@ def check(self, context, x, axes1, axes2) -> MatchResult:
279
302
return check_result
280
303
281
304
305
+ class Flatten2Reshape (RewriteRuleClassBase ):
306
+ """Convert ``Flatten(x)`` to Reshape."""
307
+
308
+ def pattern (self , op , x : ir .Value ):
309
+ return op .Flatten (x )
310
+
311
+ def rewrite (self , op , x : ir .Value ):
312
+ new_shape = op .initializer (ir .Tensor (self ._new_shape , name = f"{ x .name } /shape" ))
313
+ return op .Reshape (x , new_shape )
314
+
315
+ def check (self , context , x : ir .Value ) -> MatchResult :
316
+ check_result = MatchResult ()
317
+ self ._new_shape = np .array ([- 1 , - 1 ], "int64" )
318
+
319
+ # Convert axis in a positive value if possible.
320
+ axis = context .root .attributes .get_int ("axis" , 1 )
321
+ input_rank = None
322
+ if (input_shape := x .shape ) is not None :
323
+ input_rank = len (input_shape )
324
+ if axis < 0 :
325
+ axis += input_rank
326
+
327
+ # Compute reshape shape following axis attribute.
328
+ if axis == 0 :
329
+ self ._new_shape [0 ] = 1
330
+ elif axis == 1 :
331
+ self ._new_shape [0 ] = 0
332
+ elif axis == input_rank :
333
+ self ._new_shape [1 ] = 1
334
+
335
+ # Try to update shape if output is known.
336
+ if (output_shape := context .output_values [0 ].shape ) is not None :
337
+ for i , dim in enumerate (output_shape ):
338
+ if isinstance (dim , int ):
339
+ self ._new_shape [i ] = dim
340
+
341
+ # Try to update shape if input is known.
342
+ if input_shape is not None :
343
+ if all (isinstance (dim , int ) for dim in input_shape [:axis ]):
344
+ self ._new_shape [0 ] = np .prod (input_shape [:axis ])
345
+ if all (isinstance (dim , int ) for dim in input_shape [axis :]):
346
+ self ._new_shape [1 ] = np .prod (input_shape [axis :])
347
+
348
+ # Verify if it is possible to apply rule.
349
+ if np .count_nonzero (self ._new_shape == - 1 ) > 1 :
350
+ return check_result .fail ("Impossible to compute new shape." )
351
+ return check_result
352
+
353
+
282
354
# Create rule instances
283
355
cast_cast_rule = CastCast .rule ()
284
356
no_op_cast_rule = CastIdentity .rule ()
@@ -289,6 +361,7 @@ def check(self, context, x, axes1, axes2) -> MatchResult:
289
361
transpose_transpose_rule = TransposeTranspose .rule ()
290
362
unsqueeze_unsqueeze_rule = UnsqueezeUnsqueeze .rule ()
291
363
squeeze_reshape_1d_rule = SqueezeReshape .rule ()
364
+ flatten_to_reshape_rule = Flatten2Reshape .rule ()
292
365
293
366
294
367
def basic_optimization_rules () -> RewriteRuleSet :
@@ -311,6 +384,8 @@ def basic_optimization_rules() -> RewriteRuleSet:
311
384
cast_cast_rule ,
312
385
no_op_cast_rule ,
313
386
no_op_expand_rule ,
387
+ # flatten_to_reshape_rule is order sensitive to reshape_reshape_rule
388
+ flatten_to_reshape_rule ,
314
389
reshape_reshape_rule ,
315
390
slice_split_rule ,
316
391
no_op_transpose_rule ,
0 commit comments