diff --git a/crates/burn-import/SUPPORTED-ONNX-OPS.md b/crates/burn-import/SUPPORTED-ONNX-OPS.md index 6ee76b8bab..dacae6e8cc 100644 --- a/crates/burn-import/SUPPORTED-ONNX-OPS.md +++ b/crates/burn-import/SUPPORTED-ONNX-OPS.md @@ -63,7 +63,7 @@ represent the corresponding Burn Op. | [EyeLike][55] | ❌ | ❌ | | [Flatten][56] | ✅ | ✅ | | [Floor][57] | ❌ | ❌ | -| [Gather][58] | ❌ | ✅ | +| [Gather][58] | ✅ | ✅ | | [GatherElements][59] | ✅ | ✅ | | [GatherND][60] | ❌ | ❌ | | [Gelu][61] | ✅ | ✅ | diff --git a/crates/burn-import/onnx-tests/build.rs b/crates/burn-import/onnx-tests/build.rs index f06e1da977..391a085a7d 100644 --- a/crates/burn-import/onnx-tests/build.rs +++ b/crates/burn-import/onnx-tests/build.rs @@ -27,6 +27,7 @@ fn main() { .input("tests/exp/exp.onnx") .input("tests/flatten/flatten.onnx") .input("tests/gather/gather.onnx") + .input("tests/gather_elements/gather_elements.onnx") .input("tests/gelu/gelu.onnx") .input("tests/global_avr_pool/global_avr_pool.onnx") .input("tests/layer_norm/layer_norm.onnx") diff --git a/crates/burn-import/onnx-tests/tests/gather/gather.onnx b/crates/burn-import/onnx-tests/tests/gather/gather.onnx index 72713542ad..9589d8410e 100644 --- a/crates/burn-import/onnx-tests/tests/gather/gather.onnx +++ b/crates/burn-import/onnx-tests/tests/gather/gather.onnx @@ -1,16 +1,16 @@ -pytorch2.2.2: -a -onnx::GatherElements_0 -onnx::GatherElements_12/GatherElements"GatherElements* +pytorch2.1.1: +A +onnx::Gather_0 +onnx::Gather_12/Gather"Gather* axis -main_graphZ( -onnx::GatherElements_0 +main_graphZ +onnx::Gather_0   -Z( -onnx::GatherElements_1 -  - +Z +onnx::Gather_1 + + b 2  diff --git a/crates/burn-import/onnx-tests/tests/gather/gather.py b/crates/burn-import/onnx-tests/tests/gather/gather.py index 0f5ffbdbf8..39688d34d6 100644 --- a/crates/burn-import/onnx-tests/tests/gather/gather.py +++ b/crates/burn-import/onnx-tests/tests/gather/gather.py @@ -11,8 +11,8 @@ def __init__(self): super(Model, self).__init__() def forward(self, x, index): - x = torch.gather(x, 1, index) - return x + gathered = torch.index_select(x, 1, index) + return gathered def main(): @@ -24,8 +24,9 @@ def main(): model.eval() device = torch.device("cpu") onnx_name = "gather.onnx" - dummy_input = torch.randn(2, 2, device=device) - dummy_index = torch.randint(high=2, size=(2, 2), device=device, dtype=torch.int64) + + dummy_input = torch.randn(2, 3, device=device) + dummy_index = torch.tensor([0, 2], device=device, dtype=torch.int64) torch.onnx.export(model, (dummy_input, dummy_index), onnx_name, verbose=False, opset_version=16) @@ -33,10 +34,9 @@ def main(): print("Finished exporting model to {}".format(onnx_name)) # Output some test data for use in the test - test_input = torch.tensor([[1.0, 2.0], - [3.0, 4.0]]) - test_index = torch.tensor([[0, 0], - [1, 0]]) + test_input = torch.tensor([[1.0, 2.0, 3.0], + [4.0, 5.0, 6.0]]) + test_index = torch.tensor([0, 2], dtype=torch.int64) print("Test input data: {}, {}".format(test_input, test_index)) output = model.forward(test_input, test_index) diff --git a/crates/burn-import/onnx-tests/tests/gather_elements/gather_elements.onnx b/crates/burn-import/onnx-tests/tests/gather_elements/gather_elements.onnx new file mode 100644 index 0000000000..ffe56710a3 --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/gather_elements/gather_elements.onnx @@ -0,0 +1,18 @@ +pytorch2.1.1: +a +onnx::GatherElements_0 +onnx::GatherElements_12/GatherElements"GatherElements* +axis +main_graphZ( +onnx::GatherElements_0 +  + +Z( +onnx::GatherElements_1 +  + +b +2 +  + +B \ No newline at end of file diff --git a/crates/burn-import/onnx-tests/tests/gather_elements/gather_elements.py b/crates/burn-import/onnx-tests/tests/gather_elements/gather_elements.py new file mode 100644 index 0000000000..0c1c7afdc8 --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/gather_elements/gather_elements.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 + +# used to generate model: onnx-tests/tests/gather/gather_elements.onnx +# note that the ONNX specification for `GatherElements` corresponds to PyTorch's/Burn's `gather` function + +import torch +import torch.nn as nn + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, index): + x = torch.gather(x, 1, index) + return x + + +def main(): + # Set random seed for reproducibility + torch.manual_seed(0) + + # Export to onnx + model = Model() + model.eval() + device = torch.device("cpu") + onnx_name = "gather_elements.onnx" + dummy_input = torch.randn(2, 2, device=device) + dummy_index = torch.randint(high=2, size=(2, 2), device=device, dtype=torch.int64) + + torch.onnx.export(model, (dummy_input, dummy_index), onnx_name, + verbose=False, opset_version=16) + + print("Finished exporting model to {}".format(onnx_name)) + + # Output some test data for use in the test + test_input = torch.tensor([[1.0, 2.0], + [3.0, 4.0]]) + test_index = torch.tensor([[0, 0], + [1, 0]]) + + print("Test input data: {}, {}".format(test_input, test_index)) + output = model.forward(test_input, test_index) + print("Test output data: {}".format(output)) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/crates/burn-import/onnx-tests/tests/onnx_tests.rs b/crates/burn-import/onnx-tests/tests/onnx_tests.rs index cbfb5aa4c6..7188e1c396 100644 --- a/crates/burn-import/onnx-tests/tests/onnx_tests.rs +++ b/crates/burn-import/onnx-tests/tests/onnx_tests.rs @@ -37,6 +37,7 @@ include_models!( expand, flatten, gather, + gather_elements, gelu, global_avr_pool, layer_norm, @@ -390,9 +391,23 @@ mod tests { #[test] fn gather() { - // Initialize the model with weights (loaded from the exported file) let model: gather::Model = gather::Model::default(); + let device = Default::default(); + + let input = Tensor::::from_floats([[1., 2., 3.], [4., 5., 6.]], &device); + let index = Tensor::::from_ints([0, 2], &device); + let output = model.forward(input, index); + let expected = Data::from([[1., 3.], [4., 6.]]); + + assert_eq!(output.to_data(), expected); + } + + #[test] + fn gather_elements() { + // Initialize the model with weights (loaded from the exported file) + let model: gather_elements::Model = gather_elements::Model::default(); + let device = Default::default(); // Run the model let input = Tensor::::from_floats([[1., 2.], [3., 4.]], &device); diff --git a/crates/burn-import/src/burn/node/base.rs b/crates/burn-import/src/burn/node/base.rs index b2722410b9..027e98ee93 100644 --- a/crates/burn-import/src/burn/node/base.rs +++ b/crates/burn-import/src/burn/node/base.rs @@ -3,11 +3,12 @@ use super::{ batch_norm::BatchNormNode, binary::BinaryNode, clip::ClipNode, concat::ConcatNode, constant::ConstantNode, conv1d::Conv1dNode, conv2d::Conv2dNode, conv_transpose_2d::ConvTranspose2dNode, dropout::DropoutNode, expand::ExpandNode, - gather::GatherNode, global_avg_pool::GlobalAvgPoolNode, layer_norm::LayerNormNode, - linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode, max_pool1d::MaxPool1dNode, - max_pool2d::MaxPool2dNode, prelu::PReluNode, random_normal::RandomNormalNode, - random_uniform::RandomUniformNode, range::RangeNode, reshape::ReshapeNode, - squeeze::SqueezeNode, sum::SumNode, unary::UnaryNode, unsqueeze::UnsqueezeNode, + gather::GatherNode, gather_elements::GatherElementsNode, global_avg_pool::GlobalAvgPoolNode, + layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode, + max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, prelu::PReluNode, + random_normal::RandomNormalNode, random_uniform::RandomUniformNode, range::RangeNode, + reshape::ReshapeNode, squeeze::SqueezeNode, sum::SumNode, unary::UnaryNode, + unsqueeze::UnsqueezeNode, }; use crate::burn::{BurnImports, Scope, Type}; use burn::backend::NdArray; @@ -92,6 +93,7 @@ pub enum Node { Dropout(DropoutNode), Expand(ExpandNode), Gather(GatherNode), + GatherElements(GatherElementsNode), GlobalAvgPool(GlobalAvgPoolNode), LayerNorm(LayerNormNode), Linear(LinearNode), @@ -128,6 +130,7 @@ macro_rules! match_all { Node::Dropout(node) => $func(node), Node::Expand(node) => $func(node), Node::Gather(node) => $func(node), + Node::GatherElements(node) => $func(node), Node::GlobalAvgPool(node) => $func(node), Node::LayerNorm(node) => $func(node), Node::Linear(node) => $func(node), @@ -174,6 +177,7 @@ impl Node { Node::Dropout(_) => "dropout", Node::Expand(_) => "expand", Node::Gather(_) => "gather", + Node::GatherElements(_) => "gather_elements", Node::GlobalAvgPool(_) => "global_avg_pool", Node::LayerNorm(_) => "layer_norm", Node::Linear(_) => "linear", diff --git a/crates/burn-import/src/burn/node/gather.rs b/crates/burn-import/src/burn/node/gather.rs index d55b7bb6cb..7c384548d3 100644 --- a/crates/burn-import/src/burn/node/gather.rs +++ b/crates/burn-import/src/burn/node/gather.rs @@ -35,7 +35,7 @@ impl NodeCodegen for GatherNode { let output = &self.output.name; quote! { - let #output = #input.gather(#dim, #index); + let #output = #input.select(#dim, #index); } } @@ -62,9 +62,9 @@ mod tests { graph.register(GatherNode::new( TensorType::new_float("tensor1", 2), - TensorType::new_int("tensor2", 2), + TensorType::new_int("tensor2", 1), TensorType::new_float("tensor3", 2), - 1, + 0, )); graph.register_input_output( @@ -98,9 +98,9 @@ mod tests { pub fn forward( &self, tensor1: Tensor, - tensor2: Tensor + tensor2: Tensor ) -> Tensor { - let tensor3 = tensor1.gather(1, tensor2); + let tensor3 = tensor1.select(0, tensor2); tensor3 } diff --git a/crates/burn-import/src/burn/node/gather_elements.rs b/crates/burn-import/src/burn/node/gather_elements.rs new file mode 100644 index 0000000000..2d509a0f00 --- /dev/null +++ b/crates/burn-import/src/burn/node/gather_elements.rs @@ -0,0 +1,112 @@ +use super::{Node, NodeCodegen}; +use crate::burn::{TensorType, ToTokens, Type}; + +use burn::record::PrecisionSettings; +use quote::quote; + +#[derive(Debug, Clone, new)] +pub struct GatherElementsNode { + pub input: TensorType, + pub index: TensorType, + pub output: TensorType, + pub dim: usize, +} + +impl NodeCodegen for GatherElementsNode { + fn output_types(&self) -> Vec { + vec![Type::Tensor(self.output.clone())] + } + + fn input_types(&self) -> Vec { + vec![ + Type::Tensor(self.input.clone()), + Type::Tensor(self.index.clone()), + ] + } + + fn forward( + &self, + scope: &mut crate::burn::Scope, + node_position: usize, + ) -> proc_macro2::TokenStream { + let dim = self.dim.to_tokens(); + let input = scope.tensor_use_owned(&self.input, node_position); + let index = scope.tensor_use_owned(&self.index, node_position); + let output = &self.output.name; + + quote! { + let #output = #input.gather(#dim, #index); + } + } + + fn into_node(self) -> super::Node { + Node::GatherElements(self) + } +} + +#[cfg(test)] +mod tests { + + use burn::record::FullPrecisionSettings; + + use super::*; + use crate::burn::{ + graph::BurnGraph, + node::{gather_elements::GatherElementsNode, test::assert_tokens}, + TensorType, + }; + + #[test] + fn test_codegen_gather_elements() { + let mut graph = BurnGraph::::default(); + + graph.register(GatherElementsNode::new( + TensorType::new_float("tensor1", 2), + TensorType::new_int("tensor2", 2), + TensorType::new_float("tensor3", 2), + 1, + )); + + graph.register_input_output( + vec!["tensor1".to_string(), "tensor2".to_string()], + vec!["tensor3".to_string()], + ); + + let expected = quote! { + use burn::tensor::Int; + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + device: burn::module::Ignored, + } + + impl Model { + #[allow(unused_variables)] + pub fn new(device: &B::Device) -> Self { + Self { + phantom: core::marker::PhantomData, + device: burn::module::Ignored(device.clone()), + } + } + + #[allow(clippy::let_and_return, clippy::approx_constant)] + pub fn forward( + &self, + tensor1: Tensor, + tensor2: Tensor + ) -> Tensor { + let tensor3 = tensor1.gather(1, tensor2); + + tensor3 + } + } + }; + + assert_tokens(graph.codegen(), expected); + } +} diff --git a/crates/burn-import/src/burn/node/mod.rs b/crates/burn-import/src/burn/node/mod.rs index 5b3b416416..b8f44750bd 100644 --- a/crates/burn-import/src/burn/node/mod.rs +++ b/crates/burn-import/src/burn/node/mod.rs @@ -14,6 +14,7 @@ pub(crate) mod conv_transpose_2d; pub(crate) mod dropout; pub(crate) mod expand; pub(crate) mod gather; +pub(crate) mod gather_elements; pub(crate) mod global_avg_pool; pub(crate) mod layer_norm; pub(crate) mod linear; diff --git a/crates/burn-import/src/burn/node/sum.rs b/crates/burn-import/src/burn/node/sum.rs index 78b95273ed..ad0a2601f8 100644 --- a/crates/burn-import/src/burn/node/sum.rs +++ b/crates/burn-import/src/burn/node/sum.rs @@ -53,7 +53,7 @@ mod tests { }; #[test] - fn test_codegen_concat() { + fn test_codegen_sum() { let mut graph = BurnGraph::::default(); graph.register(SumNode::new( diff --git a/crates/burn-import/src/onnx/dim_inference.rs b/crates/burn-import/src/onnx/dim_inference.rs index 4e20941454..a70502f9ca 100644 --- a/crates/burn-import/src/onnx/dim_inference.rs +++ b/crates/burn-import/src/onnx/dim_inference.rs @@ -33,6 +33,7 @@ pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) { NodeType::Expand => expand_update_outputs(node), NodeType::Flatten => flatten_update_outputs(node), NodeType::Gelu => same_as_input(node), + NodeType::Gather => gather_update_outputs(node), NodeType::GatherElements => same_as_input(node), NodeType::GlobalAveragePool => same_as_input(node), NodeType::ConvTranspose2d => conv_transpose2d_update_outputs(node), @@ -717,3 +718,32 @@ fn where_update_outputs(node: &mut Node) { _ => panic!("Only tensor input is valid"), } } + +fn gather_update_outputs(node: &mut Node) { + if node.inputs.len() != 2 { + panic!("Gather requires two inputs: data and indices"); + } + + let input_tensor = match &node.inputs[0].ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + let indices_tensor = match &node.inputs[1].ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor indices is valid"), + }; + + if indices_tensor.dim != 1 { + panic!("Gather: indices tensor rank above 1 not supported") + } + + // Output of rank q+(r-1), where q is rank of indices tensor and r is rank of input + let output_rank = indices_tensor.dim + input_tensor.dim - 1; + + node.outputs[0].ty = ArgType::Tensor(TensorType { + dim: output_rank, + shape: None, + elem_type: input_tensor.elem_type.clone(), + }); +} diff --git a/crates/burn-import/src/onnx/to_burn.rs b/crates/burn-import/src/onnx/to_burn.rs index 8b0b214ed1..aceac8f641 100644 --- a/crates/burn-import/src/onnx/to_burn.rs +++ b/crates/burn-import/src/onnx/to_burn.rs @@ -29,6 +29,7 @@ use crate::{ dropout::DropoutNode, expand::ExpandNode, gather::GatherNode, + gather_elements::GatherElementsNode, global_avg_pool::GlobalAvgPoolNode, layer_norm::LayerNormNode, linear::LinearNode, @@ -274,7 +275,8 @@ impl OnnxGraph { NodeType::Relu => graph.register(Self::relu_conversion(node)), NodeType::Gelu => graph.register(Self::gelu_conversion(node)), NodeType::Flatten => graph.register(Self::flatten_conversion(node)), - NodeType::GatherElements => graph.register(Self::gather_conversion(node)), + NodeType::Gather => graph.register(Self::gather_conversion(node)), + NodeType::GatherElements => graph.register(Self::gather_elements_conversion(node)), NodeType::Log => graph.register(Self::log_conversion(node)), NodeType::LeakyRelu => graph.register(Self::leaky_relu_conversion(node)), NodeType::LogSoftmax => graph.register(Self::log_softmax_conversion(node)), @@ -550,6 +552,15 @@ impl OnnxGraph { GatherNode::new(input, index, output, dim) } + fn gather_elements_conversion(node: Node) -> GatherElementsNode { + let input = node.inputs.first().unwrap().to_tensor_type(); + let index = node.inputs.get(1).unwrap().to_tensor_type(); + let output = node.outputs.first().unwrap().to_tensor_type(); + let dim = gather_config(&node); + + GatherElementsNode::new(input, index, output, dim) + } + fn transpose_conversion(node: Node) -> UnaryNode { let input = node.inputs.first().unwrap().to_type(); let output = node.outputs.first().unwrap().to_type();