From 283e0d87cfa3d1e3932af16cd293dfeb8ea16524 Mon Sep 17 00:00:00 2001 From: Jachym Putta Date: Sat, 15 Jun 2024 17:49:44 -0400 Subject: [PATCH] feat: added reduce min onnx import --- crates/burn-import/SUPPORTED-ONNX-OPS.md | 2 +- crates/burn-import/onnx-tests/build.rs | 1 + .../onnx-tests/tests/onnx_tests.rs | 17 +++++ .../tests/reduce_min/reduce_min.onnx | Bin 0 -> 384 bytes .../onnx-tests/tests/reduce_min/reduce_min.py | 47 ++++++++++++ crates/burn-import/src/burn/node/unary.rs | 68 ++++++++++++++++++ crates/burn-import/src/onnx/dim_inference.rs | 25 +++++++ .../burn-import/src/onnx/op_configuration.rs | 42 +++++++++++ crates/burn-import/src/onnx/to_burn.rs | 9 +++ 9 files changed, 210 insertions(+), 1 deletion(-) create mode 100644 crates/burn-import/onnx-tests/tests/reduce_min/reduce_min.onnx create mode 100644 crates/burn-import/onnx-tests/tests/reduce_min/reduce_min.py diff --git a/crates/burn-import/SUPPORTED-ONNX-OPS.md b/crates/burn-import/SUPPORTED-ONNX-OPS.md index 12250c1223..49723acc13 100644 --- a/crates/burn-import/SUPPORTED-ONNX-OPS.md +++ b/crates/burn-import/SUPPORTED-ONNX-OPS.md @@ -141,7 +141,7 @@ represent the corresponding Burn Op. | [ReduceLogSumExp][134] | ❌ | ❌ | | [ReduceMax][135] | ✅ | ✅ | | [ReduceMean][136] | ✅ | ✅ | -| [ReduceMin][137] | ❌ | ✅ | +| [ReduceMin][137] | ✅ | ✅ | | [ReduceProd][138] | ❌ | ✅ | | [ReduceSum][139] | ✅ | ✅ | | [ReduceSumSquare][140] | ❌ | ❌ | diff --git a/crates/burn-import/onnx-tests/build.rs b/crates/burn-import/onnx-tests/build.rs index f46d218006..8d1c155c09 100644 --- a/crates/burn-import/onnx-tests/build.rs +++ b/crates/burn-import/onnx-tests/build.rs @@ -52,6 +52,7 @@ fn main() { .input("tests/leaky_relu/leaky_relu.onnx") .input("tests/prelu/prelu.onnx") .input("tests/reduce_max/reduce_max.onnx") + .input("tests/reduce_min/reduce_min.onnx") .input("tests/reduce_mean/reduce_mean.onnx") .input("tests/reduce_sum/reduce_sum_opset13.onnx") .input("tests/reduce_sum/reduce_sum_opset11.onnx") diff --git a/crates/burn-import/onnx-tests/tests/onnx_tests.rs b/crates/burn-import/onnx-tests/tests/onnx_tests.rs index f6b9a9b58a..a001d39150 100644 --- a/crates/burn-import/onnx-tests/tests/onnx_tests.rs +++ b/crates/burn-import/onnx-tests/tests/onnx_tests.rs @@ -62,6 +62,7 @@ include_models!( range, recip, reduce_max, + reduce_min, reduce_mean, reduce_sum_opset13, reduce_sum_opset11, @@ -728,6 +729,22 @@ mod tests { assert_eq!(output_value.to_data(), expected); } + #[test] + fn reduce_min() { + let device = Default::default(); + let model: reduce_min::Model = reduce_min::Model::new(&device); + + // Run the models + let input = Tensor::::from_floats([[[[1.0, 4.0, 9.0, 25.0]]]], &device); + let (output_scalar, output_tensor, output_value) = model.forward(input.clone()); + let expected_scalar = Data::from([1.]); + let expected = Data::from([[[[1.]]]]); + + assert_eq!(output_scalar.to_data(), expected_scalar); + assert_eq!(output_tensor.to_data(), input.to_data()); + assert_eq!(output_value.to_data(), expected); + } + #[test] fn reduce_mean() { let device = Default::default(); diff --git a/crates/burn-import/onnx-tests/tests/reduce_min/reduce_min.onnx b/crates/burn-import/onnx-tests/tests/reduce_min/reduce_min.onnx new file mode 100644 index 0000000000000000000000000000000000000000..b860f2ee07c5c6cbdc78a83d431ed77ad732a884 GIT binary patch literal 384 zcmdvQtwFQZjRkB^VYkGI9B0)o&!lgU}yuh-?Bk7fWJAYOw?30!DTew f8Va#-u>d6*lBBs9jf8Ni1WFp?kYsdX5)c3Y1}j~| literal 0 HcmV?d00001 diff --git a/crates/burn-import/onnx-tests/tests/reduce_min/reduce_min.py b/crates/burn-import/onnx-tests/tests/reduce_min/reduce_min.py new file mode 100644 index 0000000000..c6cc205f95 --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/reduce_min/reduce_min.py @@ -0,0 +1,47 @@ + +#!/usr/bin/env python3 + +# used to generate model: onnx-tests/tests/reduce_min/reduce_min.onnx + +import torch +import torch.nn as nn + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x): + return ( + # ReduceMin, keepdims=0, axes=None + torch.min(x), + # ReduceMin, keepdims=1, axes=[1] + torch.min(x, dim=1, keepdim=True).values, + # ReduceMin, keepdims=1, axes=[-1] + torch.min(x, dim=-1, keepdim=True).values, + ) + + +def main(): + # Set random seed for reproducibility + torch.manual_seed(0) + + # Export to onnx + model = Model() + model.eval() + device = torch.device("cpu") + onnx_name = "reduce_min.onnx" + test_input = torch.tensor([[[[1.0, 4.0, 9.0, 25.0]]]], device=device) + + torch.onnx.export(model, test_input, onnx_name, verbose=False, opset_version=16) + + print(f"Finished exporting model to {onnx_name}") + + # Output some test data for use in the test + print(f"Test input data: {test_input}") + output = model.forward(*test_input) + print(f"Test output data: {output}") + + +if __name__ == "__main__": + main() diff --git a/crates/burn-import/src/burn/node/unary.rs b/crates/burn-import/src/burn/node/unary.rs index d4433d09b1..3d05bf064f 100644 --- a/crates/burn-import/src/burn/node/unary.rs +++ b/crates/burn-import/src/burn/node/unary.rs @@ -33,6 +33,7 @@ pub enum UnaryNodeKind { Neg, Not, ReduceMax, + ReduceMin, ReduceMean, ReduceSum, Reciprocal, @@ -62,6 +63,7 @@ impl UnaryNodeKind { Self::Neg => "neg", Self::Not => "not", Self::ReduceMax => "reduce_max", + Self::ReduceMin => "reduce_min", Self::ReduceMean => "reduce_mean", Self::ReduceSum => "reduce_sum", Self::Reciprocal => "reciprocal", @@ -331,6 +333,35 @@ impl UnaryNode { } } + pub(crate) fn reduce_min(input: Type, output: Type, dim: Option) -> Self { + if let Type::Tensor(ref tensor) = output { + if let Some(dim) = dim { + if tensor.kind == TensorKind::Bool { + // Min is only implemented on numeric tensors + panic!("ReduceMin is not supported for boolean"); + } + // ReduceMin, keepdims=1, axes=[dim] + let dim = dim.to_tokens(); + Self::new( + input, + output, + UnaryNodeKind::ReduceMin, + Rc::new(move |input| quote! { #input.min_dim(#dim) }), + ) + } else { + // ReduceMin, keepdims=0, axes=None + Self::new( + input, + output, + UnaryNodeKind::ReduceMin, + Rc::new(move |input| quote! { #input.min() }), + ) + } + } else { + panic!("ReduceMin only supports tensor output"); + } + } + pub(crate) fn reduce_mean(input: Type, output: Type, dim: Option) -> Self { // ReduceMean is constrained to numeric tensors, so no need to check for bool. if let Type::Tensor(_) = output { @@ -629,6 +660,43 @@ mod tests { ); } + #[test] + fn test_unary_codegen_reduce_min() { + one_node_graph( + UnaryNode::reduce_min( + Type::Tensor(TensorType::new_float("tensor1", 4)), + Type::Tensor(TensorType::new_float("tensor2", 4)), + Some(1), + ), + quote! { + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = tensor1.min_dim(1); + + tensor2 + } + }, + vec!["tensor1".to_string()], + vec!["tensor2".to_string()], + ); + + one_node_graph( + UnaryNode::reduce_min( + Type::Tensor(TensorType::new_float("tensor1", 4)), + Type::Tensor(TensorType::new_float("tensor2", 1)), + None, + ), + quote! { + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = tensor1.min(); + + tensor2 + } + }, + vec!["tensor1".to_string()], + vec!["tensor2".to_string()], + ); + } + #[test] fn test_unary_codegen_reduce_mean() { one_node_graph( diff --git a/crates/burn-import/src/onnx/dim_inference.rs b/crates/burn-import/src/onnx/dim_inference.rs index 0046c18f4a..94d66d22f6 100644 --- a/crates/burn-import/src/onnx/dim_inference.rs +++ b/crates/burn-import/src/onnx/dim_inference.rs @@ -55,6 +55,7 @@ pub fn dim_inference(node: &mut Node) { NodeType::Range => range_update_outputs(node), NodeType::Reciprocal => same_as_input(node), NodeType::ReduceMax => reduce_max_update_outputs(node), + NodeType::ReduceMin => reduce_min_update_outputs(node), NodeType::ReduceMean => reduce_mean_update_outputs(node), NodeType::ReduceSum => reduce_sum_update_outputs(node), NodeType::Relu => same_as_input(node), @@ -716,6 +717,30 @@ fn reduce_max_update_outputs(node: &mut Node) { } } +fn reduce_min_update_outputs(node: &mut Node) { + if node.inputs.len() != 1 { + panic!("ReduceMin: multiple inputs are not supported"); + } + let node_input = &mut node.inputs[0]; + let tensor = match node_input.clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + let dim_only = match node.attrs.get("axes") { + Some(value) => match &value { + AttributeValue::Int64(_) => true, + AttributeValue::Int64s(ints) => ints.len() == 1, + _ => false, + }, + None => false, + }; + if dim_only { + node.outputs[0].ty = ArgType::Tensor(tensor); + } else { + node.outputs[0].ty = ArgType::Tensor(TensorType { dim: 1, ..tensor }); + } +} + /// Infers the shape of a ReduceSum node and replaces the shape of the output tensor. fn reduce_sum_update_outputs(node: &mut Node) { let node_input = &mut node.inputs[0]; diff --git a/crates/burn-import/src/onnx/op_configuration.rs b/crates/burn-import/src/onnx/op_configuration.rs index 2f539776d7..58fc2de9a1 100644 --- a/crates/burn-import/src/onnx/op_configuration.rs +++ b/crates/burn-import/src/onnx/op_configuration.rs @@ -902,6 +902,48 @@ pub fn reduce_max_config(node: &Node) -> Option { } } +pub fn reduce_min_config(node: &Node) -> Option { + let mut axes = Vec::new(); + let mut keepdims = 1; + + let tensor = match node.inputs.first().unwrap().clone().ty { + ArgType::Tensor(tensor) => tensor, + _ => panic!("Only tensor input is valid"), + }; + + // Extract the attributes + for (key, value) in node.attrs.iter() { + match key.as_str() { + "axes" => axes = value.clone().into_i64s(), + "keepdims" => keepdims = value.clone().into_i64(), + _ => {} + } + } + + if axes.len() > 1 { + panic!("ReduceMin: reducing on multiple dimensions is not supported") + } + + if axes.is_empty() && keepdims == 1 { + panic!("ReduceMin: axes must be provided with keepdims") + } + + if !axes.is_empty() && keepdims == 0 { + panic!("ReduceMin: the reduce operation must preserve the reduced dimension") + } + + if axes.is_empty() { + None + } else { + let mut dim = axes[0]; + + if dim < 0 { + dim += tensor.dim as i64; + } + Some(dim as usize) + } +} + pub fn reduce_mean_config(node: &Node) -> Option { let mut axes = Vec::new(); let mut keepdims = 1; diff --git a/crates/burn-import/src/onnx/to_burn.rs b/crates/burn-import/src/onnx/to_burn.rs index 9ff18d9306..72137e0a94 100644 --- a/crates/burn-import/src/onnx/to_burn.rs +++ b/crates/burn-import/src/onnx/to_burn.rs @@ -289,6 +289,7 @@ impl OnnxGraph { NodeType::Min => graph.register(Self::min_conversion(node)), NodeType::Range => graph.register(Self::range_conversion(node)), NodeType::ReduceMax => graph.register(Self::reduce_max_conversion(node)), + NodeType::ReduceMin => graph.register(Self::reduce_min_conversion(node)), NodeType::ReduceMean => graph.register(Self::reduce_mean_conversion(node)), NodeType::ReduceSum => graph.register(Self::reduce_sum_conversion(node)), NodeType::Reshape => graph.register(Self::reshape_conversion(node)), @@ -640,6 +641,14 @@ impl OnnxGraph { UnaryNode::reduce_max(input, output, dim) } + fn reduce_min_conversion(node: Node) -> UnaryNode { + let input = node.inputs.first().unwrap().to_type(); + let output = node.outputs.first().unwrap().to_type(); + let dim = reduce_min_config(&node); + + UnaryNode::reduce_min(input, output, dim) + } + fn reduce_mean_conversion(node: Node) -> UnaryNode { let input = node.inputs.first().unwrap().to_type(); let output = node.outputs.first().unwrap().to_type();