-
Notifications
You must be signed in to change notification settings - Fork 459
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Infer ONNX conv output shapes #2304
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2304 +/- ##
==========================================
- Coverage 85.44% 85.42% -0.02%
==========================================
Files 766 766
Lines 97916 98019 +103
==========================================
+ Hits 83667 83736 +69
- Misses 14249 14283 +34 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ha, the linked article brings back memories 😄 I remember having to compute the shapes for different convolution kernels back in the days, a very useful resource!
Thanks for addressing the linked issue! I have a question though.. wouldn't it have worked to simply set the output shape as None
instead of computing the exact shape since we mostly rely on the rank? Sure, the broadcasting validity couldn't be fully confirmed by inspecting the shapes, but it's assumed to be valid otherwise.
It could become quite tedious to have to manually infer the output shapes for operations (and error prone), so we tried to not rely on the exact shapes until this point.
We don't need to track the output shapes statically because the Burn API tracks and calculates dynamically. Burn needs only rank output. However, onnx-ir might have a feature that tracks shapes which might be useful for other frameworks. |
From my point of view, this should be an either or:
The current state, where the shape is only sometimes available, and more importantly not reliable, is IMHO bad for implementing Operators. If the onnx-ir were to go the second route, some Operators become hard (impossible?) to implement. |
You are right. Here we have to decide. I think eventually I would prefer onnx-ir tracked shape information. However, I do not want to burden Burn onnx implementers right now. Currently it is a hybrid approach. The project started initially tracking shape information. Maybe we should have this feature in onnx-ir then we can add more as we we go along? |
So far I've haven't really made a distinction between onnx-ir and burn-import, so maybe I'm missing something here. I'm not sure if your statements are in favor or against the inclusion of this PR? I'm willing to put in some work to improve the overall situation, but am not quite clear what you propose.
As far as I can tell from reading the ONNX IR spec, including type information (which carries the shape) is only mandatory for top graphs in- and outputs. So for intermediate values, it is possible to include in a graph's |
This was also my first concern regarding Burn implementation. But the reality is that if we want the shapes to be tracked in
After reading the comments, I think I am also in favor of tracking shape information. But this means that we need to make sure all ops currently track shapes properly, including already implemented ones (which is a bit more restrictive and will require more work). I think it should improve the support of more complex models (as we've seen, the ONNX spec is quite large and having support for a single op doesn't always mean it will cover the range of the whole spec such as it might be used in a model). I might be wrong, but I think it will reduce the friction with the ONNX import support (even it requires a bit more work..). |
Another issue just resurfaced that motivates the changes proposed in this PR to better track shapes. Since most shapes are set to |
Indeed, Burn doesn't need to track that. But in order to convert the code to Burn some ONNX ops actually require the shapes to be tracked properly it seems 😞 I think it has worked so far without tracking shapes for the simple stuff, but the limitations are starting to show. Let me know what you think @antimora |
Sorry I forgot to add my recommendation for this after we spoke offline. I recommend we do not track shapes statically primarily because generated ONNX Models should be able to accept different size inputs at runtime. If we use shapes to statically generate code for specific shapes, then Burn's ONNX will be limited in its functionality. Therefore, we only accurately track input and output ranks in our graph. If we come across a use case where we cannot implement ONNX OP without knowing an exact shape, e.g. Shape -> Reshape, then we use other tactics at our disposable to make it work. For example, we can use a fusing at ONNX op level ShapeReshape or modify/add Burn tensor operator. Ticket to remove static shape tracking: #2478 |
Another example why should get away from static input shapes: #2441 (review) |
Closing this PR in favor of #2478 |
Pull Request Template
Checklist
run-checks all
script has been executed.Related Issues/PRs
Should fix #2243
Changes
Calculate output shape of convolutions in ONNX import dim_inference step.
I've found https://arxiv.org/abs/1603.07285 to follow, since the correct output shape isn't really defined in the ONNX specs, but am not quite sure if this implementation is correct for cases where
group
!= 1.Also, convtranspose also has the same problem (in that it just returns the input shape as its output shape), but here, I've just set the shape to None instead of trying to implement it fully - this should at least prevent dependent nodes from relying on wrong information.