Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PReLu ONNX import #1721

Merged
merged 5 commits into from
May 4, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added crates/burn-import/onnx-tests/tests/prelu/prelu.onnx
Binary file not shown.
49 changes: 49 additions & 0 deletions crates/burn-import/onnx-tests/tests/prelu/prelu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#!/usr/bin/env python3

# used to generate model: prelu.onnx

import torch
import torch.nn as nn


class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.relu1 = nn.PReLU()

def forward(self, x):
x = self.relu1(x)
return x


def main():

# Set seed for reproducibility
torch.manual_seed(42)

torch.set_printoptions(precision=8)

# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")

file_name = "prelu.onnx"
test_input = torch.randn(2, 3, device=device)
torch.onnx.export(model, test_input, file_name,
verbose=False, opset_version=16)

print("Finished exporting model to {}".format(file_name))

# Output some test data for use in the test
print("Test input data of ones: {}".format(test_input))
print("Test input data shape of ones: {}".format(test_input.shape))
output = model.forward(test_input)
print("Test output data shape: {}".format(output.shape))

print("Test output: {}".format(output))


if __name__ == '__main__':
main()

4 changes: 4 additions & 0 deletions crates/burn-import/src/burn/node/base.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use super::layer_norm::LayerNormNode;
use super::mask_where::WhereNode;
use super::prelu::PReluNode;
use super::unsqueeze::UnsqueezeNode;
use super::{
avg_pool2d::AvgPool2dNode, batch_norm::BatchNormNode, binary::BinaryNode, clip::ClipNode,
Expand Down Expand Up @@ -85,6 +86,7 @@ pub enum Node<PS: PrecisionSettings> {
Conv1d(Conv1dNode<PS>),
Conv2d(Conv2dNode<PS>),
ConvTranspose2d(ConvTranspose2dNode<PS>),
PRelu(PReluNode<PS>),
Dropout(DropoutNode),
Gather(GatherNode),
GlobalAvgPool(GlobalAvgPoolNode),
Expand All @@ -111,6 +113,7 @@ macro_rules! match_all {
Node::Conv1d(node) => $func(node),
Node::Conv2d(node) => $func(node),
Node::ConvTranspose2d(node) => $func(node),
Node::PRelu(node) => $func(node),
Node::Dropout(node) => $func(node),
Node::Gather(node) => $func(node),
Node::GlobalAvgPool(node) => $func(node),
Expand Down Expand Up @@ -147,6 +150,7 @@ impl<PS: PrecisionSettings> Node<PS> {
Node::Conv1d(_) => "conv1d",
Node::Conv2d(_) => "conv2d",
Node::ConvTranspose2d(_) => "conv_transpose2d",
Node::PRelu(_) => "prelu",
Node::Dropout(_) => "dropout",
Node::Gather(_) => "gather",
Node::GlobalAvgPool(_) => "global_avg_pool",
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/src/burn/node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pub(crate) mod linear;
pub(crate) mod mask_where;
pub(crate) mod matmul;
pub(crate) mod max_pool2d;
pub(crate) mod prelu;
pub(crate) mod reshape;
pub(crate) mod unary;
pub(crate) mod unsqueeze;
Expand Down
100 changes: 100 additions & 0 deletions crates/burn-import/src/burn/node/prelu.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
use super::{Node, NodeCodegen, SerializationBackend};
use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type};
use burn::{
module::{Param, ParamId},
nn::{PReluConfig, PReluRecord},
record::{PrecisionSettings, Record},
tensor::{DataSerialize, Tensor},
};
use proc_macro2::TokenStream;
use quote::quote;
use serde::Serialize;

#[derive(Clone, Debug)]
pub struct PReluNode<PS: PrecisionSettings> {
pub field: OtherType,
pub input: TensorType,
pub output: TensorType,
pub alpha: DataSerialize<PS::FloatElem>,
pub config: PReluConfig,
}

impl<PS: PrecisionSettings> PReluNode<PS> {
pub fn new<S: AsRef<str>>(
name: S,
input: TensorType,
output: TensorType,
alpha: DataSerialize<PS::FloatElem>,
config: PReluConfig,
) -> Self {
Self {
field: OtherType::new(
name,
quote! {
PRelu<B>
},
),
input,
output,
alpha,
config,
}
}
}

impl<PS: PrecisionSettings> NodeCodegen<PS> for PReluNode<PS> {
fn input_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.input.clone())]
}
fn output_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.output.clone())]
}
fn field_type(&self) -> Option<Type> {
Some(Type::Other(self.field.clone()))
}

