-
Notifications
You must be signed in to change notification settings - Fork 82
[rewriter] Unify reshape flatten ops #2518
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[rewriter] Unify reshape flatten ops #2518
Conversation
self.new_shape = np.array(np_shape, np_shape.dtype) | ||
|
||
# Try to replace {0,-1} values in shape if reshape output is known. | ||
if (reshape_output := context.output_values[0].shape) is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would this better be done in constant folding? @gramalingam
Or maybe not if a user doesn't want to fold other constants?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a different question (and may be that is your main point too): should this rule be decomposed into two simpler rules ... specifically, the transformation applied to the second Reshape in this pattern can be applied even if it does not follow a first Reshape (at least, that is my reading of this, please correct me otherwise). And then we can have the original rule for composing two Reshapes, improved to allow -1 but not zero in the shape.
As to when/where this transformation is done: maybe it could even be done in the torch exporter itself?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. @gramalingam and @titaiwangms for more eyes.
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #2518 +/- ##
==========================================
+ Coverage 69.99% 70.16% +0.16%
==========================================
Files 216 216
Lines 26074 26215 +141
Branches 2618 2638 +20
==========================================
+ Hits 18250 18393 +143
+ Misses 6921 6918 -3
- Partials 903 904 +1 ☔ View full report in Codecov by Sentry. |
11673eb
to
704d4a1
Compare
Last push force with @justinchuby suggestions |
9743c6c
to
22bed90
Compare
22bed90
to
2493299
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks fine to me. As discussed in the comments, this could potentially be generalized further by splitting it into two separate rules.
@Johansmm could you merge from main to resolve conflicts? There was a refactoring done to clean up the name space for all rewrite rules. Thanks |
144a879
to
2bf14e4
Compare
Last push force rebasing on main |
- rewrite test with ir.tape approach. - include new tests around check function.
- remove pointless check in shape ignored - (conditional) support negative shape - (conditional) support zero shape
- Convert Flatten to reshape if possible - Merge Flatten + Reshape or Reshape + Flatten
2bf14e4
to
0825709
Compare
Last push force exposing rule |
Following (#2301),
flatten_to_reshape_rule
rule set is introduced to reduce the following list of operators:Note to support this changes:
ReshapeReshape
rule is updated to support more cases.Flatten2Reshape
rule is introduced to convert Flatten ops into Reshape when possible.