Skip to content

Commit

Permalink
added unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Arjun31415 committed May 3, 2024
1 parent e63fea4 commit 93c7171
Showing 1 changed file with 61 additions and 0 deletions.
61 changes: 61 additions & 0 deletions crates/burn-import/src/burn/node/prelu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,64 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for PReluNode<PS> {
Node::PRelu(self)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::burn::{
graph::BurnGraph,
node::{conv1d::Conv1dNode, test::assert_tokens},

Check failure on line 106 in crates/burn-import/src/burn/node/prelu.rs

View workflow job for this annotation

GitHub Actions / tests (ubuntu-22.04, stable, std)

[clippy] reported by reviewdog 🐶 error: unused import: `conv1d::Conv1dNode` --> crates/burn-import/src/burn/node/prelu.rs:106:16 | 106 | node::{conv1d::Conv1dNode, test::assert_tokens}, | ^^^^^^^^^^^^^^^^^^ | = note: `-D unused-imports` implied by `-D warnings` = help: to override `-D warnings` add `#[allow(unused_imports)]` Raw Output: crates/burn-import/src/burn/node/prelu.rs:106:16:e:error: unused import: `conv1d::Conv1dNode` --> crates/burn-import/src/burn/node/prelu.rs:106:16 | 106 | node::{conv1d::Conv1dNode, test::assert_tokens}, | ^^^^^^^^^^^^^^^^^^ | = note: `-D unused-imports` implied by `-D warnings` = help: to override `-D warnings` add `#[allow(unused_imports)]` __END__
TensorType,
};
use burn::{
nn::conv::Conv1dConfig, nn::PaddingConfig1d, record::FullPrecisionSettings, tensor::Data,

Check failure on line 110 in crates/burn-import/src/burn/node/prelu.rs

View workflow job for this annotation

GitHub Actions / tests (ubuntu-22.04, stable, std)

[clippy] reported by reviewdog 🐶 error: unused imports: `nn::PaddingConfig1d`, `nn::conv::Conv1dConfig` --> crates/burn-import/src/burn/node/prelu.rs:110:9 | 110 | nn::conv::Conv1dConfig, nn::PaddingConfig1d, record::FullPrecisionSettings, tensor::Data, | ^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^ Raw Output: crates/burn-import/src/burn/node/prelu.rs:110:9:e:error: unused imports: `nn::PaddingConfig1d`, `nn::conv::Conv1dConfig` --> crates/burn-import/src/burn/node/prelu.rs:110:9 | 110 | nn::conv::Conv1dConfig, nn::PaddingConfig1d, record::FullPrecisionSettings, tensor::Data, | ^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^ __END__
};

#[test]
fn test_codegen() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();

graph.register(PReluNode::new(
"prelu",
TensorType::new_float("input", 4),
TensorType::new_float("output", 4),
Data::from([2.]).serialize(),
PReluConfig::new(),
));

graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);

let expected = quote! {
use burn::nn::prelu::PRelu;
use burn::nn::prelu::PReluConfig;
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
prelu: PRelu<B>,
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}
impl<B: Backend> Model<B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
let prelu = PReluConfig::new(1, 0.25).init(device);
Self {
prelu,
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
let output = self.prelu.forward(input);
output
}
}
};

assert_tokens(graph.codegen(), expected);
}
}

0 comments on commit 93c7171

Please sign in to comment.