Skip to content

Commit

Permalink
feat: add sum onnx import (tracel-ai#1846)
Browse files Browse the repository at this point in the history
  • Loading branch information
JachymPutta authored and LilDojd committed Jun 5, 2024
1 parent 6ffeef0 commit ec04031
Show file tree
Hide file tree
Showing 12 changed files with 290 additions and 5 deletions.
2 changes: 1 addition & 1 deletion crates/burn-import/SUPPORTED-ONNX-OPS.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ represent the corresponding Burn Op.
| [STFT][177] |||
| [StringNormalizer][178] |||
| [Sub][179] |||
| [Sum][180] | ||
| [Sum][180] | ||
| [Tan][181] |||
| [Tanh][182] |||
| [TfIdfVectorizer][183] |||
Expand Down
2 changes: 2 additions & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ fn main() {
.input("tests/conv_transpose2d/conv_transpose2d.onnx")
.input("tests/pow/pow.onnx")
.input("tests/pow/pow_int.onnx")
.input("tests/sum/sum.onnx")
.input("tests/sum/sum_int.onnx")
.input("tests/unsqueeze/unsqueeze.onnx")
.input("tests/unsqueeze/unsqueeze_opset16.onnx")
.input("tests/unsqueeze/unsqueeze_opset11.onnx")
Expand Down
32 changes: 32 additions & 0 deletions crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ include_models!(
sqrt,
sub_int,
sub,
sum,
sum_int,
tanh,
transpose,
conv_transpose2d,
Expand Down Expand Up @@ -161,6 +163,36 @@ mod tests {
assert_eq!(output.to_data(), expected);
}

#[test]
fn sum_tensor_and_tensor() {
let device = Default::default();
let model: sum::Model<Backend> = sum::Model::default();

let input1 = Tensor::<Backend, 1>::from_floats([1., 2., 3., 4.], &device);
let input2 = Tensor::<Backend, 1>::from_floats([1., 2., 3., 4.], &device);
let input3 = Tensor::<Backend, 1>::from_floats([1., 2., 3., 4.], &device);

let output = model.forward(input1, input2, input3);
let expected = Data::from([3., 6., 9., 12.]);

assert_eq!(output.to_data(), expected);
}

#[test]
fn sum_int_tensor_and_int_tensor() {
let device = Default::default();
let model: sum_int::Model<Backend> = sum_int::Model::default();

let input1 = Tensor::<Backend, 1, Int>::from_ints([1, 2, 3, 4], &device);
let input2 = Tensor::<Backend, 1, Int>::from_ints([1, 2, 3, 4], &device);
let input3 = Tensor::<Backend, 1, Int>::from_ints([1, 2, 3, 4], &device);

let output = model.forward(input1, input2, input3);
let expected = Data::from([3, 6, 9, 12]);

assert_eq!(output.to_data(), expected);
}

#[test]
fn mul_scalar_with_tensor_and_tensor_with_tensor() {
// Initialize the model with weights (loaded from the exported file)
Expand Down
22 changes: 22 additions & 0 deletions crates/burn-import/onnx-tests/tests/sum/sum.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@

 sum-model:�
%
input1
input2
input3output"SumSumGraphZ
input1


Z
input2


Z
input3


b
output


B
41 changes: 41 additions & 0 deletions crates/burn-import/onnx-tests/tests/sum/sum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#!/usr/bin/env python3

# used to generate model: onnx-tests/tests/sum/sum.onnx

import onnx
import onnx.helper
import onnx.checker
import numpy as np

# Create input tensors
input1 = onnx.helper.make_tensor_value_info('input1', onnx.TensorProto.FLOAT, [3])
input2 = onnx.helper.make_tensor_value_info('input2', onnx.TensorProto.FLOAT, [3])
input3 = onnx.helper.make_tensor_value_info('input3', onnx.TensorProto.FLOAT, [3])

# Create output tensor
output = onnx.helper.make_tensor_value_info('output', onnx.TensorProto.FLOAT, [3])

# Create the Sum node
sum_node = onnx.helper.make_node(
'Sum',
inputs=['input1', 'input2', 'input3'],
outputs=['output']
)

# Create the graph (GraphProto)
graph_def = onnx.helper.make_graph(
nodes=[sum_node],
name='SumGraph',
inputs=[input1, input2, input3],
outputs=[output]
)

# Create the model (ModelProto)
model_def = onnx.helper.make_model(graph_def, producer_name='sum-model')
onnx.checker.check_model(model_def)

# Save the ONNX model
onnx.save(model_def, 'sum.onnx')

print("ONNX model 'sum.onnx' generated successfully.")

22 changes: 22 additions & 0 deletions crates/burn-import/onnx-tests/tests/sum/sum_int.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@

 sum-model:�
%
input1
input2
input3output"SumSumGraphZ
input1


Z
input2


Z
input3


b
output


B
40 changes: 40 additions & 0 deletions crates/burn-import/onnx-tests/tests/sum/sum_int.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#!/usr/bin/env python3

# used to generate model: onnx-tests/tests/sum/sum.onnx

import onnx
import onnx.helper
import onnx.checker
import numpy as np

# Create input tensors
input1 = onnx.helper.make_tensor_value_info('input1', onnx.TensorProto.INT64, [3])
input2 = onnx.helper.make_tensor_value_info('input2', onnx.TensorProto.INT64, [3])
input3 = onnx.helper.make_tensor_value_info('input3', onnx.TensorProto.INT64, [3])

# Create output tensor
output = onnx.helper.make_tensor_value_info('output', onnx.TensorProto.INT64, [3])

# Create the Sum node
sum_node = onnx.helper.make_node(
'Sum',
inputs=['input1', 'input2', 'input3'],
outputs=['output']
)

# Create the graph (GraphProto)
graph_def = onnx.helper.make_graph(
nodes=[sum_node],
name='SumGraph',
inputs=[input1, input2, input3],
outputs=[output]
)

# Create the model (ModelProto)
model_def = onnx.helper.make_model(graph_def, producer_name='sum-model')
onnx.checker.check_model(model_def)

# Save the ONNX model
onnx.save(model_def, 'sum_int.onnx')

print("ONNX model 'sum_int.onnx' generated successfully.")
11 changes: 7 additions & 4 deletions crates/burn-import/src/burn/node/base.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
use super::expand::ExpandNode;
use super::{
argmax::ArgMaxNode, avg_pool1d::AvgPool1dNode, avg_pool2d::AvgPool2dNode,
batch_norm::BatchNormNode, binary::BinaryNode, clip::ClipNode, concat::ConcatNode,
constant::ConstantNode, conv1d::Conv1dNode, conv2d::Conv2dNode,
conv_transpose_2d::ConvTranspose2dNode, dropout::DropoutNode, gather::GatherNode,
gather_elements::GatherElementsNode, global_avg_pool::GlobalAvgPoolNode,
conv_transpose_2d::ConvTranspose2dNode, dropout::DropoutNode, expand::ExpandNode,
gather::GatherNode, gather_elements::GatherElementsNode, global_avg_pool::GlobalAvgPoolNode,
layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode,
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, prelu::PReluNode,
random_normal::RandomNormalNode, random_uniform::RandomUniformNode, range::RangeNode,
reshape::ReshapeNode, squeeze::SqueezeNode, unary::UnaryNode, unsqueeze::UnsqueezeNode,
reshape::ReshapeNode, squeeze::SqueezeNode, sum::SumNode, unary::UnaryNode,
unsqueeze::UnsqueezeNode,
};
use crate::burn::{BurnImports, Scope, Type};
use burn::backend::NdArray;
Expand Down Expand Up @@ -103,6 +103,7 @@ pub enum Node<PS: PrecisionSettings> {
Range(RangeNode),
Reshape(ReshapeNode),
Squeeze(SqueezeNode),
Sum(SumNode),
Unary(UnaryNode),
Unsqueeze(UnsqueezeNode),
Where(WhereNode),
Expand Down Expand Up @@ -139,6 +140,7 @@ macro_rules! match_all {
Node::Range(node) => $func(node),
Node::Reshape(node) => $func(node),
Node::Squeeze(node) => $func(node),
Node::Sum(node) => $func(node),
Node::Unary(node) => $func(node),
Node::Unsqueeze(node) => $func(node),
Node::Where(node) => $func(node),
Expand Down Expand Up @@ -185,6 +187,7 @@ impl<PS: PrecisionSettings> Node<PS> {
Node::Range(_) => "range",
Node::Reshape(_) => "reshape",
Node::Squeeze(_) => "squeeze",
Node::Sum(_) => "add",
Node::Unary(unary) => unary.kind.as_str(),
Node::Unsqueeze(_) => "unsqueeze",
Node::Where(_) => "where",
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/src/burn/node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ pub(crate) mod random_uniform;
pub(crate) mod range;
pub(crate) mod reshape;
pub(crate) mod squeeze;
pub(crate) mod sum;
pub(crate) mod unary;
pub(crate) mod unsqueeze;
pub(crate) use base::*;
Expand Down
108 changes: 108 additions & 0 deletions crates/burn-import/src/burn/node/sum.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
use super::{Node, NodeCodegen};
use crate::burn::{Scope, TensorType, Type};

use burn::record::PrecisionSettings;
use proc_macro2::TokenStream;
use quote::quote;

#[derive(Debug, Clone, new)]
pub struct SumNode {
pub inputs: Vec<TensorType>,
pub output: TensorType,
}

impl<PS: PrecisionSettings> NodeCodegen<PS> for SumNode {
fn output_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.output.clone())]
}

fn input_types(&self) -> Vec<Type> {
self.inputs
.iter()
.map(|t| Type::Tensor(t.clone()))
.collect()
}

fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
let inputs = self
.inputs
.iter()
.map(|t| scope.tensor_use_owned(t, node_position));

let output = &self.output.name;

quote! {
let #output = #(#inputs)+*;
}
}

fn into_node(self) -> Node<PS> {
Node::Sum(self)
}
}

#[cfg(test)]
mod tests {
use burn::record::FullPrecisionSettings;

use super::*;
use crate::burn::{
graph::BurnGraph,
node::{sum::SumNode, test::assert_tokens},
TensorType,
};

#[test]
fn test_codegen_sum() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();

graph.register(SumNode::new(
vec![
TensorType::new_float("tensor1", 4),
TensorType::new_float("tensor2", 4),
],
TensorType::new_float("tensor3", 4),
));

graph.register_input_output(
vec!["tensor1".to_string(), "tensor2".to_string()],
vec!["tensor3".to_string()],
);

let expected = quote! {
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};

#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}

impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}

#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(
&self,
tensor1: Tensor<B, 4>,
tensor2: Tensor<B, 4>
) -> Tensor<B, 4> {
let tensor3 = tensor1 + tensor2;

tensor3
}
}
};

