Skip to content

Commit

Permalink
Add subtract tensor from scalar for ONNX sub op
Browse files Browse the repository at this point in the history
  • Loading branch information
JC committed Jul 4, 2024
1 parent f5be04f commit 785c5ed
Show file tree
Hide file tree
Showing 8 changed files with 41 additions and 10 deletions.
2 changes: 1 addition & 1 deletion crates/burn-import/DEVELOPMENT.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ To extend `burn-import` with support for new ONNX operators, follow these steps:
the Burn model in Rust code, and `my-model.json` includes the model data.

7. **Add End-to-End Test**: Include the test in `./burn-import/onnx-tests/tests/onnx_tests.rs`.
Further details can be found in the [onnx-tests README](./burn-import/onnx-tests/README.md).
Further details can be found in the [onnx-tests README](./onnx-tests/README.md).

## Testing

Expand Down
4 changes: 2 additions & 2 deletions crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ mod tests {
let input = Tensor::<Backend, 4>::from_floats([[[[1., 2., 3., 4.]]]], &device);
let scalar = 3.0f64;
let output = model.forward(input, scalar);
let expected = TensorData::from([[[[6f32, 7., 8., 9.]]]]);
let expected = TensorData::from([[[[-12f32, -13., -14., -15.]]]]);

output.to_data().assert_eq(&expected, true);
}
Expand All @@ -164,7 +164,7 @@ mod tests {
let input = Tensor::<Backend, 4, Int>::from_ints([[[[1, 2, 3, 4]]]], &device);
let scalar = 3;
let output = model.forward(input, scalar);
let expected = TensorData::from([[[[6i64, 6, 6, 6]]]]);
let expected = TensorData::from([[[[-12i64, -12, -12, -12]]]]);

output.to_data().assert_eq(&expected, true);
}
Expand Down
Binary file modified crates/burn-import/onnx-tests/tests/sub/sub.onnx
Binary file not shown.
10 changes: 7 additions & 3 deletions crates/burn-import/onnx-tests/tests/sub/sub.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ def forward(self, x, k):
# Sutract a scalar from a tensor
x = x - d

# Sutract a tensor from a scalar
x = d - x

return x


Expand All @@ -40,8 +43,9 @@ def main():

scalar = 3.0

torch.onnx.export(model, (dummy_input, scalar), onnx_name,
verbose=False, opset_version=16)
torch.onnx.export(
model, (dummy_input, scalar), onnx_name, verbose=False, opset_version=16
)

print("Finished exporting model to {}".format(onnx_name))

Expand All @@ -53,5 +57,5 @@ def main():
print("Test output data: {}".format(output))


if __name__ == '__main__':
if __name__ == "__main__":
main()
Binary file modified crates/burn-import/onnx-tests/tests/sub/sub_int.onnx
Binary file not shown.
10 changes: 7 additions & 3 deletions crates/burn-import/onnx-tests/tests/sub/sub_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def forward(self, x, k):
# Sutract a scalar from a tensor
x = x - d

# Sutract a tensor from a scalar
x = d - x

return x


Expand All @@ -41,8 +44,9 @@ def main():
test_input = torch.tensor([[[[1, 2, 3, 4]]]], device=device)
scalar = 3

torch.onnx.export(model, (test_input, scalar), onnx_name,
verbose=False, opset_version=16)
torch.onnx.export(
model, (test_input, scalar), onnx_name, verbose=False, opset_version=16
)

print("Finished exporting model to {}".format(onnx_name))

Expand All @@ -51,5 +55,5 @@ def main():
print("Test output data: {}".format(output))


if __name__ == '__main__':
if __name__ == "__main__":
main()
3 changes: 3 additions & 0 deletions crates/burn-import/src/burn/node/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ impl BinaryNode {
(Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.sub(#rhs) },
(Type::Tensor(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs.sub_scalar(#rhs) },
(Type::Scalar(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs - #rhs },
(Type::Scalar(_), Type::Tensor(_)) => {
move |lhs, rhs| quote! { #rhs.mul_scalar(-1).add_scalar(#lhs) }
}
_ => panic!("Subtraction is supported for tensor and scalar only"),
};

Expand Down
22 changes: 21 additions & 1 deletion crates/onnx-ir/src/dim_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ pub fn dim_inference(node: &mut Node) {
NodeType::Slice => slice_update_outputs(node),
NodeType::Softmax => same_as_input(node),
NodeType::Sqrt => same_as_input(node),
NodeType::Sub => same_as_input(node),
NodeType::Sub => sub_update_outputs(node),
NodeType::Sum => same_as_input(node),
NodeType::Tanh => same_as_input(node),
NodeType::Transpose => same_as_input(node),
Expand Down Expand Up @@ -481,6 +481,26 @@ fn slice_update_outputs(node: &mut Node) {
}
}

fn sub_update_outputs(node: &mut Node) {
node.outputs[0].ty = match (node.inputs[0].ty.clone(), node.inputs[1].ty.clone()) {
(ArgType::Scalar(_lhs), ArgType::Tensor(rhs)) => ArgType::Tensor(rhs),
// TODO: remove after debugging
// (ArgType::Tensor(lhs), ArgType::Scalar(_rhs)) => ArgType::Tensor(lhs),
(ArgType::Tensor(lhs), ArgType::Tensor(rhs)) => {
// Support broadcasting for lhs/rhs
if lhs.dim > rhs.dim {
ArgType::Tensor(lhs)
} else {
ArgType::Tensor(rhs)
}
}
_ => {
println!("{0:?},{1:?}", node.inputs, node.inputs);
panic!("Only tensor-scalar inputs are valid.");
}
};
}

/// Update the output tensor dimension based on the "axes" attribute or the second input
fn unsqueeze_update_output(node: &mut Node) {
let axes = if node.inputs.len() == 2 {
Expand Down

0 comments on commit 785c5ed

Please sign in to comment.