Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add subtract tensor from scalar for ONNX sub op #1964

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,6 @@ target
.vs
.fleet
.ipynb_checkpoints/

# Generated IR and Burn Graph from ONNX
out
2 changes: 1 addition & 1 deletion crates/burn-import/DEVELOPMENT.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ To extend `burn-import` with support for new ONNX operators, follow these steps:
the Burn model in Rust code, and `my-model.json` includes the model data.

7. **Add End-to-End Test**: Include the test in `./burn-import/onnx-tests/tests/onnx_tests.rs`.
Further details can be found in the [onnx-tests README](./burn-import/onnx-tests/README.md).
Further details can be found in the [onnx-tests README](./onnx-tests/README.md).

## Testing

Expand Down
4 changes: 2 additions & 2 deletions crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ mod tests {
let input = Tensor::<Backend, 4>::from_floats([[[[1., 2., 3., 4.]]]], &device);
let scalar = 3.0f64;
let output = model.forward(input, scalar);
let expected = TensorData::from([[[[6f32, 7., 8., 9.]]]]);
let expected = TensorData::from([[[[-12f32, -13., -14., -15.]]]]);

output.to_data().assert_eq(&expected, true);
}
Expand All @@ -162,7 +162,7 @@ mod tests {
let input = Tensor::<Backend, 4, Int>::from_ints([[[[1, 2, 3, 4]]]], &device);
let scalar = 3;
let output = model.forward(input, scalar);
let expected = TensorData::from([[[[6i64, 6, 6, 6]]]]);
let expected = TensorData::from([[[[-12i64, -12, -12, -12]]]]);

output.to_data().assert_eq(&expected, true);
}
Expand Down
Binary file modified crates/burn-import/onnx-tests/tests/sub/sub.onnx
Binary file not shown.
10 changes: 7 additions & 3 deletions crates/burn-import/onnx-tests/tests/sub/sub.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ def forward(self, x, k):
# Sutract a scalar from a tensor
x = x - d

# Sutract a tensor from a scalar
x = d - x

return x


Expand All @@ -40,8 +43,9 @@ def main():

scalar = 3.0

torch.onnx.export(model, (dummy_input, scalar), onnx_name,
verbose=False, opset_version=16)
torch.onnx.export(
model, (dummy_input, scalar), onnx_name, verbose=False, opset_version=16
)

print("Finished exporting model to {}".format(onnx_name))

Expand All @@ -53,5 +57,5 @@ def main():
print("Test output data: {}".format(output))


if __name__ == '__main__':
if __name__ == "__main__":
main()
Binary file modified crates/burn-import/onnx-tests/tests/sub/sub_int.onnx
Binary file not shown.
10 changes: 7 additions & 3 deletions crates/burn-import/onnx-tests/tests/sub/sub_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def forward(self, x, k):
# Sutract a scalar from a tensor
x = x - d

# Sutract a tensor from a scalar
x = d - x

return x


Expand All @@ -41,8 +44,9 @@ def main():
test_input = torch.tensor([[[[1, 2, 3, 4]]]], device=device)
scalar = 3

torch.onnx.export(model, (test_input, scalar), onnx_name,
verbose=False, opset_version=16)
torch.onnx.export(
model, (test_input, scalar), onnx_name, verbose=False, opset_version=16
)

print("Finished exporting model to {}".format(onnx_name))

Expand All @@ -51,5 +55,5 @@ def main():
print("Test output data: {}".format(output))


if __name__ == '__main__':
if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions crates/burn-import/src/burn/node/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ impl BinaryNode {
(Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.sub(#rhs) },
(Type::Tensor(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs.sub_scalar(#rhs) },
(Type::Scalar(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs - #rhs },
(Type::Scalar(_), Type::Tensor(_)) => move |lhs, rhs| quote! { -#rhs.sub_scalar(#lhs) },
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

FYI @nathanielsimard , @laggui , @louisfd

Another reason to support native Scalar type in Burn. See our earlier design discussion: #1689 (comment)

{ -#rhs.sub_scalar(#lhs) }, solution to scalar - tensor will result in two operations instead of one.

_ => panic!("Subtraction is supported for tensor and scalar only"),
};

Expand Down
16 changes: 15 additions & 1 deletion crates/onnx-ir/src/dim_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ pub fn dim_inference(node: &mut Node) {
NodeType::Slice => slice_update_outputs(node),
NodeType::Softmax => same_as_input(node),
NodeType::Sqrt => same_as_input(node),
NodeType::Sub => same_as_input(node),
NodeType::Sub => sub_update_outputs(node),
NodeType::Sum => same_as_input(node),
NodeType::Tanh => same_as_input(node),
NodeType::Transpose => same_as_input(node),
Expand Down Expand Up @@ -481,6 +481,20 @@ fn slice_update_outputs(node: &mut Node) {
}
}

fn sub_update_outputs(node: &mut Node) {
node.outputs[0].ty = match (node.inputs[0].ty.clone(), node.inputs[1].ty.clone()) {
(ArgType::Scalar(_lhs), ArgType::Scalar(rhs)) => ArgType::Scalar(rhs),
(ArgType::Scalar(_lhs), ArgType::Tensor(rhs)) => ArgType::Tensor(rhs),
(ArgType::Tensor(lhs), ArgType::Scalar(_rhs)) => ArgType::Tensor(lhs),
// Support broadcasting for lhs/rhs
(ArgType::Tensor(lhs), ArgType::Tensor(rhs)) if lhs.dim > rhs.dim => ArgType::Tensor(lhs),
(ArgType::Tensor(lhs), ArgType::Tensor(rhs)) if lhs.dim <= rhs.dim => ArgType::Tensor(rhs),
_ => {
panic!("Only tensor-scalar inputs are valid.");
}
};
}

/// Update the output tensor dimension based on the "axes" attribute or the second input
fn unsqueeze_update_output(node: &mut Node) {
let axes = if node.inputs.len() == 2 {
Expand Down
Loading