diff --git a/.gitignore b/.gitignore index 9d5f51b617..ffa113f25c 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,6 @@ target .vs .fleet .ipynb_checkpoints/ + +# Generated IR and Burn Graph from ONNX +out diff --git a/crates/burn-import/DEVELOPMENT.md b/crates/burn-import/DEVELOPMENT.md index 687428fe1b..2d122a042b 100644 --- a/crates/burn-import/DEVELOPMENT.md +++ b/crates/burn-import/DEVELOPMENT.md @@ -62,7 +62,7 @@ To extend `burn-import` with support for new ONNX operators, follow these steps: the Burn model in Rust code, and `my-model.json` includes the model data. 7. **Add End-to-End Test**: Include the test in `./burn-import/onnx-tests/tests/onnx_tests.rs`. - Further details can be found in the [onnx-tests README](./burn-import/onnx-tests/README.md). + Further details can be found in the [onnx-tests README](./onnx-tests/README.md). ## Testing diff --git a/crates/burn-import/onnx-tests/tests/onnx_tests.rs b/crates/burn-import/onnx-tests/tests/onnx_tests.rs index ac780b6d0b..d010d81f1a 100644 --- a/crates/burn-import/onnx-tests/tests/onnx_tests.rs +++ b/crates/burn-import/onnx-tests/tests/onnx_tests.rs @@ -147,7 +147,7 @@ mod tests { let input = Tensor::::from_floats([[[[1., 2., 3., 4.]]]], &device); let scalar = 3.0f64; let output = model.forward(input, scalar); - let expected = TensorData::from([[[[6f32, 7., 8., 9.]]]]); + let expected = TensorData::from([[[[-12f32, -13., -14., -15.]]]]); output.to_data().assert_eq(&expected, true); } @@ -162,7 +162,7 @@ mod tests { let input = Tensor::::from_ints([[[[1, 2, 3, 4]]]], &device); let scalar = 3; let output = model.forward(input, scalar); - let expected = TensorData::from([[[[6i64, 6, 6, 6]]]]); + let expected = TensorData::from([[[[-12i64, -12, -12, -12]]]]); output.to_data().assert_eq(&expected, true); } diff --git a/crates/burn-import/onnx-tests/tests/sub/sub.onnx b/crates/burn-import/onnx-tests/tests/sub/sub.onnx index 7ffdfc8083..60d76ea1c5 100644 Binary files a/crates/burn-import/onnx-tests/tests/sub/sub.onnx and b/crates/burn-import/onnx-tests/tests/sub/sub.onnx differ diff --git a/crates/burn-import/onnx-tests/tests/sub/sub.py b/crates/burn-import/onnx-tests/tests/sub/sub.py index f71cf4a018..2169592b90 100755 --- a/crates/burn-import/onnx-tests/tests/sub/sub.py +++ b/crates/burn-import/onnx-tests/tests/sub/sub.py @@ -26,6 +26,9 @@ def forward(self, x, k): # Sutract a scalar from a tensor x = x - d + # Sutract a tensor from a scalar + x = d - x + return x @@ -40,8 +43,9 @@ def main(): scalar = 3.0 - torch.onnx.export(model, (dummy_input, scalar), onnx_name, - verbose=False, opset_version=16) + torch.onnx.export( + model, (dummy_input, scalar), onnx_name, verbose=False, opset_version=16 + ) print("Finished exporting model to {}".format(onnx_name)) @@ -53,5 +57,5 @@ def main(): print("Test output data: {}".format(output)) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/crates/burn-import/onnx-tests/tests/sub/sub_int.onnx b/crates/burn-import/onnx-tests/tests/sub/sub_int.onnx index 55309cea18..73a4ace795 100644 Binary files a/crates/burn-import/onnx-tests/tests/sub/sub_int.onnx and b/crates/burn-import/onnx-tests/tests/sub/sub_int.onnx differ diff --git a/crates/burn-import/onnx-tests/tests/sub/sub_int.py b/crates/burn-import/onnx-tests/tests/sub/sub_int.py index 17ace09cad..487c66b19b 100755 --- a/crates/burn-import/onnx-tests/tests/sub/sub_int.py +++ b/crates/burn-import/onnx-tests/tests/sub/sub_int.py @@ -27,6 +27,9 @@ def forward(self, x, k): # Sutract a scalar from a tensor x = x - d + # Sutract a tensor from a scalar + x = d - x + return x @@ -41,8 +44,9 @@ def main(): test_input = torch.tensor([[[[1, 2, 3, 4]]]], device=device) scalar = 3 - torch.onnx.export(model, (test_input, scalar), onnx_name, - verbose=False, opset_version=16) + torch.onnx.export( + model, (test_input, scalar), onnx_name, verbose=False, opset_version=16 + ) print("Finished exporting model to {}".format(onnx_name)) @@ -51,5 +55,5 @@ def main(): print("Test output data: {}".format(output)) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/crates/burn-import/src/burn/node/binary.rs b/crates/burn-import/src/burn/node/binary.rs index da6b7b9303..37983f4c3d 100644 --- a/crates/burn-import/src/burn/node/binary.rs +++ b/crates/burn-import/src/burn/node/binary.rs @@ -131,6 +131,7 @@ impl BinaryNode { (Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.sub(#rhs) }, (Type::Tensor(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs.sub_scalar(#rhs) }, (Type::Scalar(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs - #rhs }, + (Type::Scalar(_), Type::Tensor(_)) => move |lhs, rhs| quote! { -#rhs.sub_scalar(#lhs) }, _ => panic!("Subtraction is supported for tensor and scalar only"), }; diff --git a/crates/onnx-ir/src/dim_inference.rs b/crates/onnx-ir/src/dim_inference.rs index e39ef73bdf..2e4f67676a 100644 --- a/crates/onnx-ir/src/dim_inference.rs +++ b/crates/onnx-ir/src/dim_inference.rs @@ -69,7 +69,7 @@ pub fn dim_inference(node: &mut Node) { NodeType::Slice => slice_update_outputs(node), NodeType::Softmax => same_as_input(node), NodeType::Sqrt => same_as_input(node), - NodeType::Sub => same_as_input(node), + NodeType::Sub => sub_update_outputs(node), NodeType::Sum => same_as_input(node), NodeType::Tanh => same_as_input(node), NodeType::Transpose => same_as_input(node), @@ -481,6 +481,20 @@ fn slice_update_outputs(node: &mut Node) { } } +fn sub_update_outputs(node: &mut Node) { + node.outputs[0].ty = match (node.inputs[0].ty.clone(), node.inputs[1].ty.clone()) { + (ArgType::Scalar(_lhs), ArgType::Scalar(rhs)) => ArgType::Scalar(rhs), + (ArgType::Scalar(_lhs), ArgType::Tensor(rhs)) => ArgType::Tensor(rhs), + (ArgType::Tensor(lhs), ArgType::Scalar(_rhs)) => ArgType::Tensor(lhs), + // Support broadcasting for lhs/rhs + (ArgType::Tensor(lhs), ArgType::Tensor(rhs)) if lhs.dim > rhs.dim => ArgType::Tensor(lhs), + (ArgType::Tensor(lhs), ArgType::Tensor(rhs)) if lhs.dim <= rhs.dim => ArgType::Tensor(rhs), + _ => { + panic!("Only tensor-scalar inputs are valid."); + } + }; +} + /// Update the output tensor dimension based on the "axes" attribute or the second input fn unsqueeze_update_output(node: &mut Node) { let axes = if node.inputs.len() == 2 {