Skip to content

Commit

Permalink
feat: add sum onnx import
Browse files Browse the repository at this point in the history
  • Loading branch information
JachymPutta committed Jun 2, 2024
1 parent e407c76 commit 3950f58
Show file tree
Hide file tree
Showing 8 changed files with 151 additions and 1 deletion.
2 changes: 1 addition & 1 deletion crates/burn-import/SUPPORTED-ONNX-OPS.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ represent the corresponding Burn Op.
| [STFT][177] |||
| [StringNormalizer][178] |||
| [Sub][179] |||
| [Sum][180] | ||
| [Sum][180] | ||
| [Tan][181] |||
| [Tanh][182] |||
| [TfIdfVectorizer][183] |||
Expand Down
2 changes: 2 additions & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ fn main() {
.input("tests/conv_transpose2d/conv_transpose2d.onnx")
.input("tests/pow/pow.onnx")
.input("tests/pow/pow_int.onnx")
.input("tests/sum/sum.onnx")
.input("tests/sum/sum_int.onnx")
.input("tests/unsqueeze/unsqueeze.onnx")
.input("tests/unsqueeze/unsqueeze_opset16.onnx")
.input("tests/unsqueeze/unsqueeze_opset11.onnx")
Expand Down
30 changes: 30 additions & 0 deletions crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ include_models!(
sqrt,
sub_int,
sub,
sum,
sum_int,
tanh,
transpose,
conv_transpose2d,
Expand Down Expand Up @@ -160,6 +162,34 @@ mod tests {
assert_eq!(output.to_data(), expected);
}

#[test]
fn sum_tensor_and_tensor() {
let device = Default::default();
let model: sum::Model<Backend> = sum::Model::default();

let input1 = Tensor::<Backend, 2>::from_floats([[1., 2., 3., 4.]], &device);
let input2 = Tensor::<Backend, 2>::from_floats([[1., 2., 3., 4.]], &device);

let output = model.forward(input1, input2);
let expected = Data::from([[2., 4., 6., 8.]]);

assert_eq!(output.to_data(), expected);
}

#[test]
fn sum_int_tensor_and_int_tensor() {
let device = Default::default();
let model: sum_int::Model<Backend> = sum_int::Model::default();

let input1 = Tensor::<Backend, 4, Int>::from_ints([[[[1, 2, 3, 4]]]], &device);
let input2 = Tensor::<Backend, 4, Int>::from_ints([[[[1, 2, 3, 4]]]], &device);

let output = model.forward(input1, input2);
let expected = Data::from([[[[2, 4, 6, 8]]]]);

assert_eq!(output.to_data(), expected);
}

#[test]
fn mul_scalar_with_tensor_and_tensor_with_tensor() {
// Initialize the model with weights (loaded from the exported file)
Expand Down
17 changes: 17 additions & 0 deletions crates/burn-import/onnx-tests/tests/sum/sum.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
pytorch2.3.0:�
(
onnx::Add_0
onnx::Add_12/Add"Add
main_graphZ
onnx::Add_0


Z
onnx::Add_1


b
2


B
38 changes: 38 additions & 0 deletions crates/burn-import/onnx-tests/tests/sum/sum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#!/usr/bin/env python3

# used to generate model: onnx-tests/tests/sum/sum.onnx

import torch
import torch.nn as nn

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y):
return x + y

def main():
# Set seed for reproducibility
torch.manual_seed(42)
torch.set_printoptions(precision=8)

# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")

onnx_name = "sum.onnx"

test_input1 = torch.randn(4, 4, device=device)
test_input2 = torch.randn(4, 4, device=device)
torch.onnx.export(model, (test_input1, test_input2), onnx_name, verbose=False, opset_version=16)

print("Finished exporting model to {}".format(onnx_name))

print("Test input data: {} {}".format(test_input1, test_input2))
output = model.forward(test_input1, test_input2)
print("Test output data: {}".format(output))

if __name__ == '__main__':
main()
23 changes: 23 additions & 0 deletions crates/burn-import/onnx-tests/tests/sum/sum_int.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
pytorch2.3.0:�
(
onnx::Add_0
onnx::Add_12/Add"Add
main_graphZ%
onnx::Add_0




Z%
onnx::Add_1




b
2




B
38 changes: 38 additions & 0 deletions crates/burn-import/onnx-tests/tests/sum/sum_int.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#!/usr/bin/env python3

# used to generate model: onnx-tests/tests/sum/sum_int.onnx

import torch
import torch.nn as nn

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y):
return x + y

def main():
# Set seed for reproducibility
torch.manual_seed(42)
torch.set_printoptions(precision=8)

# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")

onnx_name = "sum_int.onnx"

test_input1 = torch.tensor([[[[1, 2, 3, 4]]]], device=device)
test_input2 = torch.tensor([[[[1, 2, 3, 4]]]], device=device)
torch.onnx.export(model, (test_input1, test_input2), onnx_name, verbose=False, opset_version=16)

print("Finished exporting model to {}".format(onnx_name))

print("Test input data: {} {}".format(test_input1, test_input2))
output = model.forward(test_input1, test_input2)
print("Test output data: {}".format(output))

if __name__ == '__main__':
main()
2 changes: 2 additions & 0 deletions crates/burn-import/src/burn/node/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pub enum BinaryType {
GreaterOrEqual,
Less,
LessOrEqual,
Sum,
}

impl BinaryType {
Expand All @@ -38,6 +39,7 @@ impl BinaryType {
BinaryType::GreaterOrEqual => "greater_equal",
BinaryType::Less => "lower",
BinaryType::LessOrEqual => "lower_equal",
BinaryType::Sum => "add",
}
}
}
Expand Down

0 comments on commit 3950f58

Please sign in to comment.