diff --git a/ascent/models/components/encoders/convnext.py b/ascent/models/components/encoders/convnext.py index d3a88c4..e131f29 100644 --- a/ascent/models/components/encoders/convnext.py +++ b/ascent/models/components/encoders/convnext.py @@ -86,6 +86,7 @@ def __init__( ValueError: When `len(expansion_rate)` is not equal to `num_stages`. ValueError: When `len(n_conv_per_stage)` is not equal to `num_stages`. ValueError: When `len(num_features_per_stage)` is not equal to `num_stages`. + Warning: When `return_skip` is set to True and the stem performs a pooling operation. """ super().__init__() @@ -207,11 +208,12 @@ def __init__( self.return_skip = return_skip self.strides = [[i] * dim if isinstance(i, int) else list(i) for i in strides] - if return_skip and 2 in self.strides[0]: + if return_skip and not all(s == 1 for s in self.strides[0]): warnings.warn( "The stem performs a pooling operation and `return_skip` is set to True. The " "resolution of the skip connections might be wrong and will cause an error if " - "paired with `UNetDecoder`." + "paired with `UNetDecoder`. The highest resolution output of the `UNetDecoder` " + "will be (input resolution / stride of the stem) rather than the input resolution." ) # initialize weights