Skip to content

Commit

Permalink
Add onnx mean (#2119)
Browse files Browse the repository at this point in the history
* make contacts deterministic across Worlds

* add top k acc

* add onnx mean

* fix

* push fix

* format

---------

Co-authored-by: Charles Bournhonesque <[email protected]>
  • Loading branch information
cBournhonesque and cbournhonesque-sc authored Aug 7, 2024
1 parent cd848b1 commit dad85e0
Show file tree
Hide file tree
Showing 9 changed files with 208 additions and 5 deletions.
2 changes: 1 addition & 1 deletion crates/burn-import/SUPPORTED-ONNX-OPS.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ represent the corresponding Burn Op.
| [MaxPool2d][98] |||
| [MaxRoiPool][99] |||
| [MaxUnpool][100] |||
| [Mean][101] | ||
| [Mean][101] | ||
| [MeanVarianceNormalization][102] |||
| [MelWeightMatrix][103] |||
| [Min][104] |||
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ fn main() {
.input("tests/maxpool1d/maxpool1d.onnx")
.input("tests/maxpool2d/maxpool2d.onnx")
.input("tests/min/min.onnx")
.input("tests/mean/mean.onnx")
.input("tests/mul/mul.onnx")
.input("tests/neg/neg.onnx")
.input("tests/not/not.onnx")
Expand Down
23 changes: 23 additions & 0 deletions crates/burn-import/onnx-tests/tests/mean/mean.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@


mean-model:�
&
input1
input2
input3output"Mean MeanGraphZ
input1


Z
input2


Z
input3


b
output


B
41 changes: 41 additions & 0 deletions crates/burn-import/onnx-tests/tests/mean/mean.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#!/usr/bin/env python3

# used to generate model: onnx-tests/tests/mean/mean.onnx

import onnx
import onnx.helper
import onnx.checker
import numpy as np

# Create input tensors
input1 = onnx.helper.make_tensor_value_info('input1', onnx.TensorProto.FLOAT, [3])
input2 = onnx.helper.make_tensor_value_info('input2', onnx.TensorProto.FLOAT, [3])
input3 = onnx.helper.make_tensor_value_info('input3', onnx.TensorProto.FLOAT, [3])

# Create output tensor
output = onnx.helper.make_tensor_value_info('output', onnx.TensorProto.FLOAT, [3])

# Create the Mean node
mean_node = onnx.helper.make_node(
'Mean',
inputs=['input1', 'input2', 'input3'],
outputs=['output']
)

# Create the graph (GraphProto)
graph_def = onnx.helper.make_graph(
nodes=[mean_node],
name='MeanGraph',
inputs=[input1, input2, input3],
outputs=[output]
)

# Create the model (ModelProto)
model_def = onnx.helper.make_model(graph_def, producer_name='mean-model')
onnx.checker.check_model(model_def)

# Save the ONNX model
onnx.save(model_def, 'mean.onnx')

print("ONNX model 'mean.onnx' generated successfully.")

16 changes: 16 additions & 0 deletions crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ include_models!(
maxpool1d,
maxpool2d,
min,
mean,
mul,
neg,
not,
Expand Down Expand Up @@ -208,6 +209,21 @@ mod tests {
output.to_data().assert_eq(&expected, true);
}

#[test]
fn mean_tensor_and_tensor() {
let device = Default::default();
let model: mean::Model<Backend> = mean::Model::default();

let input1 = Tensor::<Backend, 1>::from_floats([1., 2., 3., 4.], &device);
let input2 = Tensor::<Backend, 1>::from_floats([2., 2., 4., 0.], &device);
let input3 = Tensor::<Backend, 1>::from_floats([3., 2., 5., -4.], &device);

let output = model.forward(input1, input2, input3);
let expected = TensorData::from([2.0f32, 2., 4., 0.]);

output.to_data().assert_eq(&expected, true);
}

#[test]
fn mul_scalar_with_tensor_and_tensor_with_tensor() {
// Initialize the model with weights (loaded from the exported file)
Expand Down
11 changes: 7 additions & 4 deletions crates/burn-import/src/burn/node/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ use super::{
conv_transpose_3d::ConvTranspose3dNode, dropout::DropoutNode, expand::ExpandNode,
gather::GatherNode, gather_elements::GatherElementsNode, global_avg_pool::GlobalAvgPoolNode,
layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode,
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, pad::PadNode, prelu::PReluNode,
random_normal::RandomNormalNode, random_uniform::RandomUniformNode, range::RangeNode,
reshape::ReshapeNode, resize::ResizeNode, slice::SliceNode, squeeze::SqueezeNode, sum::SumNode,
unary::UnaryNode, unsqueeze::UnsqueezeNode,
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, mean::MeanNode, pad::PadNode,
prelu::PReluNode, random_normal::RandomNormalNode, random_uniform::RandomUniformNode,
range::RangeNode, reshape::ReshapeNode, resize::ResizeNode, slice::SliceNode,
squeeze::SqueezeNode, sum::SumNode, unary::UnaryNode, unsqueeze::UnsqueezeNode,
};
use crate::burn::{BurnImports, Scope, Type};
use burn::backend::NdArray;
Expand Down Expand Up @@ -105,6 +105,7 @@ pub enum Node<PS: PrecisionSettings> {
Matmul(MatmulNode),
MaxPool1d(MaxPool1dNode),
MaxPool2d(MaxPool2dNode),
Mean(MeanNode),
Pad(PadNode),
Range(RangeNode),
Reshape(ReshapeNode),
Expand Down Expand Up @@ -151,6 +152,7 @@ macro_rules! match_all {
Node::Matmul(node) => $func(node),
Node::MaxPool1d(node) => $func(node),
Node::MaxPool2d(node) => $func(node),
Node::Mean(node) => $func(node),
Node::Pad(node) => $func(node),
Node::Range(node) => $func(node),
Node::Reshape(node) => $func(node),
Expand Down Expand Up @@ -205,6 +207,7 @@ impl<PS: PrecisionSettings> Node<PS> {
Node::Matmul(_) => "matmul",
Node::MaxPool1d(_) => "max_pool1d",
Node::MaxPool2d(_) => "max_pool2d",
Node::Mean(_) => "mean",
Node::Pad(_) => "pad",
Node::Range(_) => "range",
Node::Reshape(_) => "reshape",
Expand Down
109 changes: 109 additions & 0 deletions crates/burn-import/src/burn/node/mean.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
use super::{Node, NodeCodegen};
use crate::burn::{Scope, TensorType, Type};

use burn::record::PrecisionSettings;
use proc_macro2::TokenStream;
use quote::quote;

#[derive(Debug, Clone, new)]
pub struct MeanNode {
pub inputs: Vec<TensorType>,
pub output: TensorType,
}

impl<PS: PrecisionSettings> NodeCodegen<PS> for MeanNode {
fn output_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.output.clone())]
}

fn input_types(&self) -> Vec<Type> {
self.inputs
.iter()
.map(|t| Type::Tensor(t.clone()))
.collect()
}

fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
let inputs = self
.inputs
.iter()
.map(|t| scope.tensor_use_owned(t, node_position));

let output = &self.output.name;
let inputs_len = self.inputs.len() as u32;

quote! {
let #output = (#(#inputs)+*) / #inputs_len;
}
}

fn into_node(self) -> Node<PS> {
Node::Mean(self)
}
}

#[cfg(test)]
mod tests {
use burn::record::FullPrecisionSettings;

use super::*;
use crate::burn::{
graph::BurnGraph,
node::{mean::MeanNode, test::assert_tokens},
TensorType,
};

#[test]
fn test_codegen_mean() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();

graph.register(MeanNode::new(
vec![
TensorType::new_float("tensor1", 4),
TensorType::new_float("tensor2", 4),
],
TensorType::new_float("tensor3", 4),
));

graph.register_input_output(
vec!["tensor1".to_string(), "tensor2".to_string()],
vec!["tensor3".to_string()],
);

let expected = quote! {
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};

#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}

impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}

#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(
&self,
tensor1: Tensor<B, 4>,
tensor2: Tensor<B, 4>
) -> Tensor<B, 4> {
let tensor3 = (tensor1 + tensor2) / 2u32;

tensor3
}
}
};

assert_tokens(graph.codegen(), expected);
}
}
1 change: 1 addition & 0 deletions crates/burn-import/src/burn/node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pub(crate) mod mask_where;
pub(crate) mod matmul;
pub(crate) mod max_pool1d;
pub(crate) mod max_pool2d;
pub(crate) mod mean;
pub(crate) mod pad;
pub(crate) mod prelu;
pub(crate) mod random_normal;
Expand Down
9 changes: 9 additions & 0 deletions crates/burn-import/src/onnx/to_burn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ use onnx_ir::{
};

pub use crate::burn::graph::RecordType;
use crate::burn::node::mean::MeanNode;

/// Generate code and states from `.onnx` files and save them to the `out_dir`.
#[derive(Debug, Default)]
Expand Down Expand Up @@ -268,6 +269,7 @@ impl ParsedOnnxGraph {
NodeType::Max => graph.register(Self::max_conversion(node)),
NodeType::MaxPool1d => graph.register(Self::max_pool1d_conversion(node)),
NodeType::MaxPool2d => graph.register(Self::max_pool2d_conversion(node)),
NodeType::Mean => graph.register(Self::mean_conversion(node)),
NodeType::PRelu => graph.register(Self::prelu_conversion::<PS>(node)),
NodeType::AveragePool1d => graph.register(Self::avg_pool_1d_conversion(node)),
NodeType::AveragePool2d => graph.register(Self::avg_pool_2d_conversion(node)),
Expand Down Expand Up @@ -972,6 +974,13 @@ impl ParsedOnnxGraph {
MaxPool2dNode::new(name, input, output, config)
}

fn mean_conversion(node: Node) -> MeanNode {
let inputs = node.inputs.iter().map(TensorType::from).collect();
let output = TensorType::from(node.outputs.first().unwrap());

MeanNode::new(inputs, output)
}

fn prelu_conversion<PS: PrecisionSettings>(node: Node) -> PReluNode {
let input = TensorType::from(node.inputs.first().unwrap());
let output = TensorType::from(node.outputs.first().unwrap());
Expand Down

0 comments on commit dad85e0

Please sign in to comment.