Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Arjun31415 committed May 3, 2024
1 parent 7e05956 commit 6dc886f
Showing 1 changed file with 7 additions and 16 deletions.
23 changes: 7 additions & 16 deletions crates/burn-import/src/burn/node/prelu.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::{Node, NodeCodegen, SerializationBackend};
use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type};
use crate::burn::{BurnImports, OtherType, Scope, TensorType, Type};
use burn::{
module::{Param, ParamId},
nn::{PReluConfig, PReluRecord},
Expand Down Expand Up @@ -55,11 +55,8 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for PReluNode<PS> {

fn field_init(&self) -> Option<TokenStream> {
let name = &self.field.name;

let num_parameters = self.config.num_parameters.to_tokens();
let alpha = self.config.alpha.to_tokens();
let tokens = quote! {
let #name = PReluConfig::new(#num_parameters, #alpha)
let #name = PReluConfig::new()
.init(device);
};

Expand Down Expand Up @@ -101,14 +98,8 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for PReluNode<PS> {
#[cfg(test)]
mod tests {
use super::*;
use crate::burn::{
graph::BurnGraph,
node::{conv1d::Conv1dNode, test::assert_tokens},
TensorType,
};
use burn::{
nn::conv::Conv1dConfig, nn::PaddingConfig1d, record::FullPrecisionSettings, tensor::Data,
};
use crate::burn::{graph::BurnGraph, node::test::assert_tokens, TensorType};
use burn::{record::FullPrecisionSettings, tensor::Data};

#[test]
fn test_codegen() {
Expand All @@ -125,8 +116,8 @@ mod tests {
graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);

let expected = quote! {
use burn::nn::prelu::PRelu;
use burn::nn::prelu::PReluConfig;
use burn::nn::PRelu;
use burn::nn::PReluConfig;
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
Expand All @@ -140,7 +131,7 @@ mod tests {
impl<B: Backend> Model<B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
let prelu = PReluConfig::new(1, 0.25).init(device);
let prelu = PReluConfig::new().init(device);
Self {
prelu,
phantom: core::marker::PhantomData,
Expand Down

0 comments on commit 6dc886f

Please sign in to comment.