From 7e0595680d2afa0e3a80d48b0be9399de1827ff3 Mon Sep 17 00:00:00 2001 From: Arjun31415 Date: Fri, 3 May 2024 21:59:00 +0530 Subject: [PATCH] added onnx tests and burn codegen tests --- crates/burn-import/onnx-tests/build.rs | 1 + .../onnx-tests/tests/onnx_tests.rs | 24 +++++++ crates/burn-import/src/burn/node/prelu.rs | 65 ++++++++++++++++++- 3 files changed, 88 insertions(+), 2 deletions(-) diff --git a/crates/burn-import/onnx-tests/build.rs b/crates/burn-import/onnx-tests/build.rs index 94b32209e9..3a2c9b685d 100644 --- a/crates/burn-import/onnx-tests/build.rs +++ b/crates/burn-import/onnx-tests/build.rs @@ -39,6 +39,7 @@ fn main() { .input("tests/recip/recip.onnx") .input("tests/relu/relu.onnx") .input("tests/leaky_relu/leaky_relu.onnx") + .input("tests/prelu/prelu.onnx") .input("tests/reduce_max/reduce_max.onnx") .input("tests/reduce_mean/reduce_mean.onnx") .input("tests/reshape/reshape.onnx") diff --git a/crates/burn-import/onnx-tests/tests/onnx_tests.rs b/crates/burn-import/onnx-tests/tests/onnx_tests.rs index 34ddfa5f87..d6f259e682 100644 --- a/crates/burn-import/onnx-tests/tests/onnx_tests.rs +++ b/crates/burn-import/onnx-tests/tests/onnx_tests.rs @@ -47,6 +47,7 @@ include_models!( mul, neg, not, + prelu, recip, reduce_max, reduce_mean, @@ -658,6 +659,29 @@ mod tests { assert_eq!(output.to_data(), expected); } + #[test] + fn prelu() { + // Initialize the model without weights (because the exported file does not contain them) + let device = Default::default(); + let model: prelu::Model = prelu::Model::new(&device); + + // Run the model + let input = Tensor::::from_floats( + [ + [0.33669037, 0.0, 0.23446237], + [0.23033303, -1.122_856, -0.18632829], + ], + &device, + ); + let output = model.forward(input); + let expected = Data::from([ + [0.33669037, 0.0, 0.23446237], + [0.23033303, -0.01122_856, -0.0018632829], + ]); + + assert_eq!(output.to_data(), expected); + } + #[test] fn relu() { // Initialize the model without weights (because the exported file does not contain them) diff --git a/crates/burn-import/src/burn/node/prelu.rs b/crates/burn-import/src/burn/node/prelu.rs index 32e3d3b52a..2ac5f8b85e 100644 --- a/crates/burn-import/src/burn/node/prelu.rs +++ b/crates/burn-import/src/burn/node/prelu.rs @@ -89,11 +89,72 @@ impl NodeCodegen for PReluNode { } } fn register_imports(&self, imports: &mut BurnImports) { - imports.register("burn::nn::prelu::PRelu"); - imports.register("burn::nn::prelu::PReluConfig"); + imports.register("burn::nn::PRelu"); + imports.register("burn::nn::PReluConfig"); } fn into_node(self) -> Node { Node::PRelu(self) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::burn::{ + graph::BurnGraph, + node::{conv1d::Conv1dNode, test::assert_tokens}, + TensorType, + }; + use burn::{ + nn::conv::Conv1dConfig, nn::PaddingConfig1d, record::FullPrecisionSettings, tensor::Data, + }; + + #[test] + fn test_codegen() { + let mut graph = BurnGraph::::default(); + + graph.register(PReluNode::new( + "prelu", + TensorType::new_float("input", 4), + TensorType::new_float("output", 4), + Data::from([2.]).serialize(), + PReluConfig::new(), + )); + + graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); + + let expected = quote! { + use burn::nn::prelu::PRelu; + use burn::nn::prelu::PReluConfig; + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + #[derive(Module, Debug)] + pub struct Model { + prelu: PRelu, + phantom: core::marker::PhantomData, + device: burn::module::Ignored, + } + impl Model { + #[allow(unused_variables)] + pub fn new(device: &B::Device) -> Self { + let prelu = PReluConfig::new(1, 0.25).init(device); + Self { + prelu, + phantom: core::marker::PhantomData, + device: burn::module::Ignored(device.clone()), + } + } + #[allow(clippy::let_and_return, clippy::approx_constant)] + pub fn forward(&self, input: Tensor) -> Tensor { + let output = self.prelu.forward(input); + output + } + } + }; + + assert_tokens(graph.codegen(), expected); + } +}