Skip to content

Commit

Permalink
Fix initialization_kwargs (#147)
Browse files Browse the repository at this point in the history
  • Loading branch information
HangJung97 authored Nov 16, 2023
1 parent c5a3baf commit 204d50a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
6 changes: 5 additions & 1 deletion ascent/models/components/encoders/convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,11 @@ def __init__(
# initialize weights
init_kwargs = {}
if activation == "leakyrelu":
if activation_kwargs is not None and "negative_slope" in activation_kwargs:
if (
activation_kwargs is not None
and "negative_slope" in activation_kwargs
and initialization == "kaiming_normal"
):
init_kwargs["neg_slope"] = activation_kwargs["negative_slope"]
self.apply(get_initialization(initialization, **init_kwargs))

Expand Down
6 changes: 5 additions & 1 deletion ascent/models/components/encoders/unet_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,11 @@ def __init__(
# initialize weights
init_kwargs = {}
if activation == "leakyrelu":
if activation_kwargs is not None and "negative_slope" in activation_kwargs:
if (
activation_kwargs is not None
and "negative_slope" in activation_kwargs
and initialization == "kaiming_normal"
):
init_kwargs["neg_slope"] = activation_kwargs["negative_slope"]
self.apply(get_initialization(initialization, **init_kwargs))

Expand Down

0 comments on commit 204d50a

Please sign in to comment.