Skip to content

Commit

Permalink
Add subtract tensor from scalar for ONNX sub op
Browse files Browse the repository at this point in the history
Make sure sub op is more efficient by using one operator
  • Loading branch information
JC committed Jul 4, 2024
1 parent f5be04f commit af2138d
Show file tree
Hide file tree
Showing 9 changed files with 36 additions and 10 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,6 @@ target
.vs
.fleet
.ipynb_checkpoints/

# Generated IR and Burn Graph from ONNX
out
2 changes: 1 addition & 1 deletion crates/burn-import/DEVELOPMENT.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ To extend `burn-import` with support for new ONNX operators, follow these steps:
the Burn model in Rust code, and `my-model.json` includes the model data.

7. **Add End-to-End Test**: Include the test in `./burn-import/onnx-tests/tests/onnx_tests.rs`.
Further details can be found in the [onnx-tests README](./burn-import/onnx-tests/README.md).
Further details can be found in the [onnx-tests README](./onnx-tests/README.md).

## Testing

Expand Down
4 changes: 2 additions & 2 deletions crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ mod tests {
let input = Tensor::<Backend, 4>::from_floats([[[[1., 2., 3., 4.]]]], &device);
let scalar = 3.0f64;
let output = model.forward(input, scalar);
let expected = TensorData::from([[[[6f32, 7., 8., 9.]]]]);
let expected = TensorData::from([[[[-12f32, -13., -14., -15.]]]]);

output.to_data().assert_eq(&expected, true);
}
Expand All @@ -164,7 +164,7 @@ mod tests {
let input = Tensor::<Backend, 4, Int>::from_ints([[[[1, 2, 3, 4]]]], &device);
let scalar = 3;
let output = model.forward(input, scalar);
let expected = TensorData::from([[[[6i64, 6, 6, 6]]]]);
let expected = TensorData::from([[[[-12i64, -12, -12, -12]]]]);

output.to_data().assert_eq(&expected, true);
}
Expand Down
Binary file modified crates/burn-import/onnx-tests/tests/sub/sub.onnx
Binary file not shown.
10 changes: 7 additions & 3 deletions crates/burn-import/onnx-tests/tests/sub/sub.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ def forward(self, x, k):
# Sutract a scalar from a tensor
x = x - d

# Sutract a tensor from a scalar
x = d - x

return x


Expand All @@ -40,8 +43,9 @@ def main():

scalar = 3.0

torch.onnx.export(model, (dummy_input, scalar), onnx_name,
verbose=False, opset_version=16)
torch.onnx.export(
model, (dummy_input, scalar), onnx_name, verbose=False, opset_version=16
)

print("Finished exporting model to {}".format(onnx_name))

Expand All @@ -53,5 +57,5 @@ def main():
print("Test output data: {}".format(output))


if __name__ == '__main__':
if __name__ == "__main__":
main()
Binary file modified crates/burn-import/onnx-tests/tests/sub/sub_int.onnx
Binary file not shown.
10 changes: 7 additions & 3 deletions crates/burn-import/onnx-tests/tests/sub/sub_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def forward(self, x, k):
# Sutract a scalar from a tensor
x = x - d

# Sutract a tensor from a scalar
x = d - x

return x


Expand All @@ -41,8 +44,9 @@ def main():
test_input = torch.tensor([[[[1, 2, 3, 4]]]], device=device)
scalar = 3

torch.onnx.export(model, (test_input, scalar), onnx_name,
verbose=False, opset_version=16)
torch.onnx.export(
model, (test_input, scalar), onnx_name, verbose=False, opset_version=16
)

print("Finished exporting model to {}".format(onnx_name))

Expand All @@ -51,5 +55,5 @@ def main():
print("Test output data: {}".format(output))


if __name__ == '__main__':
if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions crates/burn-import/src/burn/node/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ impl BinaryNode {
(Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.sub(#rhs) },
(Type::Tensor(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs.sub_scalar(#rhs) },
(Type::Scalar(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs - #rhs },
(Type::Scalar(_), Type::Tensor(_)) => move |lhs, rhs| quote! { -#rhs.sub_scalar(#lhs) },
_ => panic!("Subtraction is supported for tensor and scalar only"),
};

Expand Down
16 changes: 15 additions & 1 deletion crates/onnx-ir/src/dim_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ pub fn dim_inference(node: &mut Node) {
NodeType::Slice => slice_update_outputs(node),
NodeType::Softmax => same_as_input(node),
NodeType::Sqrt => same_as_input(node),
NodeType::Sub => same_as_input(node),
NodeType::Sub => sub_update_outputs(node),
NodeType::Sum => same_as_input(node),
NodeType::Tanh => same_as_input(node),
NodeType::Transpose => same_as_input(node),
Expand Down Expand Up @@ -481,6 +481,20 @@ fn slice_update_outputs(node: &mut Node) {
}
}

fn sub_update_outputs(node: &mut Node) {
node.outputs[0].ty = match (node.inputs[0].ty.clone(), node.inputs[1].ty.clone()) {
(ArgType::Scalar(_lhs), ArgType::Scalar(rhs)) => ArgType::Scalar(rhs),
(ArgType::Scalar(_lhs), ArgType::Tensor(rhs)) => ArgType::Tensor(rhs),
(ArgType::Tensor(lhs), ArgType::Scalar(_rhs)) => ArgType::Tensor(lhs),
// Support broadcasting for lhs/rhs
(ArgType::Tensor(lhs), ArgType::Tensor(rhs)) if lhs.dim > rhs.dim => ArgType::Tensor(lhs),
(ArgType::Tensor(lhs), ArgType::Tensor(rhs)) if lhs.dim <= rhs.dim => ArgType::Tensor(rhs),
_ => {
panic!("Only tensor-scalar inputs are valid.");
}
};
}

/// Update the output tensor dimension based on the "axes" attribute or the second input
fn unsqueeze_update_output(node: &mut Node) {
let axes = if node.inputs.len() == 2 {
Expand Down

0 comments on commit af2138d

Please sign in to comment.