fn field_init(&self) -> Option<TokenStream> {
let name = &self.field.name;

let num_parameters = self.config.num_parameters.to_tokens();
let alpha = self.config.alpha.to_tokens();
let tokens = quote! {
let #name = PReluConfig::new(#num_parameters, #alpha)
.init(device);
};

Some(tokens)
}

fn field_serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let device = Default::default();
let record = PReluRecord::<SerializationBackend> {
alpha: Param::initialized(
ParamId::new(),
Tensor::from_data(self.alpha.clone().convert(), &device),
),
};

let item = Record::into_item::<PS>(record);
item.serialize(serializer)
}

fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
let input = scope.tensor_use_owned(&self.input, node_position);
let output = &self.output.name;
let field = &self.field.name;

quote! {
let #output = self.#field.forward(#input);
}
}
fn register_imports(&self, imports: &mut BurnImports) {
imports.register("burn::nn::PRelu");
imports.register("burn::nn::prelu::PRelu");
Arjun31415 marked this conversation as resolved.
Show resolved Hide resolved
imports.register("burn::nn::prelu::PReluConfig");
}

fn into_node(self) -> Node<PS> {
Node::PRelu(self)
}
}
17 changes: 16 additions & 1 deletion crates/burn-import/src/onnx/op_configuration.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use burn::nn::{
conv::{Conv1dConfig, Conv2dConfig, ConvTranspose2dConfig},
pool::{AvgPool2dConfig, MaxPool2dConfig},
BatchNormConfig, DropoutConfig, LayerNormConfig, LinearConfig, PaddingConfig1d,
BatchNormConfig, DropoutConfig, LayerNormConfig, LinearConfig, PReluConfig, PaddingConfig1d,
PaddingConfig2d,
};

Expand Down Expand Up @@ -120,6 +120,21 @@ pub fn max_pool2d_config(curr: &Node) -> MaxPool2dConfig {
.with_padding(padding)
.with_dilation([dilations[0] as usize, dilations[1] as usize])
}
pub fn prelu_config(curr: &Node) -> PReluConfig {
let mut alpha = 0.01;
let mut num_parameters = 0;
for (key, value) in curr.attrs.iter() {
match key.as_str() {
"alpha" => alpha = value.clone().into_f32(),
"num_parameters" => num_parameters = value.clone().into_i32(),
_ => {}
}
}
Arjun31415 marked this conversation as resolved.
Show resolved Hide resolved

PReluConfig::new()
.with_num_parameters(num_parameters as usize)
.with_alpha(alpha as f64)
}

pub fn conv_transpose2d_config(curr: &Node) -> ConvTranspose2dConfig {
let mut attrs = curr.attrs.clone();
Expand Down
10 changes: 10 additions & 0 deletions crates/burn-import/src/onnx/to_burn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use crate::{
mask_where::WhereNode,
matmul::MatmulNode,
max_pool2d::MaxPool2dNode,
prelu::PReluNode,
reshape::ReshapeNode,
unary::UnaryNode,
unsqueeze::UnsqueezeNode,
Expand Down Expand Up @@ -236,6 +237,7 @@ impl OnnxGraph {
NodeType::Conv1d => graph.register(Self::conv1d_conversion::<PS>(node)),
NodeType::Conv2d => graph.register(Self::conv2d_conversion::<PS>(node)),
NodeType::MaxPool2d => graph.register(Self::max_pool2d_conversion(node)),
NodeType::PRelu => graph.register(Self::prelu_conversion::<PS>(node)),
NodeType::AveragePool2d => graph.register(Self::avg_pool_2d_conversion(node)),
NodeType::MatMul => graph.register(Self::matmul_conversion(node)),
NodeType::Neg => graph.register(Self::neg_conversion(node)),
Expand Down Expand Up @@ -695,6 +697,14 @@ impl OnnxGraph {
MaxPool2dNode::new(name, input, output, config)
}

fn prelu_conversion<PS: PrecisionSettings>(node: Node) -> PReluNode<PS> {
let input = node.inputs.first().unwrap().to_tensor_type();
let output = node.outputs.first().unwrap().to_tensor_type();
let weight = extract_data_serialize::<PS::FloatElem>(1, &node).unwrap();
let config = prelu_config(&node);
Arjun31415 marked this conversation as resolved.
Show resolved Hide resolved
let name = &node.name;
PReluNode::<PS>::new(name, input, output, weight, config)
}
fn conv_transpose2d_conversion<PS: PrecisionSettings>(node: Node) -> ConvTranspose2dNode<PS> {
let input = node.inputs.first().unwrap().to_tensor_type();
let output = node.outputs.first().unwrap().to_tensor_type();
Expand Down
Loading