diff --git a/ascent/models/components/utils/layers.py b/ascent/models/components/utils/layers.py index 029918f..d1c2da5 100644 --- a/ascent/models/components/utils/layers.py +++ b/ascent/models/components/utils/layers.py @@ -107,6 +107,7 @@ def get_conv( stride: Union[int, tuple[int, ...], list[int]], dim: int, conv_bias: bool = True, + padding: Optional[Union[int, tuple[int, ...], list[int]]] = None, **kwargs, ) -> nn.Module: """Get 2D or 3D convolution layer. @@ -118,6 +119,8 @@ def get_conv( stride: Stride of the convolution. dim: Dimension of convolution. conv_bias: If True, adds a learnable bias to the convolution output + padding: Padding added to input. If None, padding is computed based on kernel size and + stride. **kwargs: Keyword arguments to be passed to either `nn.Conv2d` or `nn.Conv3d`. Returns: @@ -129,8 +132,8 @@ def get_conv( if dim not in [2, 3]: raise NotImplementedError(f"{dim}D convolution is not supported right now!") conv = convolutions[f"Conv{dim}d"] - padding = get_padding(kernel_size, stride) - return conv(in_channels, out_channels, kernel_size, stride, padding, bias=conv_bias, **kwargs) + pad = get_padding(kernel_size, stride) if padding is None else padding + return conv(in_channels, out_channels, kernel_size, stride, pad, bias=conv_bias, **kwargs) def get_transp_conv(