diff --git a/crates/burn-import/SUPPORTED-ONNX-OPS.md b/crates/burn-import/SUPPORTED-ONNX-OPS.md index f4ff56c5ff..12250c1223 100644 --- a/crates/burn-import/SUPPORTED-ONNX-OPS.md +++ b/crates/burn-import/SUPPORTED-ONNX-OPS.md @@ -147,7 +147,7 @@ represent the corresponding Burn Op. | [ReduceSumSquare][140] | ❌ | ❌ | | [Relu][141] | ✅ | ✅ | | [Reshape][142] | ✅ | ✅ | -| [Resize][143] | ❌ | ✅ | +| [Resize][143] | ✅ | ✅ | | [ReverseSequence][144] | ❌ | ❌ | | [RNN][145] | ❌ | ✅ | | [RoiAlign][146] | ❌ | ❌ | diff --git a/crates/burn-import/onnx-tests/build.rs b/crates/burn-import/onnx-tests/build.rs index 6e984df1b0..f46d218006 100644 --- a/crates/burn-import/onnx-tests/build.rs +++ b/crates/burn-import/onnx-tests/build.rs @@ -56,6 +56,7 @@ fn main() { .input("tests/reduce_sum/reduce_sum_opset13.onnx") .input("tests/reduce_sum/reduce_sum_opset11.onnx") .input("tests/reshape/reshape.onnx") + .input("tests/resize/resize.onnx") .input("tests/shape/shape.onnx") .input("tests/sigmoid/sigmoid.onnx") .input("tests/sign/sign.onnx") diff --git a/crates/burn-import/onnx-tests/tests/onnx_tests.rs b/crates/burn-import/onnx-tests/tests/onnx_tests.rs index ee71e2ee3f..f6b9a9b58a 100644 --- a/crates/burn-import/onnx-tests/tests/onnx_tests.rs +++ b/crates/burn-import/onnx-tests/tests/onnx_tests.rs @@ -67,6 +67,7 @@ include_models!( reduce_sum_opset11, relu, reshape, + resize, shape, sigmoid, sign, @@ -789,6 +790,30 @@ mod tests { assert_eq!(output.to_data(), expected); } + #[test] + fn resize() { + // Initialize the model without weights (because the exported file does not contain them) + let device = Default::default(); + let model: resize::Model = resize::Model::new(&device); + + // Run the model + let input = Tensor::::from_floats( + [[[ + [0.0, 1.0, 2.0, 3.0], + [4.0, 5.0, 6.0, 7.0], + [8.0, 9.0, 10.0, 11.0], + [12.0, 13.0, 14.0, 15.0], + ]]], + &device, + ); + let size = Tensor::::from_ints([1, 1, 2, 3], &device); + + let output = model.forward(input, size); + let expected = Data::from([[[[0.0, 1.5, 3.0], [12.0, 13.5, 15.0]]]]); + + assert_eq!(output.to_data(), expected); + } + #[test] fn shape() { let device = Default::default(); diff --git a/crates/burn-import/onnx-tests/tests/resize/resize.onnx b/crates/burn-import/onnx-tests/tests/resize/resize.onnx new file mode 100644 index 0000000000..3067216282 Binary files /dev/null and b/crates/burn-import/onnx-tests/tests/resize/resize.onnx differ diff --git a/crates/burn-import/onnx-tests/tests/resize/resize.py b/crates/burn-import/onnx-tests/tests/resize/resize.py new file mode 100644 index 0000000000..c873508648 --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/resize/resize.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python3 + +# used to generate model: onnx-tests/tests/resize/resize.onnx + +import onnx +from onnx import helper, TensorProto + +def main() -> None: + input_tensor = helper.make_tensor_value_info("input_tensor", TensorProto.FLOAT, [1, 1, 4, 4]) + sizes_tensor = helper.make_tensor_value_info("sizes", TensorProto.INT64, [4]) + + resize_node = helper.make_node( + "Resize", + name="resize_node", + inputs=["input_tensor", "", "", "sizes"], + outputs=["output"], + mode="linear", + ) + + graph_def = helper.make_graph( + nodes=[resize_node], + name="ResizeGraph", + inputs=[input_tensor, sizes_tensor], + outputs=[ + helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, 1, 2, 2]) + ], + ) + + model_def = helper.make_model(graph_def, producer_name="resize") + + onnx.save(model_def, "resize.onnx") + + +if __name__ == "__main__": + main() diff --git a/crates/burn-import/src/burn/node/base.rs b/crates/burn-import/src/burn/node/base.rs index 234c60666f..bd6e80d27a 100644 --- a/crates/burn-import/src/burn/node/base.rs +++ b/crates/burn-import/src/burn/node/base.rs @@ -7,8 +7,8 @@ use super::{ layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode, max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, prelu::PReluNode, random_normal::RandomNormalNode, random_uniform::RandomUniformNode, range::RangeNode, - reshape::ReshapeNode, slice::SliceNode, squeeze::SqueezeNode, sum::SumNode, unary::UnaryNode, - unsqueeze::UnsqueezeNode, + reshape::ReshapeNode, resize::ResizeNode, slice::SliceNode, squeeze::SqueezeNode, sum::SumNode, + unary::UnaryNode, unsqueeze::UnsqueezeNode, }; use crate::burn::{BurnImports, Scope, Type}; use burn::backend::NdArray; @@ -102,6 +102,7 @@ pub enum Node { MaxPool2d(MaxPool2dNode), Range(RangeNode), Reshape(ReshapeNode), + Resize(ResizeNode), Slice(SliceNode), Squeeze(SqueezeNode), Sum(SumNode), @@ -140,6 +141,7 @@ macro_rules! match_all { Node::MaxPool2d(node) => $func(node), Node::Range(node) => $func(node), Node::Reshape(node) => $func(node), + Node::Resize(node) => $func(node), Node::Slice(node) => $func(node), Node::Squeeze(node) => $func(node), Node::Sum(node) => $func(node), @@ -188,6 +190,7 @@ impl Node { Node::MaxPool2d(_) => "max_pool2d", Node::Range(_) => "range", Node::Reshape(_) => "reshape", + Node::Resize(_) => "resize", Node::Slice(_) => "slice", Node::Squeeze(_) => "squeeze", Node::Sum(_) => "add", diff --git a/crates/burn-import/src/burn/node/mod.rs b/crates/burn-import/src/burn/node/mod.rs index 02922b8af9..8d59a73ce1 100644 --- a/crates/burn-import/src/burn/node/mod.rs +++ b/crates/burn-import/src/burn/node/mod.rs @@ -27,6 +27,7 @@ pub(crate) mod random_normal; pub(crate) mod random_uniform; pub(crate) mod range; pub(crate) mod reshape; +pub(crate) mod resize; pub(crate) mod slice; pub(crate) mod squeeze; pub(crate) mod sum; diff --git a/crates/burn-import/src/burn/node/resize.rs b/crates/burn-import/src/burn/node/resize.rs new file mode 100644 index 0000000000..d319cbf3fe --- /dev/null +++ b/crates/burn-import/src/burn/node/resize.rs @@ -0,0 +1,207 @@ +use super::{Node, NodeCodegen}; +use crate::burn::{OtherType, Scope, TensorType, Type}; +use burn::module::Module; +use burn::record::PrecisionSettings; +use proc_macro2::TokenStream; +use quote::quote; + +#[derive(Module, Debug, Clone)] +pub enum ResizeMode { + Nearest, + Linear, + Cubic, +} + +#[derive(new, Module, Debug, Clone)] +pub struct ResizeOptions { + pub mode: ResizeMode, +} + +#[derive(Debug, Clone)] +pub struct ResizeNode { + pub field: OtherType, + pub input: TensorType, + pub output: TensorType, + pub output_size: TensorType, + pub config: ResizeOptions, +} + +impl ResizeNode { + pub fn new>( + name: S, + input: TensorType, + output: TensorType, + output_size: TensorType, + config: ResizeOptions, + ) -> Self { + Self { + field: OtherType::new( + name, + quote! { + burn::module::Ignored + }, + ), + input, + output, + output_size, + config, + } + } +} + +impl NodeCodegen for ResizeNode { + fn output_types(&self) -> Vec { + vec![Type::Tensor(self.output.clone())] + } + + fn input_types(&self) -> Vec { + vec![ + Type::Tensor(self.input.clone()), + Type::Tensor(self.output_size.clone()), + ] + } + + fn field_type(&self) -> Option { + Some(Type::Other(self.field.clone())) + } + + fn field_init(&self) -> Option { + let name = &self.field.name; + + let mode = match self.config.mode { + ResizeMode::Linear => quote! { InterpolateMode::Bilinear }, + ResizeMode::Nearest => quote! { InterpolateMode::Nearest }, + ResizeMode::Cubic => quote! { InterpolateMode::Bicubic }, + }; + + let tokens = quote! { + let #name = InterpolateOptions { + mode: #mode, + }; + let #name = burn::module::Ignored(#name); + }; + + Some(tokens) + } + + fn field_serialize(&self, serializer: S) -> Result { + S::serialize_none(serializer) + } + + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + let input = scope.tensor_use_owned(&self.input, node_position); + let output_size = scope.tensor_use_owned(&self.output_size, node_position); + let output = &self.output.name; + + let field = &self.field.name; + + quote! { + let output_size_raw = #output_size.to_data().value; + let mut output_size = [0usize; 2]; + + for (i, &x) in output_size_raw.iter().rev().take(2).rev().enumerate() { + output_size[i] = x.elem::() as usize; + } + + let #output = interpolate( + #input, + output_size, + self.#field.0.clone(), + ); + } + } + + fn into_node(self) -> Node { + Node::Resize(self) + } + + fn register_imports(&self, imports: &mut crate::burn::BurnImports) { + imports.register("burn::tensor::ElementConversion"); + imports.register("burn::tensor::module::interpolate"); + imports.register("burn::tensor::ops::InterpolateMode"); + imports.register("burn::tensor::ops::InterpolateOptions"); + } +} + +#[cfg(test)] +mod tests { + use burn::record::FullPrecisionSettings; + + use super::*; + use crate::burn::{ + graph::BurnGraph, + node::{resize::ResizeNode, test::assert_tokens}, + TensorType, + }; + + #[test] + fn test_codegen_nodes() { + let mut graph = BurnGraph::::default(); + + graph.register(ResizeNode::new( + "resize", + TensorType::new_float("tensor1", 4), + TensorType::new_float("tensor2", 4), + TensorType::new_int("output_size", 1), + ResizeOptions::new(ResizeMode::Linear), + )); + + graph.register_input_output( + vec!["tensor1".to_string(), "output_size".to_string()], + vec!["tensor2".to_string()], + ); + + let expected = quote! { + use burn::tensor::module::interpolate; + use burn::tensor::ops::InterpolateMode; + use burn::tensor::ops::InterpolateOptions; + use burn::tensor::ElementConversion; + use burn::tensor::Int; + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + + #[derive(Module, Debug)] + pub struct Model { + resize: burn::module::Ignored, + phantom: core::marker::PhantomData, + device: burn::module::Ignored, + } + + impl Model { + #[allow(unused_variables)] + pub fn new(device: &B::Device) -> Self { + let resize = InterpolateOptions { + mode: InterpolateMode::Bilinear, + }; + let resize = burn::module::Ignored(resize); + Self { + resize, + phantom: core::marker::PhantomData, + device: burn::module::Ignored(device.clone()), + } + } + #[allow(clippy::let_and_return, clippy::approx_constant)] + pub fn forward( + &self, + tensor1: Tensor, + output_size: Tensor + ) -> Tensor { + let output_size_raw = output_size.to_data().value; + let mut output_size = [0usize; 2]; + + for (i, &x) in output_size_raw.iter().rev().take(2).rev().enumerate() { + output_size[i] = x.elem::() as usize; + } + + let tensor2 = interpolate(tensor1, output_size, self.resize.0.clone()); + + tensor2 + } + } + }; + + assert_tokens(graph.codegen(), expected); + } +} diff --git a/crates/burn-import/src/onnx/dim_inference.rs b/crates/burn-import/src/onnx/dim_inference.rs index ff7014c95a..0046c18f4a 100644 --- a/crates/burn-import/src/onnx/dim_inference.rs +++ b/crates/burn-import/src/onnx/dim_inference.rs @@ -59,6 +59,7 @@ pub fn dim_inference(node: &mut Node) { NodeType::ReduceSum => reduce_sum_update_outputs(node), NodeType::Relu => same_as_input(node), NodeType::Reshape => reshape_update_outputs(node), + NodeType::Resize => resize_update_outputs(node), NodeType::Shape => shape_update_outputs(node), NodeType::Sigmoid => same_as_input(node), NodeType::Sign => same_as_input(node), @@ -285,6 +286,33 @@ fn reshape_update_outputs(node: &mut Node) { } } +fn resize_update_outputs(node: &mut Node) { + let input = match &node.inputs[0].ty { + ArgType::Tensor(tensor) => tensor.clone(), + _ => panic!("Resize: invalid input type"), + }; + + let output = match &node.outputs[0].ty { + ArgType::Tensor(tensor) => tensor.clone(), + _ => panic!("Resize: invalid output type"), + }; + + let output_size = match &node.inputs[3].ty { + ArgType::Tensor(output_size) => output_size.clone(), + _ => panic!("Resize: invalid output_size type"), + }; + + if output_size.dim != 1 { + panic!("Resize: output_size must be 1D"); + } + + node.outputs[0].ty = ArgType::Tensor(TensorType { + dim: input.dim, + shape: None, // shape is calculated at runtime + ..output + }); +} + fn greater_update_outputs(node: &mut Node) { match &node.inputs[0].ty { ArgType::Tensor(tensor) => { diff --git a/crates/burn-import/src/onnx/from_onnx.rs b/crates/burn-import/src/onnx/from_onnx.rs index e8b65b39d2..e6fc091225 100644 --- a/crates/burn-import/src/onnx/from_onnx.rs +++ b/crates/burn-import/src/onnx/from_onnx.rs @@ -18,7 +18,7 @@ use super::ir::{ArgType, Argument, Node, NodeType}; use protobuf::Message; -const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 11] = [ +const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 12] = [ NodeType::BatchNormalization, NodeType::Clip, NodeType::Conv1d, @@ -26,6 +26,7 @@ const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 11] = [ NodeType::Dropout, NodeType::Expand, NodeType::Reshape, + NodeType::Resize, NodeType::Unsqueeze, NodeType::ReduceSum, NodeType::Slice, diff --git a/crates/burn-import/src/onnx/op_configuration.rs b/crates/burn-import/src/onnx/op_configuration.rs index e17de2c154..2f539776d7 100644 --- a/crates/burn-import/src/onnx/op_configuration.rs +++ b/crates/burn-import/src/onnx/op_configuration.rs @@ -6,6 +6,7 @@ use burn::nn::{ }; use super::ir::{ArgType, AttributeValue, Data, Node}; +use crate::burn::node::resize::ResizeMode; /// Create a Conv1dConfig from the attributes of the node pub fn conv1d_config(curr: &Node) -> Conv1dConfig { @@ -713,6 +714,28 @@ pub fn reshape_config(node: &Node) -> Vec { } } +pub fn resize_config(node: &Node) -> ResizeMode { + let mut mode: String = "".to_string(); + for (key, value) in node.attrs.iter() { + match key.as_str() { + "coordinate_transformation_mode" => {} + "cubic_coeff_a" => {} + "mode" => mode = value.clone().into_string(), + "nearest_mode" => {} + _ => {} + } + } + + let mode = match mode.as_str() { + "nearest" => ResizeMode::Nearest, + "linear" => ResizeMode::Linear, + "cubic" => ResizeMode::Cubic, + _ => panic!("Resize: invalid mode string, must be 'nearest', 'linear', or 'cubic'"), + }; + + mode +} + //Note this function should only execute if the second input is a constant //if it wasn't and the output shape was known, unsqueeze has been remapped to reshape pub fn unsqueeze_config(node: &Node) -> Vec { diff --git a/crates/burn-import/src/onnx/to_burn.rs b/crates/burn-import/src/onnx/to_burn.rs index b9784bbcda..9ff18d9306 100644 --- a/crates/burn-import/src/onnx/to_burn.rs +++ b/crates/burn-import/src/onnx/to_burn.rs @@ -42,6 +42,7 @@ use crate::{ random_uniform::RandomUniformNode, range::RangeNode, reshape::ReshapeNode, + resize::{ResizeNode, ResizeOptions}, slice::SliceNode, squeeze::SqueezeNode, sum::SumNode, @@ -64,7 +65,7 @@ use super::{ ir::{self, ArgType, Argument, Data, ElementType, OnnxGraph}, op_configuration::{ avg_pool2d_config, clip_config, concat_config, dropout_config, reshape_config, - softmax_config, + resize_config, softmax_config, }, }; @@ -291,6 +292,7 @@ impl OnnxGraph { NodeType::ReduceMean => graph.register(Self::reduce_mean_conversion(node)), NodeType::ReduceSum => graph.register(Self::reduce_sum_conversion(node)), NodeType::Reshape => graph.register(Self::reshape_conversion(node)), + NodeType::Resize => graph.register(Self::resize_conversion(node)), NodeType::Reciprocal => graph.register(Self::reciprocal_conversion(node)), NodeType::Shape => graph.register(Self::shape_conversion(node)), NodeType::Sigmoid => graph.register(Self::sigmoid_conversion(node)), @@ -586,6 +588,19 @@ impl OnnxGraph { ReshapeNode::new(input, output, shape) } + fn resize_conversion(node: Node) -> ResizeNode { + let name = &node.name; + + let input = node.inputs[0].to_tensor_type(); + let output_size = node.inputs[3].to_tensor_type(); + + let output = node.outputs.first().unwrap().to_tensor_type(); + + let mode = resize_config(&node); + + ResizeNode::new(name, input, output, output_size, ResizeOptions { mode }) + } + fn min_conversion(node: Node) -> BinaryNode { let lhs = node.inputs.first().unwrap().to_type(); let rhs = node.inputs.get(1).unwrap().to_type();