Skip to content

Commit

Permalink
chore: merge main
Browse files Browse the repository at this point in the history
  • Loading branch information
JachymPutta committed Jun 3, 2024
2 parents bad96c3 + 92b0067 commit fadccda
Show file tree
Hide file tree
Showing 14 changed files with 272 additions and 32 deletions.
2 changes: 1 addition & 1 deletion crates/burn-import/SUPPORTED-ONNX-OPS.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ represent the corresponding Burn Op.
| [EyeLike][55] |||
| [Flatten][56] |||
| [Floor][57] |||
| [Gather][58] | ||
| [Gather][58] | ||
| [GatherElements][59] |||
| [GatherND][60] |||
| [Gelu][61] |||
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ fn main() {
.input("tests/exp/exp.onnx")
.input("tests/flatten/flatten.onnx")
.input("tests/gather/gather.onnx")
.input("tests/gather_elements/gather_elements.onnx")
.input("tests/gelu/gelu.onnx")
.input("tests/global_avr_pool/global_avr_pool.onnx")
.input("tests/layer_norm/layer_norm.onnx")
Expand Down
20 changes: 10 additions & 10 deletions crates/burn-import/onnx-tests/tests/gather/gather.onnx
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
pytorch2.2.2:�
a
onnx::GatherElements_0
onnx::GatherElements_12/GatherElements"GatherElements*
pytorch2.1.1:�
A
onnx::Gather_0
onnx::Gather_12/Gather"Gather*
axis�
main_graphZ(
onnx::GatherElements_0
main_graphZ
onnx::Gather_0


Z(
onnx::GatherElements_1


Z
onnx::Gather_1


b
2

Expand Down
16 changes: 8 additions & 8 deletions crates/burn-import/onnx-tests/tests/gather/gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ def __init__(self):
super(Model, self).__init__()

def forward(self, x, index):
x = torch.gather(x, 1, index)
return x
gathered = torch.index_select(x, 1, index)
return gathered


def main():
Expand All @@ -24,19 +24,19 @@ def main():
model.eval()
device = torch.device("cpu")
onnx_name = "gather.onnx"
dummy_input = torch.randn(2, 2, device=device)
dummy_index = torch.randint(high=2, size=(2, 2), device=device, dtype=torch.int64)

dummy_input = torch.randn(2, 3, device=device)
dummy_index = torch.tensor([0, 2], device=device, dtype=torch.int64)

torch.onnx.export(model, (dummy_input, dummy_index), onnx_name,
verbose=False, opset_version=16)

print("Finished exporting model to {}".format(onnx_name))

# Output some test data for use in the test
test_input = torch.tensor([[1.0, 2.0],
[3.0, 4.0]])
test_index = torch.tensor([[0, 0],
[1, 0]])
test_input = torch.tensor([[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0]])
test_index = torch.tensor([0, 2], dtype=torch.int64)

print("Test input data: {}, {}".format(test_input, test_index))
output = model.forward(test_input, test_index)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
pytorch2.1.1:�
a
onnx::GatherElements_0
onnx::GatherElements_12/GatherElements"GatherElements*
axis�
main_graphZ(
onnx::GatherElements_0


Z(
onnx::GatherElements_1


b
2


B
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#!/usr/bin/env python3

# used to generate model: onnx-tests/tests/gather/gather_elements.onnx
# note that the ONNX specification for `GatherElements` corresponds to PyTorch's/Burn's `gather` function

import torch
import torch.nn as nn


class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, index):
x = torch.gather(x, 1, index)
return x


def main():
# Set random seed for reproducibility
torch.manual_seed(0)

# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")
onnx_name = "gather_elements.onnx"
dummy_input = torch.randn(2, 2, device=device)
dummy_index = torch.randint(high=2, size=(2, 2), device=device, dtype=torch.int64)

torch.onnx.export(model, (dummy_input, dummy_index), onnx_name,
verbose=False, opset_version=16)

print("Finished exporting model to {}".format(onnx_name))

# Output some test data for use in the test
test_input = torch.tensor([[1.0, 2.0],
[3.0, 4.0]])
test_index = torch.tensor([[0, 0],
[1, 0]])

print("Test input data: {}, {}".format(test_input, test_index))
output = model.forward(test_input, test_index)
print("Test output data: {}".format(output))


if __name__ == '__main__':
main()
17 changes: 16 additions & 1 deletion crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ include_models!(
expand,
flatten,
gather,
gather_elements,
gelu,
global_avr_pool,
layer_norm,
Expand Down Expand Up @@ -390,9 +391,23 @@ mod tests {

#[test]
fn gather() {
// Initialize the model with weights (loaded from the exported file)
let model: gather::Model<Backend> = gather::Model::default();

let device = Default::default();

let input = Tensor::<Backend, 2>::from_floats([[1., 2., 3.], [4., 5., 6.]], &device);
let index = Tensor::<Backend, 1, Int>::from_ints([0, 2], &device);
let output = model.forward(input, index);
let expected = Data::from([[1., 3.], [4., 6.]]);

assert_eq!(output.to_data(), expected);
}

#[test]
fn gather_elements() {
// Initialize the model with weights (loaded from the exported file)
let model: gather_elements::Model<Backend> = gather_elements::Model::default();

let device = Default::default();
// Run the model
let input = Tensor::<Backend, 2>::from_floats([[1., 2.], [3., 4.]], &device);
Expand Down
14 changes: 9 additions & 5 deletions crates/burn-import/src/burn/node/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ use super::{
batch_norm::BatchNormNode, binary::BinaryNode, clip::ClipNode, concat::ConcatNode,
constant::ConstantNode, conv1d::Conv1dNode, conv2d::Conv2dNode,
conv_transpose_2d::ConvTranspose2dNode, dropout::DropoutNode, expand::ExpandNode,
gather::GatherNode, global_avg_pool::GlobalAvgPoolNode, layer_norm::LayerNormNode,
linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode, max_pool1d::MaxPool1dNode,
max_pool2d::MaxPool2dNode, prelu::PReluNode, random_normal::RandomNormalNode,
random_uniform::RandomUniformNode, range::RangeNode, reshape::ReshapeNode,
squeeze::SqueezeNode, sum::SumNode, unary::UnaryNode, unsqueeze::UnsqueezeNode,
gather::GatherNode, gather_elements::GatherElementsNode, global_avg_pool::GlobalAvgPoolNode,
layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode,
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, prelu::PReluNode,
random_normal::RandomNormalNode, random_uniform::RandomUniformNode, range::RangeNode,
reshape::ReshapeNode, squeeze::SqueezeNode, sum::SumNode, unary::UnaryNode,
unsqueeze::UnsqueezeNode,
};
use crate::burn::{BurnImports, Scope, Type};
use burn::backend::NdArray;
Expand Down Expand Up @@ -92,6 +93,7 @@ pub enum Node<PS: PrecisionSettings> {
Dropout(DropoutNode),
Expand(ExpandNode),
Gather(GatherNode),
GatherElements(GatherElementsNode),
GlobalAvgPool(GlobalAvgPoolNode),
LayerNorm(LayerNormNode<PS>),
Linear(LinearNode<PS>),
Expand Down Expand Up @@ -128,6 +130,7 @@ macro_rules! match_all {
Node::Dropout(node) => $func(node),
Node::Expand(node) => $func(node),
Node::Gather(node) => $func(node),
Node::GatherElements(node) => $func(node),
Node::GlobalAvgPool(node) => $func(node),
Node::LayerNorm(node) => $func(node),
Node::Linear(node) => $func(node),
Expand Down Expand Up @@ -174,6 +177,7 @@ impl<PS: PrecisionSettings> Node<PS> {
Node::Dropout(_) => "dropout",
Node::Expand(_) => "expand",
Node::Gather(_) => "gather",
Node::GatherElements(_) => "gather_elements",
Node::GlobalAvgPool(_) => "global_avg_pool",
Node::LayerNorm(_) => "layer_norm",
Node::Linear(_) => "linear",
Expand Down
10 changes: 5 additions & 5 deletions crates/burn-import/src/burn/node/gather.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for GatherNode {
let output = &self.output.name;

quote! {
let #output = #input.gather(#dim, #index);
let #output = #input.select(#dim, #index);
}
}

Expand All @@ -62,9 +62,9 @@ mod tests {

graph.register(GatherNode::new(
TensorType::new_float("tensor1", 2),
TensorType::new_int("tensor2", 2),
TensorType::new_int("tensor2", 1),
TensorType::new_float("tensor3", 2),
1,
0,
));

graph.register_input_output(
Expand Down Expand Up @@ -98,9 +98,9 @@ mod tests {
pub fn forward(
&self,
tensor1: Tensor<B, 2>,
tensor2: Tensor<B, 2, Int>
tensor2: Tensor<B, 1, Int>
) -> Tensor<B, 2> {
let tensor3 = tensor1.gather(1, tensor2);
let tensor3 = tensor1.select(0, tensor2);

tensor3
}
Expand Down
112 changes: 112 additions & 0 deletions crates/burn-import/src/burn/node/gather_elements.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
use super::{Node, NodeCodegen};
use crate::burn::{TensorType, ToTokens, Type};

use burn::record::PrecisionSettings;
use quote::quote;

#[derive(Debug, Clone, new)]
pub struct GatherElementsNode {
pub input: TensorType,
pub index: TensorType,
pub output: TensorType,
pub dim: usize,
}

impl<PS: PrecisionSettings> NodeCodegen<PS> for GatherElementsNode {
fn output_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.output.clone())]
}

fn input_types(&self) -> Vec<crate::burn::Type> {
vec![
Type::Tensor(self.input.clone()),
Type::Tensor(self.index.clone()),
]
}

fn forward(
&self,
scope: &mut crate::burn::Scope,
node_position: usize,
) -> proc_macro2::TokenStream {
let dim = self.dim.to_tokens();
let input = scope.tensor_use_owned(&self.input, node_position);
let index = scope.tensor_use_owned(&self.index, node_position);
let output = &self.output.name;

quote! {
let #output = #input.gather(#dim, #index);
}
}

fn into_node(self) -> super::Node<PS> {
Node::GatherElements(self)
}
}

#[cfg(test)]
mod tests {

use burn::record::FullPrecisionSettings;

use super::*;
use crate::burn::{
graph::BurnGraph,
node::{gather_elements::GatherElementsNode, test::assert_tokens},
TensorType,
};

#[test]
fn test_codegen_gather_elements() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();

graph.register(GatherElementsNode::new(
TensorType::new_float("tensor1", 2),
TensorType::new_int("tensor2", 2),
TensorType::new_float("tensor3", 2),
1,
));

graph.register_input_output(
vec!["tensor1".to_string(), "tensor2".to_string()],
vec!["tensor3".to_string()],
);

let expected = quote! {
use burn::tensor::Int;
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};

#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}

impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}

#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(
&self,
tensor1: Tensor<B, 2>,
tensor2: Tensor<B, 2, Int>
) -> Tensor<B, 2> {
let tensor3 = tensor1.gather(1, tensor2);

tensor3
}
}
};

assert_tokens(graph.codegen(), expected);
}
}
1 change: 1 addition & 0 deletions crates/burn-import/src/burn/node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub(crate) mod conv_transpose_2d;
pub(crate) mod dropout;
pub(crate) mod expand;
pub(crate) mod gather;
pub(crate) mod gather_elements;
pub(crate) mod global_avg_pool;
pub(crate) mod layer_norm;
pub(crate) mod linear;
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-import/src/burn/node/sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ mod tests {
};

#[test]
fn test_codegen_concat() {
fn test_codegen_sum() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();

graph.register(SumNode::new(
Expand Down
Loading

0 comments on commit fadccda

Please sign in to comment.