-
Notifications
You must be signed in to change notification settings - Fork 459
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add sum onnx import #1846
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #1846 +/- ##
==========================================
+ Coverage 86.24% 86.28% +0.03%
==========================================
Files 767 769 +2
Lines 88575 88833 +258
==========================================
+ Hits 76394 76648 +254
- Misses 12181 12185 +4 ☔ View full report in Codecov by Sentry. |
Yeah I think Add and Sum are similar but not the same. The python script generates Add node types. Probably it's difficult or not possible to achieve with PyTorch itself. So you might have to resort to import onnx
import onnx.helper
import onnx.checker
import numpy as np
# Create input tensors
input1 = onnx.helper.make_tensor_value_info('input1', onnx.TensorProto.FLOAT, [3])
input2 = onnx.helper.make_tensor_value_info('input2', onnx.TensorProto.FLOAT, [3])
input3 = onnx.helper.make_tensor_value_info('input3', onnx.TensorProto.FLOAT, [3])
# Create output tensor
output = onnx.helper.make_tensor_value_info('output', onnx.TensorProto.FLOAT, [3])
# Create the Sum node
sum_node = onnx.helper.make_node(
'Sum',
inputs=['input1', 'input2', 'input3'],
outputs=['output']
)
# Create the graph (GraphProto)
graph_def = onnx.helper.make_graph(
nodes=[sum_node],
name='SumGraph',
inputs=[input1, input2, input3],
outputs=[output]
)
# Create the model (ModelProto)
model_def = onnx.helper.make_model(graph_def, producer_name='sum-model')
onnx.checker.check_model(model_def)
# Save the ONNX model
onnx.save(model_def, 'sum_model.onnx')
print("ONNX model 'sum_model.onnx' generated successfully.") However, for the underlying operator you'll have to use Burn's add operator multiple times. This is very similar to concat operation, see: https://github.com/tracel-ai/burn/blob/main/crates/burn-import/src/burn/node/concat.rs#L36-L38 Basically, a generated code might look like |
Ahh, I see, let me fix that! |
3950f58
to
bad96c3
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is one little rename required but overall looks great!
Thank you for your contribution! You are on a roll!
Approving in advance.
}; | ||
|
||
#[test] | ||
fn test_codegen_concat() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test_codegen_concat => test_codegen_sum
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I missed that! That's what I get for the mindless copying:)
c146776
to
fadccda
Compare
Should be good! |
Checklist
run-checks all
script has been executed.Related Issues/PRs
Helping with #1714
Changes
Added the ONNX models for sum and sum_int, the node seems to by covered by
add
Testing
Manual
cargo test -p burn-import --color=always -- --color=always
cargo test -p onnx-tests --color=always