-
Notifications
You must be signed in to change notification settings - Fork 459
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add subtract tensor from scalar for ONNX sub op #1964
Add subtract tensor from scalar for ONNX sub op #1964
Conversation
785c5ed
to
cb23e98
Compare
b0a5ac6
to
cb6c1d3
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1964 +/- ##
=======================================
Coverage 85.29% 85.30%
=======================================
Files 798 798
Lines 95512 95522 +10
=======================================
+ Hits 81471 81481 +10
Misses 14041 14041 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you so much for fixing the bug.
It looks good overall. I have minor suggestion to imporove.
@@ -131,6 +131,9 @@ impl BinaryNode { | |||
(Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.sub(#rhs) }, | |||
(Type::Tensor(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs.sub_scalar(#rhs) }, | |||
(Type::Scalar(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs - #rhs }, | |||
(Type::Scalar(_), Type::Tensor(_)) => { | |||
move |lhs, rhs| quote! { #rhs.mul_scalar(-1).add_scalar(#lhs) } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be more efficient if we have one tensor op and rely on compiler to negate #lhs
. We can rewrite as follows: #rhs.add_scalar(-(#lhs)). So generated code might look like this:
#rhs.add_scalar(-(- 42)). And Rust compiler will precompute number literal correctly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks or review @antimora!
Since this is lhs (scalar) - rhs (tensor), looks like -#rhs.sub_scalar(#lhs)
produced the right result. I will make an update.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Other than the already requested change, LGTM!
Thanks for fixing :)
Make sure sub op is more efficient by using one operator
cb6c1d3
to
af2138d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your fix! LGTM
@@ -131,6 +131,7 @@ impl BinaryNode { | |||
(Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.sub(#rhs) }, | |||
(Type::Tensor(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs.sub_scalar(#rhs) }, | |||
(Type::Scalar(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs - #rhs }, | |||
(Type::Scalar(_), Type::Tensor(_)) => move |lhs, rhs| quote! { -#rhs.sub_scalar(#lhs) }, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
FYI @nathanielsimard , @laggui , @louisfd
Another reason to support native Scalar type in Burn. See our earlier design discussion: #1689 (comment)
{ -#rhs.sub_scalar(#lhs) },
solution to scalar - tensor
will result in two operations instead of one.
I'll merge to the main once the CI passes again. |
Pull Request Template
Checklist
run-checks all
script has been executed.Related Issues/PRs
Help Wanted: Implementing ONNX Ops
Changes
Previous sub and sub_int do not handle a scalar subtracting a tensor. This is a problem when implementing ONNX ops like pad.
Added support for sub and sub_int to handle that scenario
Testing
crates/burn-import/onnx-tests
cargo test
crates/burn-import
cargo test