Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Infer ONNX conv output shapes #2304

Closed
wants to merge 2 commits into from

Conversation

hexd0t
Copy link
Contributor

@hexd0t hexd0t commented Sep 25, 2024

Pull Request Template

Checklist

  • Confirmed that run-checks all script has been executed.
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

Should fix #2243

Changes

Calculate output shape of convolutions in ONNX import dim_inference step.
I've found https://arxiv.org/abs/1603.07285 to follow, since the correct output shape isn't really defined in the ONNX specs, but am not quite sure if this implementation is correct for cases where group != 1.

Also, convtranspose also has the same problem (in that it just returns the input shape as its output shape), but here, I've just set the shape to None instead of trying to implement it fully - this should at least prevent dependent nodes from relying on wrong information.

Copy link

codecov bot commented Sep 25, 2024

Codecov Report

Attention: Patch coverage is 88.65979% with 11 lines in your changes missing coverage. Please review.

Project coverage is 85.42%. Comparing base (a6f7a5e) to head (64783e0).
Report is 91 commits behind head on main.

Files with missing lines Patch % Lines
crates/onnx-ir/src/dim_inference.rs 92.13% 7 Missing ⚠️
crates/onnx-ir/src/ir.rs 50.00% 4 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2304      +/-   ##
==========================================
- Coverage   85.44%   85.42%   -0.02%     
==========================================
  Files         766      766              
  Lines       97916    98019     +103     
==========================================
+ Hits        83667    83736      +69     
- Misses      14249    14283      +34     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ha, the linked article brings back memories 😄 I remember having to compute the shapes for different convolution kernels back in the days, a very useful resource!

Thanks for addressing the linked issue! I have a question though.. wouldn't it have worked to simply set the output shape as None instead of computing the exact shape since we mostly rely on the rank? Sure, the broadcasting validity couldn't be fully confirmed by inspecting the shapes, but it's assumed to be valid otherwise.

It could become quite tedious to have to manually infer the output shapes for operations (and error prone), so we tried to not rely on the exact shapes until this point.

@antimora
Copy link
Collaborator

We don't need to track the output shapes statically because the Burn API tracks and calculates dynamically. Burn needs only rank output.

However, onnx-ir might have a feature that tracks shapes which might be useful for other frameworks.

@hexd0t
Copy link
Contributor Author

hexd0t commented Sep 29, 2024

From my point of view, this should be an either or:

  • Either, all operations track their shapes, potentially duplicating some code as well as possibly having bugs when the predicted shape doesn't match the actual output, or
  • Like in Burn, Shapes aren't tracked explicitly, but then I feel like the Shape property of ArgType::Tensors should be removed completely.

The current state, where the shape is only sometimes available, and more importantly not reliable, is IMHO bad for implementing Operators.

