diff --git a/crates/burn-import/src/onnx/op_configuration.rs b/crates/burn-import/src/onnx/op_configuration.rs index e0c5cd16e2c..528a017730a 100644 --- a/crates/burn-import/src/onnx/op_configuration.rs +++ b/crates/burn-import/src/onnx/op_configuration.rs @@ -120,22 +120,6 @@ pub fn max_pool2d_config(curr: &Node) -> MaxPool2dConfig { .with_padding(padding) .with_dilation([dilations[0] as usize, dilations[1] as usize]) } -pub fn prelu_config(curr: &Node) -> PReluConfig { - let mut alpha = 0.01; - let mut num_parameters = 0; - for (key, value) in curr.attrs.iter() { - match key.as_str() { - "alpha" => alpha = value.clone().into_f32(), - "num_parameters" => num_parameters = value.clone().into_i32(), - _ => {} - } - } - - PReluConfig::new() - .with_num_parameters(num_parameters as usize) - .with_alpha(alpha as f64) -} - pub fn conv_transpose2d_config(curr: &Node) -> ConvTranspose2dConfig { let mut attrs = curr.attrs.clone(); let kernel_shape = attrs diff --git a/crates/burn-import/src/onnx/to_burn.rs b/crates/burn-import/src/onnx/to_burn.rs index ba7d8fbe154..3a28c2ecf4e 100644 --- a/crates/burn-import/src/onnx/to_burn.rs +++ b/crates/burn-import/src/onnx/to_burn.rs @@ -5,6 +5,7 @@ use std::{ }; use burn::{ + nn::PReluConfig, record::{FullPrecisionSettings, HalfPrecisionSettings, PrecisionSettings}, tensor::{DataSerialize, Element}, }; @@ -701,7 +702,7 @@ impl OnnxGraph { let input = node.inputs.first().unwrap().to_tensor_type(); let output = node.outputs.first().unwrap().to_tensor_type(); let weight = extract_data_serialize::(1, &node).unwrap(); - let config = prelu_config(&node); + let config = PReluConfig::new(); let name = &node.name; PReluNode::::new(name, input, output, weight, config) }