From b9a834c41ffec0d193d23ddd28ce1040ffdd58e7 Mon Sep 17 00:00:00 2001 From: Charles Bournhonesque Date: Sat, 3 Aug 2024 21:43:24 -0400 Subject: [PATCH 1/6] make contacts deterministic across Worlds --- crates/burn-train/src/metric/top_k_acc.rs | 139 ++++++++++++++++++++++ 1 file changed, 139 insertions(+) create mode 100644 crates/burn-train/src/metric/top_k_acc.rs diff --git a/crates/burn-train/src/metric/top_k_acc.rs b/crates/burn-train/src/metric/top_k_acc.rs new file mode 100644 index 0000000000..ad24fe077f --- /dev/null +++ b/crates/burn-train/src/metric/top_k_acc.rs @@ -0,0 +1,139 @@ +use core::marker::PhantomData; + +use super::state::{FormatOptions, NumericMetricState}; +use super::{MetricEntry, MetricMetadata}; +use crate::metric::{Metric, Numeric}; +use burn_core::tensor::backend::Backend; +use burn_core::tensor::{ElementConversion, Int, Tensor}; + +/// The accuracy metric. +#[derive(Default)] +pub struct AccuracyMetric { + state: NumericMetricState, + pad_token: Option, + _b: PhantomData, +} + +/// The [accuracy metric](AccuracyMetric) input type. +#[derive(new)] +pub struct AccuracyInput { + outputs: Tensor, + targets: Tensor, +} + +impl AccuracyMetric { + /// Creates the metric. + pub fn new() -> Self { + Self::default() + } + + /// Sets the pad token. + pub fn with_pad_token(mut self, index: usize) -> Self { + self.pad_token = Some(index); + self + } +} + +impl Metric for AccuracyMetric { + const NAME: &'static str = "Accuracy"; + + type Input = AccuracyInput; + + fn update(&mut self, input: &AccuracyInput, _metadata: &MetricMetadata) -> MetricEntry { + let [batch_size, _n_classes] = input.outputs.dims(); + + let targets = input.targets.clone().to_device(&B::Device::default()); + let outputs = input + .outputs + .clone() + .argmax(1) + .to_device(&B::Device::default()) + .reshape([batch_size]); + + let accuracy = match self.pad_token { + Some(pad_token) => { + let mask = targets.clone().equal_elem(pad_token as i64); + let matches = outputs.equal(targets).int().mask_fill(mask.clone(), 0); + let num_pad = mask.int().sum().into_scalar().elem::(); + + matches.sum().into_scalar().elem::() / (batch_size as f64 - num_pad) + } + None => { + outputs + .equal(targets) + .int() + .sum() + .into_scalar() + .elem::() + / batch_size as f64 + } + }; + + self.state.update( + 100.0 * accuracy, + batch_size, + FormatOptions::new(Self::NAME).unit("%").precision(2), + ) + } + + fn clear(&mut self) { + self.state.reset() + } +} + +impl Numeric for AccuracyMetric { + fn value(&self) -> f64 { + self.state.value() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::TestBackend; + + #[test] + fn test_accuracy_without_padding() { + let device = Default::default(); + let mut metric = AccuracyMetric::::new(); + let input = AccuracyInput::new( + Tensor::from_data( + [ + [0.0, 0.2, 0.8], // 2 + [1.0, 2.0, 0.5], // 1 + [0.4, 0.1, 0.2], // 0 + [0.6, 0.7, 0.2], // 1 + ], + &device, + ), + Tensor::from_data([2, 2, 1, 1], &device), + ); + + let _entry = metric.update(&input, &MetricMetadata::fake()); + assert_eq!(50.0, metric.value()); + } + + #[test] + fn test_accuracy_with_padding() { + let device = Default::default(); + let mut metric = AccuracyMetric::::new().with_pad_token(3); + let input = AccuracyInput::new( + Tensor::from_data( + [ + [0.0, 0.2, 0.8, 0.0], // 2 + [1.0, 2.0, 0.5, 0.0], // 1 + [0.4, 0.1, 0.2, 0.0], // 0 + [0.6, 0.7, 0.2, 0.0], // 1 + [0.0, 0.1, 0.2, 5.0], // Predicted padding should not count + [0.0, 0.1, 0.2, 0.0], // Error on padding should not count + [0.6, 0.0, 0.2, 0.0], // Error on padding should not count + ], + &device, + ), + Tensor::from_data([2, 2, 1, 1, 3, 3, 3], &device), + ); + + let _entry = metric.update(&input, &MetricMetadata::fake()); + assert_eq!(50.0, metric.value()); + } +} From 6fb4e3107470247c931510f84bb9c79cfafb7106 Mon Sep 17 00:00:00 2001 From: Charles Bournhonesque Date: Sat, 3 Aug 2024 21:51:17 -0400 Subject: [PATCH 2/6] add top k acc --- crates/burn-train/src/metric/mod.rs | 5 ++ crates/burn-train/src/metric/top_k_acc.rs | 88 +++++++++++++---------- 2 files changed, 56 insertions(+), 37 deletions(-) diff --git a/crates/burn-train/src/metric/mod.rs b/crates/burn-train/src/metric/mod.rs index 7d443da067..e2f5bd279e 100644 --- a/crates/burn-train/src/metric/mod.rs +++ b/crates/burn-train/src/metric/mod.rs @@ -15,6 +15,9 @@ mod loss; #[cfg(feature = "metrics")] mod memory_use; +#[cfg(feature = "metrics")] +mod top_k_acc; + pub use acc::*; pub use base::*; #[cfg(feature = "metrics")] @@ -28,6 +31,8 @@ pub use learning_rate::*; pub use loss::*; #[cfg(feature = "metrics")] pub use memory_use::*; +#[cfg(feature = "metrics")] +pub use top_k_acc::*; pub(crate) mod processor; /// Module responsible to save and exposes data collected during training. diff --git a/crates/burn-train/src/metric/top_k_acc.rs b/crates/burn-train/src/metric/top_k_acc.rs index ad24fe077f..5a29b78e30 100644 --- a/crates/burn-train/src/metric/top_k_acc.rs +++ b/crates/burn-train/src/metric/top_k_acc.rs @@ -6,25 +6,35 @@ use crate::metric::{Metric, Numeric}; use burn_core::tensor::backend::Backend; use burn_core::tensor::{ElementConversion, Int, Tensor}; -/// The accuracy metric. +/// The Top-K accuracy metric. +/// +/// For K=1, this is equivalent to the [accuracy metric](`super::acc::AccuracyMetric`). #[derive(Default)] -pub struct AccuracyMetric { +pub struct TopKAccuracyMetric { + k: usize, state: NumericMetricState, + /// If specified, targets equal to this value will be considered padding and will not count + /// towards the metric pad_token: Option, _b: PhantomData, } -/// The [accuracy metric](AccuracyMetric) input type. +/// The [top-k accuracy metric](TopKAccuracyMetric) input type. #[derive(new)] -pub struct AccuracyInput { +pub struct TopKAccuracyInput { + /// The outputs (batch_size, num_classes) outputs: Tensor, + /// The labels (batch_size) targets: Tensor, } -impl AccuracyMetric { +impl TopKAccuracyMetric { /// Creates the metric. - pub fn new() -> Self { - Self::default() + pub fn new(k: usize) -> Self { + Self { + k, + ..Default::default() + } } /// Sets the pad token. @@ -34,40 +44,44 @@ impl AccuracyMetric { } } -impl Metric for AccuracyMetric { - const NAME: &'static str = "Accuracy"; +impl Metric for TopKAccuracyMetric { + const NAME: &'static str = "Top-K Accuracy"; - type Input = AccuracyInput; + type Input = TopKAccuracyInput; - fn update(&mut self, input: &AccuracyInput, _metadata: &MetricMetadata) -> MetricEntry { + fn update(&mut self, input: &TopKAccuracyInput, _metadata: &MetricMetadata) -> MetricEntry { let [batch_size, _n_classes] = input.outputs.dims(); let targets = input.targets.clone().to_device(&B::Device::default()); + + let outputs = input .outputs .clone() - .argmax(1) + .argsort_descending(1) + .narrow(1, 0, self.k) .to_device(&B::Device::default()) - .reshape([batch_size]); + .reshape([batch_size, self.k]); - let accuracy = match self.pad_token { + let (targets, num_pad) = match self.pad_token { Some(pad_token) => { + // we ignore the samples where the target is equal to the pad token let mask = targets.clone().equal_elem(pad_token as i64); - let matches = outputs.equal(targets).int().mask_fill(mask.clone(), 0); - let num_pad = mask.int().sum().into_scalar().elem::(); - - matches.sum().into_scalar().elem::() / (batch_size as f64 - num_pad) + let num_pad = mask.clone().int().sum().into_scalar().elem::(); + (targets.clone().mask_fill(mask, -1_i64), num_pad) } - None => { - outputs - .equal(targets) + None => (targets.clone(), 0_f64), + }; + + let accuracy = targets + .reshape([batch_size, 1]) + .repeat_dim(1, self.k) + .equal(outputs) .int() .sum() .into_scalar() .elem::() - / batch_size as f64 - } - }; + / (batch_size as f64 - num_pad); self.state.update( 100.0 * accuracy, @@ -81,7 +95,7 @@ impl Metric for AccuracyMetric { } } -impl Numeric for AccuracyMetric { +impl Numeric for TopKAccuracyMetric { fn value(&self) -> f64 { self.state.value() } @@ -95,14 +109,14 @@ mod tests { #[test] fn test_accuracy_without_padding() { let device = Default::default(); - let mut metric = AccuracyMetric::::new(); - let input = AccuracyInput::new( + let mut metric = TopKAccuracyMetric::::new(2); + let input = TopKAccuracyInput::new( Tensor::from_data( [ - [0.0, 0.2, 0.8], // 2 - [1.0, 2.0, 0.5], // 1 - [0.4, 0.1, 0.2], // 0 - [0.6, 0.7, 0.2], // 1 + [0.0, 0.2, 0.8], // 2, 1 + [1.0, 2.0, 0.5], // 1, 0 + [0.4, 0.1, 0.2], // 0, 2 + [0.6, 0.7, 0.2], // 1, 0 ], &device, ), @@ -116,14 +130,14 @@ mod tests { #[test] fn test_accuracy_with_padding() { let device = Default::default(); - let mut metric = AccuracyMetric::::new().with_pad_token(3); - let input = AccuracyInput::new( + let mut metric = TopKAccuracyMetric::::new(2).with_pad_token(3); + let input = TopKAccuracyInput::new( Tensor::from_data( [ - [0.0, 0.2, 0.8, 0.0], // 2 - [1.0, 2.0, 0.5, 0.0], // 1 - [0.4, 0.1, 0.2, 0.0], // 0 - [0.6, 0.7, 0.2, 0.0], // 1 + [0.0, 0.2, 0.8, 0.0], // 2, 1 + [1.0, 2.0, 0.5, 0.0], // 1, 0 + [0.4, 0.1, 0.2, 0.0], // 0, 2 + [0.6, 0.7, 0.2, 0.0], // 1, 0 [0.0, 0.1, 0.2, 5.0], // Predicted padding should not count [0.0, 0.1, 0.2, 0.0], // Error on padding should not count [0.6, 0.0, 0.2, 0.0], // Error on padding should not count From 57bcb3974369a7503a6db972c0d55766f2d255e9 Mon Sep 17 00:00:00 2001 From: cBournhonesque Date: Tue, 6 Aug 2024 15:19:12 -0400 Subject: [PATCH 3/6] add onnx mean --- crates/burn-import/SUPPORTED-ONNX-OPS.md | 2 +- crates/burn-import/onnx-tests/build.rs | 1 + .../onnx-tests/tests/mean/mean.onnx | 23 ++++ .../burn-import/onnx-tests/tests/mean/mean.py | 41 +++++++ .../onnx-tests/tests/onnx_tests.rs | 16 +++ crates/burn-import/src/burn/node/base.rs | 4 + crates/burn-import/src/burn/node/mean.rs | 109 ++++++++++++++++++ crates/burn-import/src/burn/node/mod.rs | 1 + crates/burn-import/src/onnx/to_burn.rs | 9 ++ 9 files changed, 205 insertions(+), 1 deletion(-) create mode 100644 crates/burn-import/onnx-tests/tests/mean/mean.onnx create mode 100644 crates/burn-import/onnx-tests/tests/mean/mean.py create mode 100644 crates/burn-import/src/burn/node/mean.rs diff --git a/crates/burn-import/SUPPORTED-ONNX-OPS.md b/crates/burn-import/SUPPORTED-ONNX-OPS.md index 571cd7b565..b880c2c085 100644 --- a/crates/burn-import/SUPPORTED-ONNX-OPS.md +++ b/crates/burn-import/SUPPORTED-ONNX-OPS.md @@ -108,7 +108,7 @@ represent the corresponding Burn Op. | [MaxPool2d][98] | ✅ | ✅ | | [MaxRoiPool][99] | ❌ | ❌ | | [MaxUnpool][100] | ❌ | ❌ | -| [Mean][101] | ❌ | ✅ | +| [Mean][101] | ✅ | ✅ | | [MeanVarianceNormalization][102] | ❌ | ❌ | | [MelWeightMatrix][103] | ❌ | ❌ | | [Min][104] | ✅ | ✅ | diff --git a/crates/burn-import/onnx-tests/build.rs b/crates/burn-import/onnx-tests/build.rs index 67e2f091a3..9ace525b1c 100644 --- a/crates/burn-import/onnx-tests/build.rs +++ b/crates/burn-import/onnx-tests/build.rs @@ -51,6 +51,7 @@ fn main() { .input("tests/maxpool1d/maxpool1d.onnx") .input("tests/maxpool2d/maxpool2d.onnx") .input("tests/min/min.onnx") + .input("tests/mean/mean.onnx") .input("tests/mul/mul.onnx") .input("tests/neg/neg.onnx") .input("tests/not/not.onnx") diff --git a/crates/burn-import/onnx-tests/tests/mean/mean.onnx b/crates/burn-import/onnx-tests/tests/mean/mean.onnx new file mode 100644 index 0000000000..39c86e7fc1 --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/mean/mean.onnx @@ -0,0 +1,23 @@ + + +mean-model: +& +input1 +input2 +input3output"Mean MeanGraphZ +input1 + + +Z +input2 + + +Z +input3 + + +b +output + + +B \ No newline at end of file diff --git a/crates/burn-import/onnx-tests/tests/mean/mean.py b/crates/burn-import/onnx-tests/tests/mean/mean.py new file mode 100644 index 0000000000..dc5b99cea8 --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/mean/mean.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 + +# used to generate model: onnx-tests/tests/mean/mean.onnx + +import onnx +import onnx.helper +import onnx.checker +import numpy as np + +# Create input tensors +input1 = onnx.helper.make_tensor_value_info('input1', onnx.TensorProto.FLOAT, [3]) +input2 = onnx.helper.make_tensor_value_info('input2', onnx.TensorProto.FLOAT, [3]) +input3 = onnx.helper.make_tensor_value_info('input3', onnx.TensorProto.FLOAT, [3]) + +# Create output tensor +output = onnx.helper.make_tensor_value_info('output', onnx.TensorProto.FLOAT, [3]) + +# Create the Mean node +mean_node = onnx.helper.make_node( + 'Mean', + inputs=['input1', 'input2', 'input3'], + outputs=['output'] +) + +# Create the graph (GraphProto) +graph_def = onnx.helper.make_graph( + nodes=[mean_node], + name='MeanGraph', + inputs=[input1, input2, input3], + outputs=[output] +) + +# Create the model (ModelProto) +model_def = onnx.helper.make_model(graph_def, producer_name='mean-model') +onnx.checker.check_model(model_def) + +# Save the ONNX model +onnx.save(model_def, 'mean.onnx') + +print("ONNX model 'mean.onnx' generated successfully.") + diff --git a/crates/burn-import/onnx-tests/tests/onnx_tests.rs b/crates/burn-import/onnx-tests/tests/onnx_tests.rs index 2fc5ec6758..31ac73a6fd 100644 --- a/crates/burn-import/onnx-tests/tests/onnx_tests.rs +++ b/crates/burn-import/onnx-tests/tests/onnx_tests.rs @@ -60,6 +60,7 @@ include_models!( maxpool1d, maxpool2d, min, + mean, mul, neg, not, @@ -208,6 +209,21 @@ mod tests { output.to_data().assert_eq(&expected, true); } + #[test] + fn mean_tensor_and_tensor() { + let device = Default::default(); + let model: mean::Model = mean::Model::default(); + + let input1 = Tensor::::from_floats([1., 2., 3., 4.], &device); + let input2 = Tensor::::from_floats([2., 2., 4., 0.], &device); + let input3 = Tensor::::from_floats([3., 2., 5., -4.], &device); + + let output = model.forward(input1, input2, input3); + let expected = TensorData::from([2.0f32, 2., 4., 0.]); + + output.to_data().assert_eq(&expected, true); + } + #[test] fn mul_scalar_with_tensor_and_tensor_with_tensor() { // Initialize the model with weights (loaded from the exported file) diff --git a/crates/burn-import/src/burn/node/base.rs b/crates/burn-import/src/burn/node/base.rs index 751cbcb471..45e6c0461f 100644 --- a/crates/burn-import/src/burn/node/base.rs +++ b/crates/burn-import/src/burn/node/base.rs @@ -18,6 +18,7 @@ use burn::backend::NdArray; use burn::record::PrecisionSettings; use proc_macro2::TokenStream; use serde::Serialize; +use crate::burn::node::mean::MeanNode; /// Backend used for serialization. pub type SerializationBackend = NdArray; @@ -105,6 +106,7 @@ pub enum Node { Matmul(MatmulNode), MaxPool1d(MaxPool1dNode), MaxPool2d(MaxPool2dNode), + Mean(MeanNode), Pad(PadNode), Range(RangeNode), Reshape(ReshapeNode), @@ -151,6 +153,7 @@ macro_rules! match_all { Node::Matmul(node) => $func(node), Node::MaxPool1d(node) => $func(node), Node::MaxPool2d(node) => $func(node), + Node::Mean(node) => $func(node), Node::Pad(node) => $func(node), Node::Range(node) => $func(node), Node::Reshape(node) => $func(node), @@ -205,6 +208,7 @@ impl Node { Node::Matmul(_) => "matmul", Node::MaxPool1d(_) => "max_pool1d", Node::MaxPool2d(_) => "max_pool2d", + Node::Mean(_) => "mean", Node::Pad(_) => "pad", Node::Range(_) => "range", Node::Reshape(_) => "reshape", diff --git a/crates/burn-import/src/burn/node/mean.rs b/crates/burn-import/src/burn/node/mean.rs new file mode 100644 index 0000000000..1fb96c7226 --- /dev/null +++ b/crates/burn-import/src/burn/node/mean.rs @@ -0,0 +1,109 @@ +use super::{Node, NodeCodegen}; +use crate::burn::{Scope, TensorType, Type}; + +use burn::record::PrecisionSettings; +use proc_macro2::TokenStream; +use quote::quote; + +#[derive(Debug, Clone, new)] +pub struct MeanNode { + pub inputs: Vec, + pub output: TensorType, +} + +impl NodeCodegen for MeanNode { + fn output_types(&self) -> Vec { + vec![Type::Tensor(self.output.clone())] + } + + fn input_types(&self) -> Vec { + self.inputs + .iter() + .map(|t| Type::Tensor(t.clone())) + .collect() + } + + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + let inputs = self + .inputs + .iter() + .map(|t| scope.tensor_use_owned(t, node_position)); + + let output = &self.output.name; + let inputs_len = self.inputs.len(); + + quote! { + let #output = (#(#inputs)+*) / #inputs_len; + } + } + + fn into_node(self) -> Node { + Node::Mean(self) + } +} + +#[cfg(test)] +mod tests { + use burn::record::FullPrecisionSettings; + + use super::*; + use crate::burn::{ + graph::BurnGraph, + node::{mean::MeanNode, test::assert_tokens}, + TensorType, + }; + + #[test] + fn test_codegen_sum() { + let mut graph = BurnGraph::::default(); + + graph.register(MeanNode::new( + vec![ + TensorType::new_float("tensor1", 4), + TensorType::new_float("tensor2", 4), + ], + TensorType::new_float("tensor3", 4), + )); + + graph.register_input_output( + vec!["tensor1".to_string(), "tensor2".to_string()], + vec!["tensor3".to_string()], + ); + + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + device: burn::module::Ignored, + } + + impl Model { + #[allow(unused_variables)] + pub fn new(device: &B::Device) -> Self { + Self { + phantom: core::marker::PhantomData, + device: burn::module::Ignored(device.clone()), + } + } + + #[allow(clippy::let_and_return, clippy::approx_constant)] + pub fn forward( + &self, + tensor1: Tensor, + tensor2: Tensor + ) -> Tensor { + let tensor3 = (tensor1 + tensor2) / 2usize; + + tensor3 + } + } + }; + + assert_tokens(graph.codegen(), expected); + } +} diff --git a/crates/burn-import/src/burn/node/mod.rs b/crates/burn-import/src/burn/node/mod.rs index 9d1fdce591..875e3e5af3 100644 --- a/crates/burn-import/src/burn/node/mod.rs +++ b/crates/burn-import/src/burn/node/mod.rs @@ -25,6 +25,7 @@ pub(crate) mod mask_where; pub(crate) mod matmul; pub(crate) mod max_pool1d; pub(crate) mod max_pool2d; +pub(crate) mod mean; pub(crate) mod pad; pub(crate) mod prelu; pub(crate) mod random_normal; diff --git a/crates/burn-import/src/onnx/to_burn.rs b/crates/burn-import/src/onnx/to_burn.rs index f8b0e0ad8f..bccea8b5b6 100644 --- a/crates/burn-import/src/onnx/to_burn.rs +++ b/crates/burn-import/src/onnx/to_burn.rs @@ -78,6 +78,7 @@ use onnx_ir::{ }; pub use crate::burn::graph::RecordType; +use crate::burn::node::mean::MeanNode; /// Generate code and states from `.onnx` files and save them to the `out_dir`. #[derive(Debug, Default)] @@ -268,6 +269,7 @@ impl ParsedOnnxGraph { NodeType::Max => graph.register(Self::max_conversion(node)), NodeType::MaxPool1d => graph.register(Self::max_pool1d_conversion(node)), NodeType::MaxPool2d => graph.register(Self::max_pool2d_conversion(node)), + NodeType::Mean => graph.register(Self::mean_conversion(node)), NodeType::PRelu => graph.register(Self::prelu_conversion::(node)), NodeType::AveragePool1d => graph.register(Self::avg_pool_1d_conversion(node)), NodeType::AveragePool2d => graph.register(Self::avg_pool_2d_conversion(node)), @@ -972,6 +974,13 @@ impl ParsedOnnxGraph { MaxPool2dNode::new(name, input, output, config) } + fn mean_conversion(node: Node) -> MeanNode { + let inputs = node.inputs.iter().map(TensorType::from).collect(); + let output = TensorType::from(node.outputs.first().unwrap()); + + MeanNode::new(inputs, output) + } + fn prelu_conversion(node: Node) -> PReluNode { let input = TensorType::from(node.inputs.first().unwrap()); let output = TensorType::from(node.outputs.first().unwrap()); From 976ef4fd7d0baaa717c3eae573b0c74a5d762ac7 Mon Sep 17 00:00:00 2001 From: cBournhonesque Date: Tue, 6 Aug 2024 16:46:00 -0400 Subject: [PATCH 4/6] fix --- crates/burn-import/src/burn/node/base.rs | 2 +- crates/burn-import/src/burn/node/mean.rs | 4 ++-- crates/burn-train/src/metric/top_k_acc.rs | 17 ++++++++--------- 3 files changed, 11 insertions(+), 12 deletions(-) diff --git a/crates/burn-import/src/burn/node/base.rs b/crates/burn-import/src/burn/node/base.rs index 45e6c0461f..3e7e202c5b 100644 --- a/crates/burn-import/src/burn/node/base.rs +++ b/crates/burn-import/src/burn/node/base.rs @@ -13,12 +13,12 @@ use super::{ reshape::ReshapeNode, resize::ResizeNode, slice::SliceNode, squeeze::SqueezeNode, sum::SumNode, unary::UnaryNode, unsqueeze::UnsqueezeNode, }; +use crate::burn::node::mean::MeanNode; use crate::burn::{BurnImports, Scope, Type}; use burn::backend::NdArray; use burn::record::PrecisionSettings; use proc_macro2::TokenStream; use serde::Serialize; -use crate::burn::node::mean::MeanNode; /// Backend used for serialization. pub type SerializationBackend = NdArray; diff --git a/crates/burn-import/src/burn/node/mean.rs b/crates/burn-import/src/burn/node/mean.rs index 1fb96c7226..a8c6c8b9bf 100644 --- a/crates/burn-import/src/burn/node/mean.rs +++ b/crates/burn-import/src/burn/node/mean.rs @@ -30,7 +30,7 @@ impl NodeCodegen for MeanNode { .map(|t| scope.tensor_use_owned(t, node_position)); let output = &self.output.name; - let inputs_len = self.inputs.len(); + let inputs_len = self.inputs.len() as u32; quote! { let #output = (#(#inputs)+*) / #inputs_len; @@ -97,7 +97,7 @@ mod tests { tensor1: Tensor, tensor2: Tensor ) -> Tensor { - let tensor3 = (tensor1 + tensor2) / 2usize; + let tensor3 = (tensor1 + tensor2) / 2u32; tensor3 } diff --git a/crates/burn-train/src/metric/top_k_acc.rs b/crates/burn-train/src/metric/top_k_acc.rs index 5a29b78e30..4c88f07371 100644 --- a/crates/burn-train/src/metric/top_k_acc.rs +++ b/crates/burn-train/src/metric/top_k_acc.rs @@ -54,7 +54,6 @@ impl Metric for TopKAccuracyMetric { let targets = input.targets.clone().to_device(&B::Device::default()); - let outputs = input .outputs .clone() @@ -74,14 +73,14 @@ impl Metric for TopKAccuracyMetric { }; let accuracy = targets - .reshape([batch_size, 1]) - .repeat_dim(1, self.k) - .equal(outputs) - .int() - .sum() - .into_scalar() - .elem::() - / (batch_size as f64 - num_pad); + .reshape([batch_size, 1]) + .repeat_dim(1, self.k) + .equal(outputs) + .int() + .sum() + .into_scalar() + .elem::() + / (batch_size as f64 - num_pad); self.state.update( 100.0 * accuracy, From 8312f4f603e776cc11cf7ce8c50851856f1b69ba Mon Sep 17 00:00:00 2001 From: cBournhonesque Date: Wed, 7 Aug 2024 11:40:34 -0400 Subject: [PATCH 5/6] push fix --- crates/burn-import/src/burn/node/base.rs | 3 +-- crates/burn-import/src/burn/node/mean.rs | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/crates/burn-import/src/burn/node/base.rs b/crates/burn-import/src/burn/node/base.rs index 3e7e202c5b..cbe8b887fc 100644 --- a/crates/burn-import/src/burn/node/base.rs +++ b/crates/burn-import/src/burn/node/base.rs @@ -8,12 +8,11 @@ use super::{ conv_transpose_3d::ConvTranspose3dNode, dropout::DropoutNode, expand::ExpandNode, gather::GatherNode, gather_elements::GatherElementsNode, global_avg_pool::GlobalAvgPoolNode, layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode, - max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, pad::PadNode, prelu::PReluNode, + max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, mean::MeanNode, pad::PadNode, prelu::PReluNode, random_normal::RandomNormalNode, random_uniform::RandomUniformNode, range::RangeNode, reshape::ReshapeNode, resize::ResizeNode, slice::SliceNode, squeeze::SqueezeNode, sum::SumNode, unary::UnaryNode, unsqueeze::UnsqueezeNode, }; -use crate::burn::node::mean::MeanNode; use crate::burn::{BurnImports, Scope, Type}; use burn::backend::NdArray; use burn::record::PrecisionSettings; diff --git a/crates/burn-import/src/burn/node/mean.rs b/crates/burn-import/src/burn/node/mean.rs index a8c6c8b9bf..17c6b34cb6 100644 --- a/crates/burn-import/src/burn/node/mean.rs +++ b/crates/burn-import/src/burn/node/mean.rs @@ -54,7 +54,7 @@ mod tests { }; #[test] - fn test_codegen_sum() { + fn test_codegen_mean() { let mut graph = BurnGraph::::default(); graph.register(MeanNode::new( From ddd8051863597d2d6fee6e54359dfa2c7bdc7839 Mon Sep 17 00:00:00 2001 From: cBournhonesque Date: Wed, 7 Aug 2024 11:51:45 -0400 Subject: [PATCH 6/6] format --- crates/burn-import/src/burn/node/base.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/crates/burn-import/src/burn/node/base.rs b/crates/burn-import/src/burn/node/base.rs index cbe8b887fc..46e1d5e1a0 100644 --- a/crates/burn-import/src/burn/node/base.rs +++ b/crates/burn-import/src/burn/node/base.rs @@ -8,10 +8,10 @@ use super::{ conv_transpose_3d::ConvTranspose3dNode, dropout::DropoutNode, expand::ExpandNode, gather::GatherNode, gather_elements::GatherElementsNode, global_avg_pool::GlobalAvgPoolNode, layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode, - max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, mean::MeanNode, pad::PadNode, prelu::PReluNode, - random_normal::RandomNormalNode, random_uniform::RandomUniformNode, range::RangeNode, - reshape::ReshapeNode, resize::ResizeNode, slice::SliceNode, squeeze::SqueezeNode, sum::SumNode, - unary::UnaryNode, unsqueeze::UnsqueezeNode, + max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, mean::MeanNode, pad::PadNode, + prelu::PReluNode, random_normal::RandomNormalNode, random_uniform::RandomUniformNode, + range::RangeNode, reshape::ReshapeNode, resize::ResizeNode, slice::SliceNode, + squeeze::SqueezeNode, sum::SumNode, unary::UnaryNode, unsqueeze::UnsqueezeNode, }; use crate::burn::{BurnImports, Scope, Type}; use burn::backend::NdArray;