If the onnx-ir were to go the second route, some Operators become hard (impossible?) to implement.
The example that got me started on implementing more Shape Tracking was Expand (see #2189), where the shape of the second input (also called "shape") can influence the rank of the output (see current impl), if the 1D-Tensor has more entries than the 1st input's rank. I'm not sure how this could be implemented without shape tracking.

@antimora
Copy link
Collaborator

From my point of view, this should be an either or:

  • Either, all operations track their shapes, potentially duplicating some code as well as possibly having bugs when the predicted shape doesn't match the actual output, or

  • Like in Burn, Shapes aren't tracked explicitly, but then I feel like the Shape property of ArgType::Tensors should be removed completely.

The current state, where the shape is only sometimes available, and more importantly not reliable, is IMHO bad for implementing Operators.

If the onnx-ir were to go the second route, some Operators become hard (impossible?) to implement.

The example that got me started on implementing more Shape Tracking was Expand (see #2189), where the shape of the second input (also called "shape") can influence the rank of the output (see current impl), if the 1D-Tensor has more entries than the 1st input's rank. I'm not sure how this could be implemented without shape tracking.

You are right. Here we have to decide. I think eventually I would prefer onnx-ir tracked shape information. However, I do not want to burden Burn onnx implementers right now. Currently it is a hybrid approach. The project started initially tracking shape information.

Maybe we should have this feature in onnx-ir then we can add more as we we go along?

@hexd0t
Copy link
Contributor Author

hexd0t commented Sep 30, 2024

So far I've haven't really made a distinction between onnx-ir and burn-import, so maybe I'm missing something here. I'm not sure if your statements are in favor or against the inclusion of this PR?

I'm willing to put in some work to improve the overall situation, but am not quite clear what you propose.

Maybe we should have this feature in onnx-ir then we can add more as we we go along?

As far as I can tell from reading the ONNX IR spec, including type information (which carries the shape) is only mandatory for top graphs in- and outputs. So for intermediate values, it is possible to include in a graph's value_info, but we cannot rely on just needing to read it from the file. Hence I went ahead and started implementing shape tracking in the dim_inference-part of onnx-ir. For some Ops, the Shape is already tracked correctly. Do you see a better way than adding shape tracking piece by piece, until at some point the Option<Shape> in TensorType can become a mandatory Shape?

@laggui
Copy link
Member

laggui commented Oct 1, 2024

I think eventually I would prefer onnx-ir tracked shape information. However, I do not want to burden Burn onnx implementers right now.

This was also my first concern regarding Burn implementation. But the reality is that if we want the shapes to be tracked in onnx-ir, any new node will likely require the implementation to track the shapes - and this is usually done by the same developer when adding a new ONNX op to support for burn import.


So far I've haven't really made a distinction between onnx-ir and burn-import, so maybe I'm missing something here.

onnx-ir is separated from burn-import and has no dependency to Burn. It is simply meant to be an ONNX parsing crate to convert graphs into the defined representation (with NodeTypes etc.) so that a Rust user can use the representation as they wish. Right now it is used in burn-import to translate the intermediate graph representation into Burn compatible operations.

I'm not sure if your statements are in favor or against the inclusion of this PR?

After reading the comments, I think I am also in favor of tracking shape information. But this means that we need to make sure all ops currently track shapes properly, including already implemented ones (which is a bit more restrictive and will require more work).

I think it should improve the support of more complex models (as we've seen, the ONNX spec is quite large and having support for a single op doesn't always mean it will cover the range of the whole spec such as it might be used in a model). I might be wrong, but I think it will reduce the friction with the ONNX import support (even it requires a bit more work..).

@laggui
Copy link
Member

laggui commented Oct 4, 2024

Another issue just resurfaced that motivates the changes proposed in this PR to better track shapes.

Since most shapes are set to None (in favor of runtime shapes), there is no way to know the output rank for the Reshape node since the second input (which specifies the output shape) will have shape: None. This results in the output dim not being set, thus having an incorrect shape set for the remainder of the graph operations.

@laggui
Copy link
Member

laggui commented Oct 7, 2024

We don't need to track the output shapes statically because the Burn API tracks and calculates dynamically. Burn needs only rank output.

Indeed, Burn doesn't need to track that. But in order to convert the code to Burn some ONNX ops actually require the shapes to be tracked properly it seems 😞 I think it has worked so far without tracking shapes for the simple stuff, but the limitations are starting to show.

Let me know what you think @antimora

@antimora
Copy link
Collaborator

antimora commented Oct 28, 2024

We don't need to track the output shapes statically because the Burn API tracks and calculates dynamically. Burn needs only rank output.

Indeed, Burn doesn't need to track that. But in order to convert the code to Burn some ONNX ops actually require the shapes to be tracked properly it seems 😞 I think it has worked so far without tracking shapes for the simple stuff, but the limitations are starting to show.

Let me know what you think @antimora

Sorry I forgot to add my recommendation for this after we spoke offline.

I recommend we do not track shapes statically primarily because generated ONNX Models should be able to accept different size inputs at runtime. If we use shapes to statically generate code for specific shapes, then Burn's ONNX will be limited in its functionality.

Therefore, we only accurately track input and output ranks in our graph.

If we come across a use case where we cannot implement ONNX OP without knowing an exact shape, e.g. Shape -> Reshape, then we use other tactics at our disposable to make it work. For example, we can use a fusing at ONNX op level ShapeReshape or modify/add Burn tensor operator.

Ticket to remove static shape tracking: #2478

@antimora antimora mentioned this pull request Oct 30, 2024
1 task
@antimora
Copy link
Collaborator

Another example why should get away from static input shapes: #2441 (review)

@antimora
Copy link
Collaborator

Closing this PR in favor of #2478

@antimora antimora closed this Nov 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
3 participants