assert_tokens(graph.codegen(), expected);
}
}
1 change: 1 addition & 0 deletions crates/burn-import/src/onnx/dim_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) {
NodeType::Softmax => same_as_input(node),
NodeType::Sqrt => same_as_input(node),
NodeType::Sub => same_as_input(node),
NodeType::Sum => same_as_input(node),
NodeType::Tanh => same_as_input(node),
NodeType::Transpose => same_as_input(node),
NodeType::Unsqueeze => unsqueeze_update_output(node),
Expand Down
13 changes: 13 additions & 0 deletions crates/burn-import/src/onnx/to_burn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ use crate::{
range::RangeNode,
reshape::ReshapeNode,
squeeze::SqueezeNode,
sum::SumNode,
unary::UnaryNode,
unsqueeze::UnsqueezeNode,
},
Expand Down Expand Up @@ -293,6 +294,7 @@ impl OnnxGraph {
NodeType::Shape => graph.register(Self::shape_conversion(node)),
NodeType::Sigmoid => graph.register(Self::sigmoid_conversion(node)),
NodeType::Sin => graph.register(Self::sin_conversion(node)),
NodeType::Sum => graph.register(Self::sum_conversion(node)),
NodeType::Transpose => graph.register(Self::transpose_conversion(node)),
NodeType::Concat => graph.register(Self::concat_conversion(node)),
NodeType::Cast => graph.register(Self::cast_conversion(node)),
Expand Down Expand Up @@ -684,6 +686,17 @@ impl OnnxGraph {
UnaryNode::sin(input, output)
}

fn sum_conversion(node: Node) -> SumNode {
let inputs = node
.inputs
.iter()
.map(|input| input.to_tensor_type())
.collect();
let output = node.outputs.first().unwrap().to_tensor_type();

SumNode::new(inputs, output)
}

fn reciprocal_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
Expand Down

0 comments on commit ec04031

Please sign in to comment.