Skip to content

Commit

Permalink
added onnx tests and burn codegen tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Arjun31415 committed May 3, 2024
1 parent e63fea4 commit 7e05956
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 2 deletions.
1 change: 1 addition & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ fn main() {
.input("tests/recip/recip.onnx")
.input("tests/relu/relu.onnx")
.input("tests/leaky_relu/leaky_relu.onnx")
.input("tests/prelu/prelu.onnx")
.input("tests/reduce_max/reduce_max.onnx")
.input("tests/reduce_mean/reduce_mean.onnx")
.input("tests/reshape/reshape.onnx")
Expand Down
24 changes: 24 additions & 0 deletions crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ include_models!(
mul,
neg,
not,
prelu,
recip,
reduce_max,
reduce_mean,
Expand Down Expand Up @@ -658,6 +659,29 @@ mod tests {
assert_eq!(output.to_data(), expected);
}

#[test]
fn prelu() {
// Initialize the model without weights (because the exported file does not contain them)
let device = Default::default();
let model: prelu::Model<Backend> = prelu::Model::new(&device);

// Run the model
let input = Tensor::<Backend, 2>::from_floats(
[
[0.33669037, 0.0, 0.23446237],
[0.23033303, -1.122_856, -0.18632829],
],
&device,
);
let output = model.forward(input);
let expected = Data::from([
[0.33669037, 0.0, 0.23446237],
[0.23033303, -0.01122_856, -0.0018632829],
]);

assert_eq!(output.to_data(), expected);
}

#[test]
fn relu() {
// Initialize the model without weights (because the exported file does not contain them)
Expand Down
65 changes: 63 additions & 2 deletions crates/burn-import/src/burn/node/prelu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,72 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for PReluNode<PS> {
}
}
fn register_imports(&self, imports: &mut BurnImports) {
imports.register("burn::nn::prelu::PRelu");
imports.register("burn::nn::prelu::PReluConfig");
imports.register("burn::nn::PRelu");
imports.register("burn::nn::PReluConfig");
}

fn into_node(self) -> Node<PS> {
Node::PRelu(self)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::burn::{
graph::BurnGraph,
node::{conv1d::Conv1dNode, test::assert_tokens},

Check failure on line 106 in crates/burn-import/src/burn/node/prelu.rs

View workflow job for this annotation

GitHub Actions / tests (ubuntu-22.04, stable, std)

[clippy] reported by reviewdog 🐶 error: unused import: `conv1d::Conv1dNode` --> crates/burn-import/src/burn/node/prelu.rs:106:16 | 106 | node::{conv1d::Conv1dNode, test::assert_tokens}, | ^^^^^^^^^^^^^^^^^^ | = note: `-D unused-imports` implied by `-D warnings` = help: to override `-D warnings` add `#[allow(unused_imports)]` Raw Output: crates/burn-import/src/burn/node/prelu.rs:106:16:e:error: unused import: `conv1d::Conv1dNode` --> crates/burn-import/src/burn/node/prelu.rs:106:16 | 106 | node::{conv1d::Conv1dNode, test::assert_tokens}, | ^^^^^^^^^^^^^^^^^^ | = note: `-D unused-imports` implied by `-D warnings` = help: to override `-D warnings` add `#[allow(unused_imports)]` __END__
TensorType,
};
use burn::{
nn::conv::Conv1dConfig, nn::PaddingConfig1d, record::FullPrecisionSettings, tensor::Data,

Check failure on line 110 in crates/burn-import/src/burn/node/prelu.rs

View workflow job for this annotation

GitHub Actions / tests (ubuntu-22.04, stable, std)

[clippy] reported by reviewdog 🐶 error: unused imports: `nn::PaddingConfig1d`, `nn::conv::Conv1dConfig` --> crates/burn-import/src/burn/node/prelu.rs:110:9 | 110 | nn::conv::Conv1dConfig, nn::PaddingConfig1d, record::FullPrecisionSettings, tensor::Data, | ^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^ Raw Output: crates/burn-import/src/burn/node/prelu.rs:110:9:e:error: unused imports: `nn::PaddingConfig1d`, `nn::conv::Conv1dConfig` --> crates/burn-import/src/burn/node/prelu.rs:110:9 | 110 | nn::conv::Conv1dConfig, nn::PaddingConfig1d, record::FullPrecisionSettings, tensor::Data, | ^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^ __END__
};

#[test]
fn test_codegen() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();

graph.register(PReluNode::new(
"prelu",
TensorType::new_float("input", 4),
TensorType::new_float("output", 4),
Data::from([2.]).serialize(),
PReluConfig::new(),
));

graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);

let expected = quote! {
use burn::nn::prelu::PRelu;
use burn::nn::prelu::PReluConfig;
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
prelu: PRelu<B>,
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}
impl<B: Backend> Model<B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
let prelu = PReluConfig::new(1, 0.25).init(device);
Self {
prelu,
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
let output = self.prelu.forward(input);
output
}
}
};

assert_tokens(graph.codegen(), expected);
}
}

0 comments on commit 7e05956

Please sign in to comment.