Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Infer ONNX conv output shapes #2304

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 106 additions & 44 deletions crates/onnx-ir/src/dim_inference.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use core::cmp::max;
use core::panic;

use log::debug;
use log::{debug, warn};
use protobuf::Enum;

use crate::{
ir::{ArgType, AttributeValue, Data, ElementType, Node, NodeType, TensorType},
ir::{ArgType, AttributeValue, Data, ElementType, Node, NodeType, Shape, TensorType},
protos::tensor_proto::DataType,
util::{flatten_config, shape_config},
};
Expand All @@ -23,8 +23,9 @@ pub fn dim_inference(node: &mut Node) {
NodeType::Concat => concat_update_outputs(node),
NodeType::Constant => constant_update_outputs(node),
NodeType::ConstantOfShape => constant_of_shape_update_output(node),
NodeType::Conv1d => conv1d_update_outputs(node),
NodeType::Conv2d => conv2d_update_outputs(node),
NodeType::Conv1d => conv_update_outputs(node),
NodeType::Conv2d => conv_update_outputs(node),
NodeType::ConvTranspose2d => conv_transpose2d_update_outputs(node),
NodeType::Cos => same_as_input(node),
NodeType::Div => same_as_input_broadcast(node),
NodeType::Dropout => same_as_input(node),
Expand All @@ -33,14 +34,13 @@ pub fn dim_inference(node: &mut Node) {
NodeType::Exp => same_as_input(node),
NodeType::Expand => expand_update_outputs(node),
NodeType::Flatten => flatten_update_outputs(node),
NodeType::Gelu => same_as_input(node),
NodeType::Gather => gather_update_outputs(node),
NodeType::GatherElements => same_as_input(node),
NodeType::Gelu => same_as_input(node),
NodeType::GlobalAveragePool => same_as_input(node),
NodeType::Greater => elementwise_comparison_outputs(node),
NodeType::GreaterOrEqual => elementwise_comparison_outputs(node),
NodeType::HardSigmoid => same_as_input(node),
NodeType::GlobalAveragePool => same_as_input(node),
NodeType::ConvTranspose2d => conv_transpose2d_update_outputs(node),
NodeType::LayerNormalization => same_as_input(node),
NodeType::LeakyRelu => same_as_input(node),
NodeType::Less => elementwise_comparison_outputs(node),
Expand Down Expand Up @@ -210,25 +210,18 @@ fn linear_update_outputs(node: &mut Node) {

// Calculate the output shape. Usually we do not use shapes, but since the input shape is
// known, we can calculate the output shape.
if let ArgType::Tensor(tensor) = node_input.clone().ty {
let mut tensor = tensor.clone();

// Update the shape of the output tensor if it's known
if let Some(mut shape) = tensor.shape.clone() {
if let ArgType::Tensor(weight_tensor) = weight.clone().ty {
let last = shape.last_mut().unwrap();
*last = *weight_tensor.shape.unwrap().first().unwrap();
} else {
panic!("Weight must be a tensor");
}
tensor.shape = Some(shape);
}
let mut tensor = node_input.clone().ty.get_tensor_type().clone();

// Update the output tensor
node.outputs[0].ty = ArgType::Tensor(tensor);
} else {
panic!("Only tensor input is valid");
// Update the shape of the output tensor if it's known
if let Some(mut shape) = tensor.shape.clone() {
let weight_tensor = weight.ty.get_tensor_type().clone();
let last = shape.last_mut().unwrap();
*last = *weight_tensor.shape.unwrap().first().unwrap();
tensor.shape = Some(shape);
}

// Update the output tensor
node.outputs[0].ty = ArgType::Tensor(tensor);
}

/// Update the output type using "to" attribute
Expand Down Expand Up @@ -579,34 +572,103 @@ fn flatten_update_outputs(node: &mut Node) {
});
}

/// Infers the shape of a Conv1d node and replaces the shape of the output tensor.
fn conv1d_update_outputs(node: &mut Node) {
// extract the channels from the weight tensor's shape [out_channels, in_channels, ...]
if let ArgType::Tensor(tensor) = node.inputs[0].clone().ty {
node.outputs[0].ty = ArgType::Tensor(tensor);
} else {
panic!("Only tensor input is valid");
}
/// Calculates the output shape of a convolution operator
fn calculate_conv_output_shape(
x_shape: &Shape,
w_shape: &Shape,
strides: &[i64],
dilations: &[i64],
pads: &[i64],
group: usize,
) -> Option<Shape> {
let batch_size = x_shape[0];
let input_channels = x_shape[1];
let mut input_data_shape = x_shape[2..].to_vec();
let axis_count = input_data_shape.len();
let feature_maps = w_shape[0];
let filter_channels = w_shape[1];
let mut kernel_shape = w_shape[2..].to_vec();

// Sanity checks:
// if one of the assumptions doesn't hold, return a Shapeless TensorType instead of calculating possibly wrong values
if axis_count != kernel_shape.len() {
warn!(
"Axis count mismatch between input ({axis_count}) and kernel ({})",
kernel_shape.len()
);
return None;
}
if filter_channels % input_channels != 0 {
warn!("Conv filter channels ({filter_channels}) isn't divisible by input channels ({input_channels})");
return None;
}
if feature_maps % group != 0 {
warn!("Conv feature maps ({feature_maps}) isn't divisible by groups ({group})");
return None;
}
let channel_ratio = filter_channels / input_channels;
if group != channel_ratio {
warn!("Conv groups ({group}) doesn't match with channel ratio ({filter_channels} / {input_channels} = {channel_ratio})");
return None;
}

// Apply modifiers to shapes, calculating the effective shapes:
for (axis, axis_shape) in input_data_shape.iter_mut().enumerate() {
*axis_shape = *axis_shape + pads[axis] as usize + pads[axis + axis_count] as usize;
// add begin and end padding to axis
}
for (axis, axis_shape) in kernel_shape.iter_mut().enumerate() {
// e.g., a kernel axis with 3 values and dilation 2 is effectively 5 = ((3-1)*2)+1 wide
*axis_shape = ((*axis_shape - 1) * (dilations[axis] as usize)) + 1;
}

let mut out_shape = Vec::with_capacity(2 + axis_count);
out_shape.push(batch_size);
out_shape.push(feature_maps);
for axis in 0..axis_count {
out_shape
.push(((input_data_shape[axis] - kernel_shape[axis]) / strides[axis] as usize) + 1);
}

Some(out_shape)
}

/// Infers the shape of a Conv2d node and replaces the shape of the output tensor.
fn conv2d_update_outputs(node: &mut Node) {
// extract the channels from the weight tensor's shape [out_channels, in_channels, ...]
if let ArgType::Tensor(tensor) = node.inputs[0].clone().ty {
node.outputs[0].ty = ArgType::Tensor(tensor);
/// Infers the shape of a Conv1d or Conv2d node and replaces the shape of the output tensor.
fn conv_update_outputs(node: &mut Node) {
let input_ty = node.inputs[0].ty.get_tensor_type();
let weights_ty = node.inputs[1].ty.get_tensor_type();

let mut output_ty = input_ty.clone();

output_ty.shape = if let (Some(x_shape), Some(w_shape)) = (&input_ty.shape, &weights_ty.shape) {
let axis_count = x_shape.len() - 2;
let mut strides = vec![1; axis_count];
let mut dilations = vec![1; axis_count];
let mut pads = vec![0; axis_count * 2];
let mut group = 1;
for (key, value) in node.attrs.iter() {
match key.as_str() {
"strides" => strides = value.clone().into_i64s(),
"pads" => pads = value.clone().into_i64s(),
"dilations" => dilations = value.clone().into_i64s(),
"group" => group = value.clone().into_i64() as usize,
_ => {}
}
}

calculate_conv_output_shape(x_shape, w_shape, &strides, &dilations, &pads, group)
} else {
panic!("Only tensor input is valid");
}
None
};
node.outputs[0].ty = ArgType::Tensor(output_ty);
}

/// Infers the shape of a ConvTranspose2d node and replaces the shape of the output tensor.
fn conv_transpose2d_update_outputs(node: &mut Node) {
// extract the channels from the weight tensor's shape [out_channels, in_channels, ...]
if let ArgType::Tensor(tensor) = node.inputs[0].clone().ty {
node.outputs[0].ty = ArgType::Tensor(tensor);
} else {
panic!("Only tensor input is valid");
}
let mut tensor = node.inputs[0].ty.get_tensor_type().clone();
tensor.shape = None; //TODO: calculate the shape
node.outputs[0].ty = ArgType::Tensor(tensor);
}

fn matmul_update_outputs(node: &mut Node) {
Expand Down
12 changes: 12 additions & 0 deletions crates/onnx-ir/src/ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ impl ArgType {
pub fn is_tensor(&self) -> bool {
matches!(self, Self::Tensor(_))
}
pub fn is_shape(&self) -> bool {
matches!(self, Self::Shape(_))
}

/// returns the rank (dimension) of the Arg
pub fn rank(&self) -> usize {
Expand All @@ -153,6 +156,15 @@ impl ArgType {
ArgType::Tensor(t) => &t.elem_type,
}
}

/// returns the contained [`TensorType`] if this `ArgType` is a `Tensor`, else panics
pub fn get_tensor_type(&self) -> &TensorType {
if let Self::Tensor(tensor) = &self {
tensor
} else {
panic!("ArgType is no Tensor!");
}
}
}

impl Argument {
Expand Down