Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add sum onnx import #1846

Merged
merged 2 commits into from
Jun 3, 2024
Merged

Conversation

JachymPutta
Copy link
Contributor

Checklist

  • Confirmed that run-checks all script has been executed.
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

Helping with #1714

Changes

Added the ONNX models for sum and sum_int, the node seems to by covered by add

Testing

Manual

cargo test -p burn-import --color=always -- --color=always
cargo test -p onnx-tests --color=always

Copy link

codecov bot commented Jun 2, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 86.28%. Comparing base (99e1ba4) to head (c146776).
Report is 5 commits behind head on main.

Current head c146776 differs from pull request most recent head fadccda

Please upload reports for the commit fadccda to get more accurate results.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1846      +/-   ##
==========================================
+ Coverage   86.24%   86.28%   +0.03%     
==========================================
  Files         767      769       +2     
  Lines       88575    88833     +258     
==========================================
+ Hits        76394    76648     +254     
- Misses      12181    12185       +4     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@antimora
Copy link
Collaborator

antimora commented Jun 2, 2024

Yeah I think Add and Sum are similar but not the same. The python script generates Add node types. Probably it's difficult or not possible to achieve with PyTorch itself. So you might have to resort to onnx python utility. Here is the code that can generate one:

import onnx
import onnx.helper
import onnx.checker
import numpy as np

# Create input tensors
input1 = onnx.helper.make_tensor_value_info('input1', onnx.TensorProto.FLOAT, [3])
input2 = onnx.helper.make_tensor_value_info('input2', onnx.TensorProto.FLOAT, [3])
input3 = onnx.helper.make_tensor_value_info('input3', onnx.TensorProto.FLOAT, [3])

# Create output tensor
output = onnx.helper.make_tensor_value_info('output', onnx.TensorProto.FLOAT, [3])

# Create the Sum node
sum_node = onnx.helper.make_node(
    'Sum',
    inputs=['input1', 'input2', 'input3'],
    outputs=['output']
)

# Create the graph (GraphProto)
graph_def = onnx.helper.make_graph(
    nodes=[sum_node],
    name='SumGraph',
    inputs=[input1, input2, input3],
    outputs=[output]
)

# Create the model (ModelProto)
model_def = onnx.helper.make_model(graph_def, producer_name='sum-model')
onnx.checker.check_model(model_def)

# Save the ONNX model
onnx.save(model_def, 'sum_model.onnx')

print("ONNX model 'sum_model.onnx' generated successfully.")

However, for the underlying operator you'll have to use Burn's add operator multiple times. This is very similar to concat operation, see: https://github.com/tracel-ai/burn/blob/main/crates/burn-import/src/burn/node/concat.rs#L36-L38

Basically, a generated code might look like input1 + input2 + input3

@JachymPutta
Copy link
Contributor Author

Ahh, I see, let me fix that!
Thank you!

@JachymPutta JachymPutta marked this pull request as draft June 2, 2024 21:07
@JachymPutta JachymPutta force-pushed the jp/sum_onnx_import branch from 3950f58 to bad96c3 Compare June 2, 2024 22:15
@JachymPutta JachymPutta marked this pull request as ready for review June 2, 2024 22:40
Copy link
Collaborator

@antimora antimora left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is one little rename required but overall looks great!

Thank you for your contribution! You are on a roll!

Approving in advance.

};

#[test]
fn test_codegen_concat() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_codegen_concat => test_codegen_sum

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I missed that! That's what I get for the mindless copying:)

@JachymPutta JachymPutta force-pushed the jp/sum_onnx_import branch from c146776 to fadccda Compare June 3, 2024 17:07
@JachymPutta
Copy link
Contributor Author

Should be good!

@antimora antimora merged commit a5af19b into tracel-ai:main Jun 3, 2024
12 checks passed
LilDojd pushed a commit to LilDojd/burn that referenced this pull request Jun 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants