Skip to content

Commit

Permalink
Add case to match lhs as tensor and rhs as scalar
Browse files Browse the repository at this point in the history
  • Loading branch information
JC committed Jul 4, 2024
1 parent b0dfdff commit b0a5ac6
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions crates/onnx-ir/src/dim_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,7 @@ fn slice_update_outputs(node: &mut Node) {
fn sub_update_outputs(node: &mut Node) {
node.outputs[0].ty = match (node.inputs[0].ty.clone(), node.inputs[1].ty.clone()) {
(ArgType::Scalar(_lhs), ArgType::Tensor(rhs)) => ArgType::Tensor(rhs),
(ArgType::Tensor(lhs), ArgType::Scalar(_rhs)) => ArgType::Tensor(lhs),
(ArgType::Tensor(lhs), ArgType::Tensor(rhs)) => {
// Support broadcasting for lhs/rhs
if lhs.dim > rhs.dim {
Expand Down

0 comments on commit b0a5ac6

Please sign in to comment.