Skip to content

Commit

Permalink
Improve warning raising condition and warning message in convnext (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
HangJung97 authored Dec 12, 2023
1 parent 7889db4 commit 855631e
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions ascent/models/components/encoders/convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def __init__(
ValueError: When `len(expansion_rate)` is not equal to `num_stages`.
ValueError: When `len(n_conv_per_stage)` is not equal to `num_stages`.
ValueError: When `len(num_features_per_stage)` is not equal to `num_stages`.
Warning: When `return_skip` is set to True and the stem performs a pooling operation.
"""
super().__init__()

Expand Down Expand Up @@ -207,11 +208,12 @@ def __init__(
self.return_skip = return_skip
self.strides = [[i] * dim if isinstance(i, int) else list(i) for i in strides]

if return_skip and 2 in self.strides[0]:
if return_skip and not all(s == 1 for s in self.strides[0]):
warnings.warn(
"The stem performs a pooling operation and `return_skip` is set to True. The "
"resolution of the skip connections might be wrong and will cause an error if "
"paired with `UNetDecoder`."
"paired with `UNetDecoder`. The highest resolution output of the `UNetDecoder` "
"will be (input resolution / stride of the stem) rather than the input resolution."
)

# initialize weights
Expand Down

0 comments on commit 855631e

Please sign in to comment.