Skip to content

Commit

Permalink
ConstrainedConv1d now has causal padding
Browse files Browse the repository at this point in the history
  • Loading branch information
fedepup committed Apr 10, 2024
1 parent 1d2d421 commit 31869b7
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 4 deletions.
28 changes: 25 additions & 3 deletions selfeeg/models/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,8 @@ class ConstrainedConv1d(nn.Conv1d):
Default = 1
padding: int, tuple or str, optional
Padding added to all four sides of the input.
Padding added to all four sides of the input. This class also accepts the
string 'causal', which triggers causal convolution like in Wavenet.
Default = 0
dilation: int or tuple, optional
Expand Down Expand Up @@ -268,6 +269,15 @@ class ConstrainedConv1d(nn.Conv1d):
constraint, set both min_norm and max_norm. To apply a UnitNorm constraint,
set both min_norm and max_norm to 1.0.
Note
----
When setting ``padding`` to ``"causal"``, padding will be internally changed
to an integer equal to ``(kernel_size - 1) * dilation``. Then, during forward,
the extra features are removed. This is preferable over F.pad, which can
lead to memory allocation or even non-deterministic operations during the
backboard pass. Additional information can be found at the following link:
https://github.com/pytorch/pytorch/issues/1333
Example
-------
>>> from import selfeeg.models import ConstrainedConv1d
Expand Down Expand Up @@ -301,12 +311,21 @@ def __init__(
axis_norm=[1,2],
minmax_rate=1.0
):

# Check causal Padding
self.pad = padding
self.causal_pad = False
if isinstance(padding, str):
if padding.casefold() == "causal":
self.causal_pad = True
self.pad = (kernel_size - 1) * dilation

super(ConstrainedConv1d, self).__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
self.pad,
dilation,
groups,
bias,
Expand Down Expand Up @@ -410,7 +429,10 @@ def forward(self, input):
"""
if self.constraint_type != 0:
self.scale_norm()
return self._conv_forward(input, self.weight, self.bias)
if self.causal_pad:
return self._conv_forward(input, self.weight, self.bias)[:,:,:-self.pad]
else:
return self._conv_forward(input, self.weight, self.bias)


class ConstrainedConv2d(nn.Conv2d):
Expand Down
2 changes: 1 addition & 1 deletion test/EEGself/models/layers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_ConstrainedConv1d(self):
"bias": [True, False],
"max_norm": [None, 1, 2],
"min_norm": [None, 1],
"padding": ["valid"],
"padding": ["valid", "causal"],
}
Conv_args = self.makeGrid(Conv_args)
for i in Conv_args:
Expand Down

0 comments on commit 31869b7

Please sign in to comment.