Skip to content

Commit d0fb218

Browse files
authored
[rewriter] Unify reshape flatten ops (#2518)
Following (#2301), `flatten_to_reshape_rule` rule set is introduced to reduce the following list of operators: - Reshape ∘ Flatten -> Reshape - Flatten ∘ Reshape -> Reshape Note to support this changes: - `ReshapeReshape` rule is updated to support more cases. - `Flatten2Reshape` rule is introduced to convert Flatten ops into Reshape when possible.
1 parent f5f9e6a commit d0fb218

File tree

3 files changed

+288
-65
lines changed

3 files changed

+288
-65
lines changed

onnxscript/rewriter/rules/common/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
"div_by_1_rule",
1111
"dropout_inference_rule",
1212
"dropout_zero_rule",
13+
"flatten_to_reshape_rule",
1314
"fuse_batchnorm_into_conv_rule",
1415
"fuse_batchnorm_into_conv_transpose_rule",
1516
"fuse_batchnorm_into_gemm_rule",
@@ -48,6 +49,7 @@
4849

4950
from onnxscript.rewriter.rules.common._basic_rules import (
5051
cast_cast_rule,
52+
flatten_to_reshape_rule,
5153
no_op_cast_rule,
5254
no_op_expand_rule,
5355
no_op_transpose_rule,

onnxscript/rewriter/rules/common/_basic_rules.py

Lines changed: 81 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212
from typing import ClassVar, Sequence
1313

14+
import numpy as np
15+
1416
from onnxscript import ir
1517
from onnxscript.rewriter import _ir_utils as ir_utils
1618
from onnxscript.rewriter._basics import MatchResult
@@ -123,16 +125,37 @@ def pattern(self, op, x, shape_ignored, shape):
123125
return op.Reshape(op.Reshape(x, shape_ignored), shape)
124126

125127
def rewrite(self, op, x: ir.Value, shape_ignored: ir.Value, shape: ir.Value):
126-
return op.Reshape(x, shape)
128+
new_shape = op.initializer(ir.Tensor(self._new_shape, name=shape.name))
129+
return op.Reshape(x, new_shape, allowzero=self._allowzero)
127130

128131
def check(self, context, x, shape_ignored, shape) -> MatchResult:
129132
check_result = MatchResult()
130-
if shape_ignored.const_value is None:
131-
return check_result.fail("Shape ignored is not a constant.")
132-
if shape.const_value is None:
133+
134+
# Shape must be a constant.
135+
if (np_shape := ir_utils.get_numpy_value(shape)) is None:
133136
return check_result.fail("Shape is not a constant.")
134-
if shape.const_value.numpy().min() <= 0:
135-
return check_result.fail("Shape has non-positive values.")
137+
# Convert to array to support assignment destination.
138+
self._new_shape = np.array(np_shape, np_shape.dtype)
139+
140+
# Try to replace {0,-1} values in shape if reshape output is known.
141+
if (reshape_output := context.output_values[0].shape) is not None:
142+
for i, dim in enumerate(reshape_output):
143+
if isinstance(dim, int) and dim > 0:
144+
self._new_shape[i] = dim
145+
146+
# Constraints for shape.
147+
self._allowzero = context.nodes[0].attributes.get_int("allowzero", 0)
148+
if self._allowzero == 1 and any(self._new_shape == 0):
149+
return check_result
150+
if any(self._new_shape == 0) and any(self._new_shape < 0):
151+
return check_result.fail("Shape cannot contain both 0 and -1 dimensions.")
152+
elif np.count_nonzero(self._new_shape == 0) > 1:
153+
return check_result.fail("Shape cannot contain more than one 0 dimension.")
154+
155+
# At this point, we can safely replace '0' with '-1'.
156+
# Note allowzero is removed since at this point it does not have any effect.
157+
self._allowzero = None
158+
self._new_shape = np.where(self._new_shape == 0, -1, self._new_shape)
136159
return check_result
137160

138161

@@ -279,6 +302,55 @@ def check(self, context, x, axes1, axes2) -> MatchResult:
279302
return check_result
280303

281304

305+
class Flatten2Reshape(RewriteRuleClassBase):
306+
"""Convert ``Flatten(x)`` to Reshape."""
307+
308+
def pattern(self, op, x: ir.Value):
309+
return op.Flatten(x)
310+
311+
def rewrite(self, op, x: ir.Value):
312+
new_shape = op.initializer(ir.Tensor(self._new_shape, name=f"{x.name}/shape"))
313+
return op.Reshape(x, new_shape)
314+
315+
def check(self, context, x: ir.Value) -> MatchResult:
316+
check_result = MatchResult()
317+
self._new_shape = np.array([-1, -1], "int64")
318+
319+
# Convert axis in a positive value if possible.
320+
axis = context.root.attributes.get_int("axis", 1)
321+
input_rank = None
322+
if (input_shape := x.shape) is not None:
323+
input_rank = len(input_shape)
324+
if axis < 0:
325+
axis += input_rank
326+
327+
# Compute reshape shape following axis attribute.
328+
if axis == 0:
329+
self._new_shape[0] = 1
330+
elif axis == 1:
331+
self._new_shape[0] = 0
332+
elif axis == input_rank:
333+
self._new_shape[1] = 1
334+
335+
# Try to update shape if output is known.
336+
if (output_shape := context.output_values[0].shape) is not None:
337+
for i, dim in enumerate(output_shape):
338+
if isinstance(dim, int):
339+
self._new_shape[i] = dim
340+
341+
# Try to update shape if input is known.
342+
if input_shape is not None:
343+
if all(isinstance(dim, int) for dim in input_shape[:axis]):
344+
self._new_shape[0] = np.prod(input_shape[:axis])
345+
if all(isinstance(dim, int) for dim in input_shape[axis:]):
346+
self._new_shape[1] = np.prod(input_shape[axis:])
347+
348+
# Verify if it is possible to apply rule.
349+
if np.count_nonzero(self._new_shape == -1) > 1:
350+
return check_result.fail("Impossible to compute new shape.")
351+
return check_result
352+
353+
282354
# Create rule instances
283355
cast_cast_rule = CastCast.rule()
284356
no_op_cast_rule = CastIdentity.rule()
@@ -289,6 +361,7 @@ def check(self, context, x, axes1, axes2) -> MatchResult:
289361
transpose_transpose_rule = TransposeTranspose.rule()
290362
unsqueeze_unsqueeze_rule = UnsqueezeUnsqueeze.rule()
291363
squeeze_reshape_1d_rule = SqueezeReshape.rule()
364+
flatten_to_reshape_rule = Flatten2Reshape.rule()
292365

293366

294367
def basic_optimization_rules() -> RewriteRuleSet:
@@ -311,6 +384,8 @@ def basic_optimization_rules() -> RewriteRuleSet:
311384
cast_cast_rule,
312385
no_op_cast_rule,
313386
no_op_expand_rule,
387+
# flatten_to_reshape_rule is order sensitive to reshape_reshape_rule
388+
flatten_to_reshape_rule,
314389
reshape_reshape_rule,
315390
slice_split_rule,
316391
no_op_transpose_rule,

0 commit comments

Comments
 (0)