From 1d2d421bb224f204fd3f27759448a9a85cce00d1 Mon Sep 17 00:00:00 2001 From: fedepup Date: Tue, 9 Apr 2024 19:25:55 +0200 Subject: [PATCH] updated constraint on custom layers --- RELEASE.md | 13 +- ...elfeeg.models.layers.ConstrainedConv1d.rst | 20 + ...lfeeg.models.layers.ConstrainedConv2d.rst} | 2 +- ...elfeeg.models.layers.ConstrainedDense.rst} | 2 +- ...selfeeg.models.layers.DepthwiseConv2d.rst} | 2 +- ...selfeeg.models.layers.SeparableConv2d.rst} | 2 +- docs/selfeeg.models.rst | 17 +- selfeeg/dataloading/load.py | 2 + selfeeg/models/__init__.py | 9 + selfeeg/models/layers.py | 970 ++++++++++++++++++ selfeeg/models/zoo.py | 427 +------- test/EEGself/models/layers_test.py | 240 +++++ test/EEGself/models/zoo_test.py | 157 +-- 13 files changed, 1282 insertions(+), 581 deletions(-) create mode 100644 docs/api/selfeeg.models.layers.ConstrainedConv1d.rst rename docs/api/{selfeeg.models.zoo.ConstrainedConv2d.rst => selfeeg.models.layers.ConstrainedConv2d.rst} (88%) rename docs/api/{selfeeg.models.zoo.ConstrainedDense.rst => selfeeg.models.layers.ConstrainedDense.rst} (88%) rename docs/api/{selfeeg.models.zoo.DepthwiseConv2d.rst => selfeeg.models.layers.DepthwiseConv2d.rst} (88%) rename docs/api/{selfeeg.models.zoo.SeparableConv2d.rst => selfeeg.models.layers.SeparableConv2d.rst} (85%) create mode 100644 selfeeg/models/layers.py create mode 100644 test/EEGself/models/layers_test.py diff --git a/RELEASE.md b/RELEASE.md index a07a17b..4bab2b1 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -2,7 +2,18 @@ **Functionality** -* EEGdataset can preload the entire dataset. +- **dataloading module**: + - EEGdataset can preload the entire dataset. +- **models module**: + - custom layers were moved in a new models.layer submodule + - layer constraints now include MaxNorm, MinMaxNorm, UnitNorm, with axis selection like in Keras. + - added Conv1d layer with norm constraint + +**Maintenance** + +* fixed typos on model module unittest. +* Added new tests for novel functionalities. + # Version 0.1.1 (latest) diff --git a/docs/api/selfeeg.models.layers.ConstrainedConv1d.rst b/docs/api/selfeeg.models.layers.ConstrainedConv1d.rst new file mode 100644 index 0000000..35a0f66 --- /dev/null +++ b/docs/api/selfeeg.models.layers.ConstrainedConv1d.rst @@ -0,0 +1,20 @@ +ConstrainedConv1d +================= + +.. currentmodule:: selfeeg.models.layers + +.. autoclass:: ConstrainedConv1d + :show-inheritance: + :noindex: + + .. rubric:: Methods Summary + + .. autosummary:: + + ~ConstrainedConv1d.forward + ~ConstrainedConv1d.scale_norm + + .. rubric:: Methods Documentation + + .. automethod:: forward + .. automethod:: scale_norm diff --git a/docs/api/selfeeg.models.zoo.ConstrainedConv2d.rst b/docs/api/selfeeg.models.layers.ConstrainedConv2d.rst similarity index 88% rename from docs/api/selfeeg.models.zoo.ConstrainedConv2d.rst rename to docs/api/selfeeg.models.layers.ConstrainedConv2d.rst index a90e513..c277b06 100644 --- a/docs/api/selfeeg.models.zoo.ConstrainedConv2d.rst +++ b/docs/api/selfeeg.models.layers.ConstrainedConv2d.rst @@ -1,7 +1,7 @@ ConstrainedConv2d ================= -.. currentmodule:: selfeeg.models.zoo +.. currentmodule:: selfeeg.models.layers .. autoclass:: ConstrainedConv2d :show-inheritance: diff --git a/docs/api/selfeeg.models.zoo.ConstrainedDense.rst b/docs/api/selfeeg.models.layers.ConstrainedDense.rst similarity index 88% rename from docs/api/selfeeg.models.zoo.ConstrainedDense.rst rename to docs/api/selfeeg.models.layers.ConstrainedDense.rst index e27a447..0bcc078 100644 --- a/docs/api/selfeeg.models.zoo.ConstrainedDense.rst +++ b/docs/api/selfeeg.models.layers.ConstrainedDense.rst @@ -1,7 +1,7 @@ ConstrainedDense ================ -.. currentmodule:: selfeeg.models.zoo +.. currentmodule:: selfeeg.models.layers .. autoclass:: ConstrainedDense :show-inheritance: diff --git a/docs/api/selfeeg.models.zoo.DepthwiseConv2d.rst b/docs/api/selfeeg.models.layers.DepthwiseConv2d.rst similarity index 88% rename from docs/api/selfeeg.models.zoo.DepthwiseConv2d.rst rename to docs/api/selfeeg.models.layers.DepthwiseConv2d.rst index 1a517b8..504f072 100644 --- a/docs/api/selfeeg.models.zoo.DepthwiseConv2d.rst +++ b/docs/api/selfeeg.models.layers.DepthwiseConv2d.rst @@ -1,7 +1,7 @@ DepthwiseConv2d =============== -.. currentmodule:: selfeeg.models.zoo +.. currentmodule:: selfeeg.models.layers .. autoclass:: DepthwiseConv2d :show-inheritance: diff --git a/docs/api/selfeeg.models.zoo.SeparableConv2d.rst b/docs/api/selfeeg.models.layers.SeparableConv2d.rst similarity index 85% rename from docs/api/selfeeg.models.zoo.SeparableConv2d.rst rename to docs/api/selfeeg.models.layers.SeparableConv2d.rst index e16dc73..12880fe 100644 --- a/docs/api/selfeeg.models.zoo.SeparableConv2d.rst +++ b/docs/api/selfeeg.models.layers.SeparableConv2d.rst @@ -1,7 +1,7 @@ SeparableConv2d =============== -.. currentmodule:: selfeeg.models.zoo +.. currentmodule:: selfeeg.models.layers .. autoclass:: SeparableConv2d :show-inheritance: diff --git a/docs/selfeeg.models.rst b/docs/selfeeg.models.rst index 8687da0..ea84df9 100644 --- a/docs/selfeeg.models.rst +++ b/docs/selfeeg.models.rst @@ -1,9 +1,20 @@ selfeeg.models ============== -This module collects various Deep Learning models proposed for EEG applications. -In addition, it implements some layers not directly available in the PyTorch nn.Module, such as: -a Depthwise Conv2d layer, a Separable Conv2d layer, Conv2d with max norm constraint, Linear layer with max norm constraint. +This module collects various Deep Learning models and custom layers. +It is divided in two submodules: + +- **layers**: a collection custom layers with the possibility to add norm constraints. +- **zoo**: a collection of deep learning models proposed for EEG applications. + +models.layers module +-------------------- + +.. automodapi:: selfeeg.models.layers + :no-inheritance-diagram: + :no-main-docstr: + :noindex: + :no-heading: models.zoo module ----------------- diff --git a/selfeeg/dataloading/load.py b/selfeeg/dataloading/load.py index a3c20ac..dda4628 100644 --- a/selfeeg/dataloading/load.py +++ b/selfeeg/dataloading/load.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import glob import math import os diff --git a/selfeeg/models/__init__.py b/selfeeg/models/__init__.py index b4f23c1..0d8f0fe 100644 --- a/selfeeg/models/__init__.py +++ b/selfeeg/models/__init__.py @@ -1,5 +1,6 @@ from .zoo import ( BasicBlock1, + ConstrainedConv1d, ConstrainedConv2d, ConstrainedDense, DeepConvNet, @@ -23,3 +24,11 @@ TinySleepNet, TinySleepNetEncoder, ) + +from .layers import ( + ConstrainedConv1d, + ConstrainedConv2d, + ConstrainedDense, + DepthwiseConv2d, + SeparableConv2d, +) diff --git a/selfeeg/models/layers.py b/selfeeg/models/layers.py new file mode 100644 index 0000000..1034367 --- /dev/null +++ b/selfeeg/models/layers.py @@ -0,0 +1,970 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = [ + "ConstrainedConv1d", + "ConstrainedConv2d", + "ConstrainedDense", + "DepthwiseConv2d", + "SeparableConv2d", +] + +class ConstrainedDense(nn.Linear): + """ + Pytorch implementation of the Dense layer with the possibility of adding a + MaxNorm, MinMaxNorm, or a UnitNorm constraint. Most of the parameters are + the same as described in torch.nn.Linear help. + + Parameters + ---------- + in_features: int + Number of input features. + out_channels: int + Number of output features. + bias: bool, optional + If True, adds a learnable bias to the output. + + Default = True + device: torch.device or str, optional + The torch device. + dtype: torch dtype, optional + layer dtype, i.e., the data type of the torch.Tensor defining the layer weights. + max_norm: float, optional + The maximum norm each hidden unit can have. + If None no constraint will be added. + + Default = 2.0 + min_norm: float, optional + The minimum norm each hidden unit can have. Must be a float + lower than max_norm. If given, MinMaxNorm will be applied in the case + max_norm is also given. Otherwise, it will be ignored. + + Default = None + axis_norm: Union[int, list, tuple], optional + The axis along weights are constrained. It behaves like Keras. So, considering + that a Conv2D layer has shape (output_depth, input_depth), set axis + to 1 will constrain the weights of each filter tensor of size + (input_depth,). + + Default = 1 + minmax_rate: float, optional + A constraint for MinMaxNorm setting how weights will be rescaled at each step. + It behaves like Keras `rate` argument of MinMaxNorm contraint. So, using + minmax_rate = 1 will set a strict enforcement of the constraint, + while rate<1.0 will slowly rescale layer's hidden units at each step. + + Default = 1.0 + + Note + ---- + To Apply a MaxNorm constraint, set only max_norm. To apply a MinMaxNorm + constraint, set both min_norm and max_norm. To apply a UnitNorm constraint, + set both min_norm and max_norm to 1.0. + + Example + ------- + >>> from selfeeg.models import ConstrainedDense + >>> import torch + >>> x = torch.randn(4,64) + >>> mdl = ConstrainedDense(64,32) + >>> out = mdl(x) + >>> norms = torch.sqrt(torch.sum(torch.square(mdl.weight), axis=1)) + >>> print(out.shape) # shoud return torch.Size([4, 32]) + >>> print(torch.isnan(out).sum()) # shoud return 0 + >>> print(torch.sum(norms>(1.4+1e-3)).item() == 0) # should return True + + """ + + def __init__( + self, + in_features, + out_features, + bias=True, + device=None, + dtype=None, + max_norm=2.0, + min_norm=None, + axis_norm=1, + minmax_rate=1.0 + ): + super(ConstrainedDense, self).__init__( + in_features, + out_features, + bias, + device, + dtype + ) + + # Check that max_norm is valid + if max_norm is not None: + if max_norm <= 0: + raise ValueError("max_norm can't be lower or equal than 0") + else: + self.max_norm = max_norm + else: + self.max_norm = max_norm + + # Check that min_norm is valid + if min_norm is not None: + if min_norm <= 0: + raise ValueError("min_norm can't be lower or equal than 0") + else: + self.min_norm = min_norm + else: + self.min_norm = min_norm + + # If both are given, check that max_norm is bigger than min_norm + if (self.min_norm is not None) and (self.max_norm is not None): + if self.min_norm > self.max_norm: + raise ValueError("max_norm can't be lower than min_norm") + + # Check that minmax_rate is bigger than 0 + if minmax_rate <= 0.0 or minmax_rate > 1.0: + raise ValueError("minmax_rate must be in (0,1]") + self.minmax_rate = minmax_rate + + # Check that axis is a valid enter + if type(axis_norm) not in [tuple, list, int]: + raise TypeError("axis must be a tuple, list, or int") + else: + if type(axis_norm) == int: + if axis_norm < 0 or axis_norm > 1: + raise ValueError("Linear has 2 axis. Values must be in 0 or 1") + else: + for i in axis_norm: + if i < 0 or i > 1: + raise ValueError("Axis values must be in 0 or 1") + self.axis_norm = axis_norm + + # set the constraint case: + # 0 --> no contraint + # 1 --> MaxNorm + # 2 --> MinMaxNorm + # 3 --> UnitNorm + # The division is for computational purpose. + # MinMaxNorm takes almost twice to execute than other operations. + if self.max_norm is not None: + if self.min_norm is not None: + if self.min_norm == 1 and self.max_norm == 1: + self.constraint_type = 3 + else: + self.constraint_type = 2 + else: + self.constraint_type = 1 + else: + self.constraint_type = 0 + + def scale_norm(self, eps=1e-9): + """ + ``scale_norm`` applies the desired constraint on the Layer. + It is highly based on the Keras implementation, but here + MaxNorm, MinMaxNorm and UnitNorm are all implemented inside + this function. + + """ + if self.constraint_type == 1: + norms = torch.sqrt( + torch.sum(torch.square(self.weight), axis=self.axis_norm, keepdims=True) + ) + desired = torch.clamp(norms, 0, self.max_norm) + self.weight = torch.nn.Parameter(self.weight * (desired / (eps + norms))) + + elif self.constraint_type == 2: + norms = torch.sqrt( + torch.sum(torch.square(self.weight), axis=self.axis_norm, keepdims=True) + ) + desired = (self.minmax_rate \ + * torch.clamp(norms, self.min_norm, self.max_norm) \ + + (1 - self.minmax_rate) * norms) + self.weight = torch.nn.Parameter(self.weight * (desired / (eps + norms))) + + elif self.constraint_type == 3: + norms = torch.sqrt( + torch.sum(torch.square(self.weight), axis=self.axis_norm, keepdims=True) + ) + self.weight = torch.nn.Parameter(self.weight / (eps + norms)) + + def forward(self, input): + """ + :meta private: + """ + if self.constraint_type != 0: + self.scale_norm() + return F.linear(input, self.weight, self.bias) + + +class ConstrainedConv1d(nn.Conv1d): + """ + Pytorch implementation of the 1D Convolutional layer with the possibility + to add a MaxNorm, MinMaxNorm, or UnitNorm constraint along the given axis. + Most of the parameters are the same as described in pytorch Conv2D help. + + Parameters + ---------- + in_channels: int + Number of input channels. + out_channels: int + Number of output channels. + kernel_size: int or tuple + Size of the convolving kernel. + stride: int or tuple, optional + Stride of the convolution. + + Default = 1 + padding: int, tuple or str, optional + Padding added to all four sides of the input. + + Default = 0 + dilation: int or tuple, optional + Spacing between kernel elements. + + Default = 1 + groups: int, optional + Number of blocked connections from input channels to output channels. + + Default = 1 + bias: bool, optional + If True, adds a learnable bias to the output. + + Default = True + padding_mode: str, optional + Any of 'zeros', 'reflect', 'replicate' or 'circular'. + + Default = 'zeros' + device: torch.device or str, optional + The torch device. + dtype: torch.dtype, optional + Layer dtype, i.e., the data type of the torch.Tensor defining the layer weights. + max_norm: float, optional + The maximum norm each hidden unit can have. + If None no constraint will be added. + + Default = 2.0 + min_norm: float, optional + The minimum norm each hidden unit can have. Must be a float + lower than max_norm. If given, MinMaxNorm will be applied in the case + max_norm is also given. Otherwise, it will be ignored. + + Default = None + axis_norm: Union[int, list, tuple], optional + The axis along weights are constrained. It behaves like Keras. So, considering + that a Conv2D layer has shape (output_depth, input_depth, length), set axis + to [1, 2] will constrain the weights of each filter tensor of size + (input_depth, length). + + Default = [1,2] + minmax_rate: float, optional + A constraint for MinMaxNorm setting how weights will be rescaled at each step. + It behaves like Keras `rate` argument of MinMaxNorm contraint. So, using + minmax_rate = 1 will set a strict enforcement of the constraint, + while rate<1.0 will slowly rescale layer's hidden units at each step. + + Default = 1.0 + + Note + ---- + To Apply a MaxNorm constraint, set only max_norm. To apply a MinMaxNorm + constraint, set both min_norm and max_norm. To apply a UnitNorm constraint, + set both min_norm and max_norm to 1.0. + + Example + ------- + >>> from import selfeeg.models import ConstrainedConv1d + >>> import torch + >>> x = torch.randn(4, 8, 64) + >>> mdl = ConstrainedConv1d(8, 16, 15, max_norm = 1.4, min_norm = 0.3) + >>> mdl.weight = torch.nn.Parameter(mdl.weight*10) + >>> out = mdl(x) + >>> norms = torch.sqrt(torch.sum(torch.square(mdl.weight), axis=[1,2])) + >>> print(out.shape) # shoud return torch.Size([4, 16, 64]) + >>> print(torch.isnan(out).sum()) # shoud return 0 + >>> print(torch.sum(norms>(1.4+1e-3)).item() == 0) # should return True + + """ + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding="same", + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + device=None, + dtype=None, + max_norm=2.0, + min_norm=None, + axis_norm=[1,2], + minmax_rate=1.0 + ): + super(ConstrainedConv1d, self).__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode, + device, + dtype + ) + + # Check that max_norm is valid + if max_norm is not None: + if max_norm <= 0: + raise ValueError("max_norm can't be lower or equal than 0") + else: + self.max_norm = max_norm + else: + self.max_norm = max_norm + + # Check that min_norm is valid + if min_norm is not None: + if min_norm <= 0: + raise ValueError("min_norm can't be lower or equal than 0") + else: + self.min_norm = min_norm + else: + self.min_norm = min_norm + + # If both are given, check that max_norm is bigger than min_norm + if (self.min_norm is not None) and (self.max_norm is not None): + if self.min_norm > self.max_norm: + raise ValueError("max_norm can't be lower than min_norm") + + # Check that minmax_rate is bigger than 0 + if minmax_rate <= 0.0 or minmax_rate > 1.0: + raise ValueError("minmax_rate must be in (0,1]") + self.minmax_rate = minmax_rate + + # Check that axis is a valid enter + if type(axis_norm) not in [tuple, list, int]: + raise TypeError("axis must be a tuple, list, or int") + else: + if type(axis_norm) == int: + if axis_norm < 0 or axis_norm > 2: + raise ValueError("Conv2D has 4 axis. Values must be in [0, 2]") + else: + for i in axis_norm: + if i < 0 or i > 2: + raise ValueError("Axis values must be in [0, 2]") + self.axis_norm = axis_norm + + # set the constraint case: + # 0 --> no contraint + # 1 --> MaxNorm + # 2 --> MinMaxNorm + # 3 --> UnitNorm + # The division is for computational purpose. + # MinMaxNorm takes almost twice to execute than other operations. + if self.max_norm is not None: + if self.min_norm is not None: + if self.min_norm == 1 and self.max_norm == 1: + self.constraint_type = 3 + else: + self.constraint_type = 2 + else: + self.constraint_type = 1 + else: + self.constraint_type = 0 + + def scale_norm(self, eps=1e-9): + """ + ``scale_norm`` applies the desired constraint on the Layer. + It is highly based on the Keras implementation, but here + MaxNorm, MinMaxNorm and UnitNorm are all implemented inside + this function. + + """ + if self.constraint_type == 1: + norms = torch.sqrt( + torch.sum(torch.square(self.weight), axis=self.axis_norm, keepdims=True) + ) + desired = torch.clamp(norms, 0, self.max_norm) + self.weight = torch.nn.Parameter(self.weight * (desired / (eps + norms))) + + elif self.constraint_type == 2: + norms = torch.sqrt( + torch.sum(torch.square(self.weight), axis=self.axis_norm, keepdims=True) + ) + desired = (self.minmax_rate \ + * torch.clamp(norms, self.min_norm, self.max_norm) \ + + (1 - self.minmax_rate) * norms) + self.weight = torch.nn.Parameter(self.weight * (desired / (eps + norms))) + + elif self.constraint_type == 3: + norms = torch.sqrt( + torch.sum(torch.square(self.weight), axis=self.axis_norm, keepdims=True) + ) + self.weight = torch.nn.Parameter(self.weight / (eps + norms)) + + def forward(self, input): + """ + :meta private: + """ + if self.constraint_type != 0: + self.scale_norm() + return self._conv_forward(input, self.weight, self.bias) + + +class ConstrainedConv2d(nn.Conv2d): + """ + Pytorch implementation of the 2D Convolutional layer with the possibility + to add a MaxNorm, MinMaxNorm, or UnitNorm constraint along the given axis. + Most of the parameters are the same as described in pytorch Conv2D help. + + Parameters + ---------- + in_channels: int + Number of input channels. + out_channels: int + Number of output channels. + kernel_size: int or tuple + Size of the convolving kernel. + stride: int or tuple, optional + Stride of the convolution. + + Default = 1 + padding: int, tuple or str, optional + Padding added to all four sides of the input. + + Default = 0 + dilation: int or tuple, optional + Spacing between kernel elements. + + Default = 1 + groups: int, optional + Number of blocked connections from input channels to output channels. + + Default = 1 + bias: bool, optional + If True, adds a learnable bias to the output. + + Default = True + padding_mode: str, optional + Any of 'zeros', 'reflect', 'replicate' or 'circular'. + + Default = 'zeros' + device: torch.device or str, optional + The torch device. + dtype: torch.dtype, optional + Layer dtype, i.e., the data type of the torch.Tensor defining the layer weights. + max_norm: float, optional + The maximum norm each hidden unit can have. + If None no constraint will be added. + + Default = 2.0 + min_norm: float, optional + The minimum norm each hidden unit can have. Must be a float + lower than max_norm. If given, MinMaxNorm will be applied in the case + max_norm is also given. Otherwise, it will be ignored. + + Default = None + axis_norm: Union[int, list, tuple], optional + The axis along weights are constrained. It behaves like Keras. So, considering + that a Conv2D layer has shape (output_depth, input_depth, rows, cols), set axis + to [1, 2, 3] will constrain the weights of each filter tensor of size + (input_depth, rows, cols). + + Default = [1,2,3] + minmax_rate: float, optional + A constraint for MinMaxNorm setting how weights will be rescaled at each step. + It behaves like Keras `rate` argument of MinMaxNorm contraint. So, using + minmax_rate = 1 will set a strict enforcement of the constraint, + while rate<1.0 will slowly rescale layer's hidden units at each step. + + Default = 1.0 + + Note + ---- + To Apply a MaxNorm constraint, set only max_norm. To apply a MinMaxNorm + constraint, set both min_norm and max_norm. To apply a UnitNorm constraint, + set both min_norm and max_norm to 1.0. + + Example + ------- + >>> from import selfeeg.models import ConstrainedConv2d + >>> import torch + >>> x = torch.randn(4, 1, 8, 64) + >>> mdl = ConstrainedConv2d(1, 4, (1, 15), max_norm = 0.5) + >>> mdl.weight = torch.nn.Parameter(mdl.weight*10) + >>> out = mdl(x) + >>> norms = torch.sqrt(torch.sum(torch.square(mdl.weight), axis=[1,2,3])) + >>> print(out.shape) # shoud return torch.Size([4, 2, 8, 64]) + >>> print(torch.isnan(out).sum()) # shoud return 0 + >>> print(torch.sum(norms>(0.5+1e-3)).item() == 0) # should return True + + """ + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding="same", + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + device=None, + dtype=None, + max_norm=2.0, + min_norm=None, + axis_norm=[1,2,3], + minmax_rate=1.0 + ): + super(ConstrainedConv2d, self).__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode, + device, + dtype, + ) + + # Check that max_norm is valid + if max_norm is not None: + if max_norm <= 0: + raise ValueError("max_norm can't be lower or equal than 0") + else: + self.max_norm = max_norm + else: + self.max_norm = max_norm + + # Check that min_norm is valid + if min_norm is not None: + if min_norm <= 0: + raise ValueError("min_norm can't be lower or equal than 0") + else: + self.min_norm = min_norm + else: + self.min_norm = min_norm + + # If both are given, check that max_norm is bigger than min_norm + if (self.min_norm is not None) and (self.max_norm is not None): + if self.min_norm > self.max_norm: + raise ValueError("max_norm can't be lower than min_norm") + + # Check that minmax_rate is bigger than 0 + if minmax_rate <= 0.0 or minmax_rate > 1.0: + raise ValueError("minmax_rate must be in (0,1]") + self.minmax_rate = minmax_rate + + # Check that axis is a valid enter + if type(axis_norm) not in [tuple, list, int]: + raise TypeError("axis must be a tuple, list, or int") + else: + if type(axis_norm) == int: + if axis_norm < 0 or axis_norm > 3: + raise ValueError("Conv2D has 4 axis. Values must be in [0, 3]") + else: + for i in axis_norm: + if i < 0 or i > 3: + raise ValueError("Axis values must be in [0, 3]") + self.axis_norm = axis_norm + + # set the constraint case: + # 0 --> no contraint + # 1 --> MaxNorm + # 2 --> MinMaxNorm + # 3 --> UnitNorm + # The division is for computational purpose. + # MinMaxNorm takes almost twice to execute than other operations. + if self.max_norm is not None: + if self.min_norm is not None: + if self.min_norm == 1 and self.max_norm == 1: + self.constraint_type = 3 + else: + self.constraint_type = 2 + else: + self.constraint_type = 1 + else: + self.constraint_type = 0 + + def scale_norm(self, eps=1e-9): + """ + ``scale_norm`` applies the desired constraint on the Layer. + It is highly based on the Keras implementation, but here + MaxNorm, MinMaxNorm and UnitNorm are all implemented inside + this function. + + """ + if self.constraint_type == 1: + norms = torch.sqrt( + torch.sum(torch.square(self.weight), axis=self.axis_norm, keepdims=True) + ) + desired = torch.clamp(norms, 0, self.max_norm) + self.weight = torch.nn.Parameter(self.weight * (desired / (eps + norms))) + + elif self.constraint_type == 2: + norms = torch.sqrt( + torch.sum(torch.square(self.weight), axis=self.axis_norm, keepdims=True) + ) + desired = (self.minmax_rate \ + * torch.clamp(norms, self.min_norm, self.max_norm) \ + + (1 - self.minmax_rate) * norms) + self.weight = torch.nn.Parameter(self.weight * (desired / (eps + norms))) + + elif self.constraint_type == 3: + norms = torch.sqrt( + torch.sum(torch.square(self.weight), axis=self.axis_norm, keepdims=True) + ) + self.weight = torch.nn.Parameter(self.weight / (eps + norms)) + + def forward(self, input): + """ + :meta private: + """ + if self.constraint_type != 0: + self.scale_norm() + return self._conv_forward(input, self.weight, self.bias) + + +class DepthwiseConv2d(nn.Conv2d): + """ + Pytorch implementation of the Depthwise Convolutional layer with + the possibility to add a MaxNorm, MinMaxNorm, or UnitNorm constraint along + the given axis. Most of the parameters are the same as described in pytorch + Conv2D help. + + Parameters + ---------- + in_channels: int + Number of input channels. + depth_multiplier: int + The depth multiplier. Output channels will be depth_multiplier*in_channels. + kernel_size: int or tuple + Size of the convolving kernel. + stride: int or tuple, optional + Stride of the convolution. + + Default = 1 + padding: int, tuple or str, optional + Padding added to all four sides of the input. + + Default = 0 + dilation: int or tuple, optional + Spacing between kernel elements. + + Default = 1 + bias: bool, optional + If True, adds a learnable bias to the output. + + Default = True + max_norm: float, optional + The maximum norm each hidden unit can have. + If None no constraint will be added. + + Default = 2.0 + min_norm: float, optional + The minimum norm each hidden unit can have. Must be a float + lower than max_norm. If given, MinMaxNorm will be applied in the case + max_norm is also given. Otherwise, it will be ignored. + + Default = None + axis_norm: Union[int, list, tuple], optional + The axis along weights are constrained. It behaves like Keras. So, considering + that a Conv2D layer has shape (output_depth, input_depth, rows, cols), set axis + to [1, 2, 3] will constrain the weights of each filter tensor of size + (input_depth, rows, cols). + + Default = [1,2,3] + minmax_rate: float, optional + A constraint for MinMaxNorm setting how weights will be rescaled at each step. + It behaves like Keras `rate` argument of MinMaxNorm contraint. So, using + minmax_rate = 1 will set a strict enforcement of the constraint, + while rate<1.0 will slowly rescale layer's hidden units at each step. + + Default = 1.0 + + Note + ---- + To Apply a MaxNorm constraint, set only max_norm. To apply a MinMaxNorm + constraint, set both min_norm and max_norm. To apply a UnitNorm constraint, + set both min_norm and max_norm to 1.0. + + Example + ------- + >>> from import selfeeg.models import DepthwiseConv2d + >>> import torch + >>> x = torch.randn(4,1,8,64) + >>> mdl = DepthwiseConv2d(1, 2, (1, 15), max_norm = 0.5) + >>> mdl.weight = torch.nn.Parameter(mdl.weight*10) + >>> out = mdl(x) + >>> norms = torch.sqrt(torch.sum(torch.square(mdl.weight), axis=[1,2,3])) + >>> print(out.shape) # shoud return torch.Size([4, 2, 8, 64]) + >>> print(torch.isnan(out).sum()) # shoud return 0 + >>> print(torch.sum(norms>(0.5+1e-3)).item() == 0) # should return True + + """ + + def __init__( + self, + in_channels, + depth_multiplier, + kernel_size, + stride=1, + padding="same", + dilation=1, + bias=False, + max_norm=2.0, + min_norm=None, + axis_norm=[1,2,3], + minmax_rate=1.0 + ): + super(DepthwiseConv2d, self).__init__( + in_channels, + depth_multiplier * in_channels, + kernel_size, + groups=in_channels, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias + ) + + # Check that max_norm is valid + if max_norm is not None: + if max_norm <= 0: + raise ValueError("max_norm can't be lower or equal than 0") + else: + self.max_norm = max_norm + else: + self.max_norm = max_norm + + # Check that min_norm is valid + if min_norm is not None: + if min_norm <= 0: + raise ValueError("min_norm can't be lower or equal than 0") + else: + self.min_norm = min_norm + else: + self.min_norm = min_norm + + # If both are given, check that max_norm is bigger than min_norm + if (self.min_norm is not None) and (self.max_norm is not None): + if self.min_norm > self.max_norm: + raise ValueError("max_norm can't be lower than min_norm") + + # Check that minmax_rate is bigger than 0 + if minmax_rate <= 0.0 or minmax_rate > 1.0: + raise ValueError("minmax_rate must be in (0,1]") + self.minmax_rate = minmax_rate + + # Check that axis is a valid enter + if type(axis_norm) not in [tuple, list, int]: + raise TypeError("axis must be a tuple, list, or int") + else: + if type(axis_norm) == int: + if axis_norm < 0 or axis_norm > 3: + raise ValueError("Conv2D has 4 axis. Values must be in [0, 3]") + else: + for i in axis_norm: + if i < 0 or i > 3: + raise ValueError("Axis values must be in [0, 3]") + self.axis_norm = axis_norm + + # set the constraint case: + # 0 --> no contraint + # 1 --> MaxNorm + # 2 --> MinMaxNorm + # 3 --> UnitNorm + # The division is for computational purpose. + # MinMaxNorm takes almost twice to execute than other operations. + if self.max_norm is not None: + if self.min_norm is not None: + if self.min_norm == 1 and self.max_norm == 1: + self.constraint_type = 3 + else: + self.constraint_type = 2 + else: + self.constraint_type = 1 + else: + self.constraint_type = 0 + + def scale_norm(self, eps=1e-9): + """ + ``scale_norm`` applies the desired constraint on the Layer. + It is highly based on the Keras implementation, but here + MaxNorm, MinMaxNorm and UnitNorm are all implemented inside + this function. + + """ + if self.constraint_type == 1: + norms = torch.sqrt( + torch.sum(torch.square(self.weight), axis=self.axis_norm, keepdims=True) + ) + desired = torch.clamp(norms, 0, self.max_norm) + self.weight = torch.nn.Parameter(self.weight * (desired / (eps + norms))) + + elif self.constraint_type == 2: + norms = torch.sqrt( + torch.sum(torch.square(self.weight), axis=self.axis_norm, keepdims=True) + ) + desired = (self.minmax_rate \ + * torch.clamp(norms, self.min_norm, self.max_norm) \ + + (1 - self.minmax_rate) * norms) + self.weight = torch.nn.Parameter(self.weight * (desired / (eps + norms))) + + elif self.constraint_type == 3: + norms = torch.sqrt( + torch.sum(torch.square(self.weight), axis=self.axis_norm, keepdims=True) + ) + self.weight = torch.nn.Parameter(self.weight / (eps + norms)) + + def forward(self, input): + """ + :meta private: + """ + if self.constraint_type != 0: + self.scale_norm() + return self._conv_forward(input, self.weight, self.bias) + + +class SeparableConv2d(nn.Module): + """ + Pytorch implementation of the Separable Convolutional layer with the possibility of + adding a norm constraint on the depthwise filters (feature) dimension. + The layer applies first a depthwise conv2d, then a pointwise conv2d (kernel size = 1) + Most of the parameters are the same as described in pytorch conv2D help. + + Parameters + ---------- + in_channels: int + Number of input channels. + out_channels: int + Number of output channels. + kernel_size: int or tuple + Size of the convolving kernel + stride: int or tuple, optional + Stride of the convolution. + + Default = 1 + padding: int, tuple or str, optional + Padding added to all four sides of the input. + + Default = 0 + dilation: int or tuple, optional + Spacing between kernel elements. + + Default = 1 + bias: bool, optional + If True, adds a learnable bias to the output. + + Default = True + depth_multiplier: int, optional + The depth multiplier of the depthwise block. + + Default = 1 + depth_max_norm: float, optional + The maximum norm each hidden unit in the depthwise layer can have. + If None no constraint will be added. + + Default = None + depth_min_norm: float, optional + The minimum norm each hidden unit in the depthwise layer can have. + Must be a float lower than max_norm. If given, MinMaxNorm will be applied + in the case max_norm is also given. Otherwise, it will be ignored. + + Default = None + depth_minmax_rate: float, optional + A constraint for depthwise's MinMaxNorm setting how weights will be rescaled + at each step. It behaves like Keras `rate` argument of MinMaxNorm contraint. + So, using minmax_rate = 1 will set a strict enforcement of the constraint, + while rate<1.0 will slowly rescale layer's hidden units at each step. + + Default = 1.0 + axis_norm: Union[int, list, tuple], optional + The axis along weights are constrained. It behaves like Keras. So, considering + that a Conv2D layer has shape (output_depth, input_depth), set axis + to 1 will constrain the weights of each filter tensor of size + (input_depth,). + + Default = 1 + point_max_norm: float, optional + Same as depth_max_norm, but applied to the pointwise Convolutional layer. + + Default = None + point_min_norm: float, optional + Same as depth_min_norm, but applied to the pointwise Convolutional layer. + + Default = None + point_minmax_rate: float, optional + Same as depth_minmax_rate, but applied to the pointwise Convolutional layer. + + Default = 1.0 + + Example + ------- + >>> from selfeeg.models import SeparableConv2d + >>> import torch + >>> x = torch.randn(4, 1, 8, 64) + >>> mdl = SeparableConv2d(1,4, (1,15), depth_multiplier=4) + >>> out = mdl(x) + >>> print(out.shape) # shoud return torch.Size([4, 4, 8, 64]) + >>> print(torch.isnan(out).sum()) # shoud return 0 + + """ + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding="same", + dilation=1, + bias=False, + depth_multiplier=1, + depth_max_norm=None, + depth_min_norm=None, + depth_minmax_rate=1.0, + point_max_norm=None, + point_min_norm=None, + point_minmax_rate=1.0, + axis_norm=[1,2,3] + ): + super(SeparableConv2d, self).__init__() + self.depthwise = DepthwiseConv2d( + in_channels, + depth_multiplier, + kernel_size, + stride, + padding, + dilation, + bias, + max_norm=depth_max_norm, + min_norm=depth_min_norm, + axis_norm=axis_norm, + minmax_rate=depth_minmax_rate + ) + self.pointwise = ConstrainedConv2d( + in_channels * depth_multiplier, + out_channels, + kernel_size=1, + bias=bias, + max_norm=point_max_norm, + min_norm=point_min_norm, + axis_norm=axis_norm, + minmax_rate=point_minmax_rate + ) + + def forward(self, input): + """ + :meta private: + """ + out = self.depthwise(input) + out = self.pointwise(out) + return out diff --git a/selfeeg/models/zoo.py b/selfeeg/models/zoo.py index 8a89579..6637817 100755 --- a/selfeeg/models/zoo.py +++ b/selfeeg/models/zoo.py @@ -1,14 +1,18 @@ import torch import torch.nn as nn import torch.nn.functional as F +from .layers import ( + ConstrainedConv1d, + ConstrainedConv2d, + ConstrainedDense, + DepthwiseConv2d, + SeparableConv2d +) __all__ = [ "BasicBlock1", - "ConstrainedConv2d", - "ConstrainedDense", "DeepConvNet", "DeepConvNetEncoder", - "DepthwiseConv2d", "EEGInception", "EEGInceptionEncoder", "EEGNet", @@ -17,7 +21,6 @@ "EEGSymEncoder", "ResNet1D", "ResNet1DEncoder", - "SeparableConv2d", "ShallowNet", "ShallowNetEncoder", "StagerNet", @@ -29,422 +32,6 @@ ] -# ### Special Kernels not implemented in pytorch -class DepthwiseConv2d(nn.Conv2d): - """ - Pytorch implementation of the Depthwise Convolutional layer with the possibility to - add a norm constraint on the filter (feature) dimension. - Most of the parameters are the same as described in pytorch conv2D help. - - Parameters - ---------- - in_channels: int - Number of input channels. - depth_multiplier: int - The depth multiplier. Output channels will be depth_multiplier*in_channels. - kernel_size: int or tuple - Size of the convolving kernel. - stride: int or tuple, optional - Stride of the convolution. - - Default = 1 - padding: int, tuple or str, optional - Padding added to all four sides of the input. - - Default = 0 - dilation: int or tuple, optional - Spacing between kernel elements. - - Default = 1 - bias: bool, optional - If True, adds a learnable bias to the output. - - Default = True - max_norm: float, optional - The maximum norm each filter can have. If None no constraint will be included. - - Default = None - - Example - ------- - >>> import selfeeg.models - >>> import torch - >>> x = torch.randn(4,1,8,64) - >>> mdl = models.DepthwiseConv2d(1,2,(1,15)) - >>> out = mdl(x) - >>> print(out.shape) # shoud return torch.Size([4, 2, 8, 64]) - >>> print(torch.isnan(out).sum()) # shoud return 0 - - """ - - def __init__( - self, - in_channels, - depth_multiplier, - kernel_size, - stride=1, - padding="same", - dilation=1, - bias=False, - max_norm=None, - ): - super(DepthwiseConv2d, self).__init__( - in_channels, - depth_multiplier * in_channels, - kernel_size, - groups=in_channels, - stride=stride, - padding=padding, - dilation=dilation, - bias=bias, - ) - if max_norm is not None: - if max_norm <= 0: - raise ValueError("max_norm can't be lower or equal than 0") - else: - self.max_norm = max_norm - else: - self.max_norm = max_norm - - @torch.no_grad() - def scale_norm(self, eps=1e-9): - """ - Citing the Tensorflow documentation, the implementation tries to replicate this - - integer, axis along which to calculate weight norms. - For instance, in a Dense layer the weight - matrix has shape (input_dim, output_dim), set axis to 0 to constrain each - weight vector of length (input_dim,). In a Conv2D layer with - data_format="channels_last", the weight tensor has - shape (rows, cols, input_depth, output_depth), set axis to [0, 1, 2] - to constrain the weights - of each filter tensor of size (rows, cols, input_depth). - - :meta private: - """ - # calculate the norm of each filter of size (row, cols, input_depth), here (1, kernel_size) - if self.kernel_size[1] > 1: - norm = self.weight.norm(dim=2, keepdim=True).norm(dim=3, keepdim=True) - else: - norm = self.weight.norm(dim=2, keepdim=True) - - # rescale only those filters which have a norm bigger than the maximum allowed - if (norm > self.max_norm).sum() > 0: - desired = torch.clamp(norm, 0, self.max_norm) - self.weight = torch.nn.Parameter(self.weight * desired / (eps + norm)) - - def forward(self, input): - """ - :meta private: - """ - if self.max_norm is not None: - self.scale_norm(self.max_norm) - return self._conv_forward(input, self.weight, self.bias) - - -class SeparableConv2d(nn.Module): - """ - Pytorch implementation of the Separable Convolutional layer with the possibility of - adding a norm constraint on the depthwise filters (feature) dimension. - The layer applies first a depthwise conv2d, then a pointwise conv2d (kernel size = 1) - Most of the parameters are the same as described in pytorch conv2D help. - - Parameters - ---------- - in_channels: int - Number of input channels. - out_channels: int - Number of output channels. - kernel_size: int or tuple - Size of the convolving kernel - stride: int or tuple, optional - Stride of the convolution. - - Default = 1 - padding: int, tuple or str, optional - Padding added to all four sides of the input. - - Default = 0 - dilation: int or tuple, optional - Spacing between kernel elements. - - Default = 1 - bias: bool, optional - If True, adds a learnable bias to the output. - - Default = True - depth_multiplier: int, optional - The depth multiplier of the depthwise block. - - Default = 1 - depth_max_norm: float, optional - The maximum norm each filter can have in the depthwise block. - If None no constraint will be included. - - Default = None - - Example - ------- - >>> import selfeeg.models - >>> import torch - >>> x = torch.randn(4,1,8,64) - >>> mdl = models.SeparableConv2d(1,4,(1,15), depth_multiplier=4) - >>> out = mdl(x) - >>> print(out.shape) # shoud return torch.Size([4, 4, 8, 64]) - >>> print(torch.isnan(out).sum()) # shoud return 0 - - """ - - def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride=1, - padding="same", - dilation=1, - bias=False, - depth_multiplier=1, - depth_max_norm=None, - ): - super(SeparableConv2d, self).__init__() - self.depthwise = DepthwiseConv2d( - in_channels, - depth_multiplier, - kernel_size, - stride, - padding, - dilation, - bias, - max_norm=None, - ) - self.pointwise = nn.Conv2d( - in_channels * depth_multiplier, out_channels, kernel_size=1, bias=bias - ) - - def forward(self, x): - """ - :meta private: - """ - out = self.depthwise(x) - out = self.pointwise(out) - return out - - -class ConstrainedDense(nn.Linear): - """ - Pytorch implementation of the Dense layer with the possibility of adding a norm constraint. - Most of the parameters are the same as described in pytorch Linear help. - - Parameters - ---------- - in_features: int - Number of input features. - out_channels: int - Number of output features. - bias: bool, optional - If True, adds a learnable bias to the output. - - Default = True - device: torch.device or str, optional - The torch device. - dtype: torch dtype, optional - layer dtype, i.e., the data type of the torch.Tensor defining the layer weights. - max_norm: float, optional - The maximum norm of the layer. If None no constraint will be included. - - Default = None - - Example - ------- - >>> import selfeeg.models - >>> import torch - >>> x = torch.randn(4,64) - >>> mdl = models.ConstrainedDense(64,32) - >>> out = mdl(x) - >>> print(out.shape) # shoud return torch.Size([4, 32]) - >>> print(torch.isnan(out).sum()) # shoud return 0 - - """ - - def __init__( - self, in_features, out_features, bias=True, device=None, dtype=None, max_norm=None - ): - super(ConstrainedDense, self).__init__(in_features, out_features, bias, device, dtype) - - if max_norm is not None: - if max_norm <= 0: - raise ValueError("max_norm can't be lower or equal than 0") - else: - self.max_norm = max_norm - else: - self.max_norm = max_norm - - @torch.no_grad() - def scale_norm(self, eps=1e-9): - """ - Citing the Tensorflow documentation, the implementation tries to replicate this - - integer, axis along which to calculate weight norms. For instance, in a Dense - layer the weight matrix has shape - (input_dim, output_dim), set axis to 0 to constrain each weight vector of length - (input_dim,). - In a Conv2D layer with data_format="channels_last", the weight tensor has shape (rows, - cols, input_depth, output_depth), - set axis to [0, 1, 2] to constrain the weights of each filter tensor of size (rows, - cols, input_depth). - - :meta private: - """ - # calculate the norm of each filter of size (row, cols, input_depth), - # here (1, kernel_size) - norm = self.weight.norm(dim=1, keepdim=True) - - # rescale only those filters which have a norm bigger than the maximum allowed - if (norm > self.max_norm).sum() > 0: - desired = torch.clamp(norm, 0, self.max_norm) - self.weight = torch.nn.Parameter(self.weight * desired / (eps + norm)) - - def forward(self, input): - """ - :meta private: - """ - if self.max_norm is not None: - self.scale_norm(self.max_norm) - return F.linear(input, self.weight, self.bias) - - -class ConstrainedConv2d(nn.Conv2d): - """ - Pytorch implementation of the Convolutional 2D layer with the possibilty of - adding a max_norm constraint on the filter (feature) dimension. - Most of the parameters are the same as described in pytorch conv2D help. - - Parameters - ---------- - in_channels: int - Number of input channels. - out_channels: int - Number of output channels. - kernel_size: int or tuple - Size of the convolving kernel. - stride: int or tuple, optional - Stride of the convolution. - - Default = 1 - padding: int, tuple or str, optional - Padding added to all four sides of the input. - - Default = 0 - dilation: int or tuple, optional - Spacing between kernel elements. - - Default = 1 - groups: int, optional - Number of blocked connections from input channels to output channels. - - Default = 1 - bias: bool, optional - If True, adds a learnable bias to the output. - - Default = True - padding_mode: str, optional - Any of 'zeros', 'reflect', 'replicate' or 'circular'. - - Default = 'zeros' - device: torch.device or str, optional - The torch device. - dtype: torch.dtype, optional - Layer dtype, i.e., the data type of the torch.Tensor defining the layer weights. - max_norm: float, optional - The maximum norm each filter can have. If None no constraint will be included. - - Default = None - - Example - ------- - >>> import selfeeg.models - >>> import torch - >>> x = torch.randn(4,1,8,64) - >>> mdl = models.ConstrainedConv2d(1,4,(1,15)) - >>> out = mdl(x) - >>> print(out.shape) # shoud return torch.Size([4, 4, 8, 64]) - >>> print(torch.isnan(out).sum()) # shoud return 0 - - """ - - def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride=1, - padding="same", - dilation=1, - groups=1, - bias=True, - padding_mode="zeros", - device=None, - dtype=None, - max_norm=None, - ): - super(ConstrainedConv2d, self).__init__( - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - groups, - bias, - padding_mode, - device, - dtype, - ) - - if max_norm is not None: - if max_norm <= 0: - raise ValueError("max_norm can't be lower or equal than 0") - else: - self.max_norm = max_norm - else: - self.max_norm = max_norm - - @torch.no_grad() - def scale_norm(self, eps=1e-9): - """ - Citing the Tensorflow documentation, the implementation tries to replicate this - integer, axis along which to calculate weight norms. For instance, in a Dense - layer the weight matrix has shape - (input_dim, output_dim), set axis to 0 to constrain each weight vector of length - (input_dim,). - In a Conv2D layer with data_format="channels_last", the weight tensor has shape (rows, - cols, input_depth, output_depth), - set axis to [0, 1, 2] to constrain the weights of each filter tensor of size (rows, - cols, input_depth). - - :meta private: - """ - # calculate the norm of each filter of size - # (row, cols, input_depth), here (1, kernel_size) - if self.kernel_size[1] > 1: - norm = self.weight.norm(dim=2, keepdim=True).norm(dim=3, keepdim=True) - else: - norm = self.weight.norm(dim=2, keepdim=True) - - # rescale only those filters which have a norm bigger than the maximum allowed - if (norm > self.max_norm).sum() > 0: - desired = torch.clamp(norm, 0, self.max_norm) - self.weight = torch.nn.Parameter(self.weight * desired / (eps + norm)) - - def forward(self, input): - """ - :meta private: - """ - if self.max_norm is not None: - self.scale_norm(self.max_norm) - return self._conv_forward(input, self.weight, self.bias) - # ------------------------------ # EEGNet diff --git a/test/EEGself/models/layers_test.py b/test/EEGself/models/layers_test.py new file mode 100644 index 0000000..77c545a --- /dev/null +++ b/test/EEGself/models/layers_test.py @@ -0,0 +1,240 @@ +import itertools +import os +import sys +import unittest + +import numpy as np +import torch + +from selfeeg import models + + +class TestModels(unittest.TestCase): + + def makeGrid(self, pars_dict): + keys = pars_dict.keys() + combinations = itertools.product(*pars_dict.values()) + ds = [dict(zip(keys, cc)) for cc in combinations] + return ds + + @classmethod + def setUpClass(cls): + cls.device = ( + torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu") + ) + if cls.device.type == "cpu": + cls.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + print("\n----------------------------") + print("TESTING MODELS.LAYERS MODULE") + if cls.device.type != "cpu": + print("Found gpu device: testing module with both cpu and gpu") + else: + print("Didn't found cuda device: testing module with only cpu") + print("----------------------------") + cls.N = 2 + cls.Chan = 8 + cls.Samples = 2048 + cls.x = torch.randn(cls.N, cls.Chan, cls.Samples) + cls.xl = torch.randn(cls.N, 1, 16, cls.Samples) + cls.xd = torch.randn(cls.N, 128) + if cls.device.type != "cpu": + cls.x2 = torch.randn(cls.N, cls.Chan, cls.Samples).to(device=cls.device) + cls.xl2 = torch.randn(cls.N, 1, 16, cls.Samples).to(device=cls.device) + cls.xd2 = torch.randn(cls.N, 128).to(device=cls.device) + + def setUp(self): + torch.manual_seed(1234) + + def test_ConstrainedConv1d(self): + print("Testing conv1d with norm constraint...", end="", flush=True) + Conv_args = { + "in_channels": [8], + "out_channels": [4, 16], + "kernel_size": [16], + "stride": [1, 2, 3], + "dilation": [1, 2], + "bias": [True, False], + "max_norm": [None, 1, 2], + "min_norm": [None, 1], + "padding": ["valid"], + } + Conv_args = self.makeGrid(Conv_args) + for i in Conv_args: + model = models.ConstrainedConv1d(**i) + model.weight = torch.nn.Parameter(model.weight * 10) + out = model(self.x) + if i["max_norm"] is not None: + norms = torch.sqrt(torch.sum(torch.square(model.weight), axis=[1,2])) + self.assertTrue(torch.sum(norms>(i["max_norm"]+1e-3)).item() == 0) + if i["min_norm"] is not None: + self.assertTrue(torch.sum(norms<(i["min_norm"]-1e-3)).item() == 0) + self.assertEqual(torch.isnan(out).sum(), 0) + + if self.device.type != "cpu": + for i in Conv_args: + model = models.ConstrainedConv1d(**i).to(device=self.device) + model.weight = torch.nn.Parameter(model.weight * 10) + out = model(self.x2) + if i["max_norm"] is not None: + norms = torch.sqrt(torch.sum(torch.square(model.weight), axis=[1,2])) + self.assertTrue(torch.sum(norms>(i["max_norm"]+1e-3)).item() == 0) + if i["min_norm"] is not None: + self.assertTrue(torch.sum(norms<(i["min_norm"]-1e-3)).item() == 0) + self.assertEqual(torch.isnan(out).sum(), 0) + print( + " Constrained conv1d OK: tested", len(Conv_args), " combinations of input arguments" + ) + + def test_ConstrainedConv2d(self): + print("Testing conv2d with norm constraint...", end="", flush=True) + Conv_args = { + "in_channels": [1], + "out_channels": [5, 16], + "kernel_size": [(1, 64), (5, 1), (5, 64)], + "stride": [1, 2, 3], + "dilation": [1, 2], + "bias": [True, False], + "max_norm": [None, 1, 2, 3], + "min_norm": [None, 1], + "padding": ["valid"], + } + Conv_args = self.makeGrid(Conv_args) + for i in Conv_args: + model = models.ConstrainedConv2d(**i) + model.weight = torch.nn.Parameter(model.weight * 10) + out = model(self.xl) + if i["max_norm"] is not None: + norms = torch.sqrt(torch.sum(torch.square(model.weight), axis=[1,2,3])) + self.assertTrue(torch.sum(norms>(i["max_norm"]+1e-3)).item() == 0) + if i["min_norm"] is not None: + self.assertTrue(torch.sum(norms<(i["min_norm"]-1e-3)).item() == 0) + self.assertEqual(torch.isnan(out).sum(), 0) + + if self.device.type != "cpu": + for i in Conv_args: + model = models.ConstrainedConv2d(**i).to(device=self.device) + model.weight = torch.nn.Parameter(model.weight * 10) + out = model(self.xl2) + if i["max_norm"] is not None: + norms = torch.sqrt(torch.sum(torch.square(model.weight), axis=[1,2,3])) + self.assertTrue(torch.sum(norms>(i["max_norm"]+1e-3)).item() == 0) + if i["min_norm"] is not None: + self.assertTrue(torch.sum(norms<(i["min_norm"]-1e-3)).item() == 0) + self.assertEqual(torch.isnan(out).sum(), 0) + print( + " Constrained conv2d OK: tested", len(Conv_args), " combinations of input arguments" + ) + + def test_ConstrainedDense(self): + print("Testing Dense layer with max norm constraint...", end="", flush=True) + Dense_args = { + "in_features": [128], + "out_features": [32], + "bias": [True, False], + "max_norm": [None, 1, 3], + "min_norm": [None, 1], + } + Dense_args = self.makeGrid(Dense_args) + for i in Dense_args: + model = models.ConstrainedDense(**i) + model.weight = torch.nn.Parameter(model.weight * 10) + out = model(self.xd) + if i["max_norm"] is not None: + norms = torch.sqrt(torch.sum(torch.square(model.weight), axis=1)) + self.assertTrue(torch.sum(norms>(i["max_norm"]+1e-3)).item() == 0) + if i["min_norm"] is not None: + self.assertTrue(torch.sum(norms<(i["min_norm"]-1e-3)).item() == 0) + self.assertEqual(torch.isnan(out).sum(), 0) + self.assertEqual(out.shape[1], 32) + + if self.device.type != "cpu": + for i in Dense_args: + model = models.ConstrainedDense(**i).to(device=self.device) + model.weight = torch.nn.Parameter(model.weight * 10) + out = model(self.xd2) + if i["max_norm"] is not None: + norms = torch.sqrt(torch.sum(torch.square(model.weight), axis=1)) + self.assertTrue(torch.sum(norms>(i["max_norm"]+1e-3)).item() == 0) + if i["min_norm"] is not None: + self.assertTrue(torch.sum(norms<(i["min_norm"]-1e-3)).item() == 0) + self.assertEqual(torch.isnan(out).sum(), 0) + self.assertEqual(out.shape[1], 32) + print(" Dense layer OK: tested", len(Dense_args), " combinations of input arguments") + + def test_DepthwiseConv2d(self): + print("Testing Depthwise conv2d with norm constraint...", end="", flush=True) + Depthwise_args = { + "in_channels": [1], + "depth_multiplier": [2, 3, 4], + "kernel_size": [(1, 64), (5, 1), (5, 64)], + "stride": [1, 2, 3], + "dilation": [1, 2], + "bias": [True, False], + "max_norm": [None, 1, 3], + "min_norm": [None, 1], + "padding": ["valid"], + } + Depthwise_args = self.makeGrid(Depthwise_args) + for i in Depthwise_args: + model = models.DepthwiseConv2d(**i) + model.weight = torch.nn.Parameter(model.weight * 10) + out = model(self.xl) + if i["max_norm"] is not None: + norms = torch.sqrt(torch.sum(torch.square(model.weight), axis=[1,2,3])) + self.assertTrue(torch.sum(norms>(i["max_norm"]+1e-3)).item() == 0) + self.assertEqual(torch.isnan(out).sum(), 0) + self.assertEqual(out.shape[1], i["depth_multiplier"]) + + if self.device.type != "cpu": + for i in Depthwise_args: + model = models.DepthwiseConv2d(**i).to(device=self.device) + model.weight = torch.nn.Parameter(model.weight * 10) + out = model(self.xl2) + if i["max_norm"] is not None: + norms = torch.sqrt(torch.sum(torch.square(model.weight), axis=[1,2,3])) + self.assertTrue(torch.sum(norms>(i["max_norm"]+1e-3)).item() == 0) + if i["min_norm"] is not None: + self.assertTrue(torch.sum(norms<(i["min_norm"]-1e-3)).item() == 0) + self.assertEqual(torch.isnan(out).sum(), 0) + self.assertEqual(out.shape[1], i["depth_multiplier"]) + print( + " Depthwise conv2d OK: tested", + len(Depthwise_args), + " combinations of input arguments", + ) + + def test_SeparableConv2d(self): + print("Testing Separable conv2d with norm constraint...", end="", flush=True) + Separable_args = { + "in_channels": [1], + "out_channels": [5, 16], + "depth_multiplier": [1, 3], + "kernel_size": [(1, 64), (5, 1), (5, 64)], + "stride": [1, 2, 3], + "dilation": [1, 2], + "bias": [True, False], + "depth_max_norm": [None, 1, 2], + "depth_min_norm": [None, 1], + "point_max_norm": [None, 1, 2], + "point_min_norm": [None, 1], + "padding": ["valid"], + } + Separable_args = self.makeGrid(Separable_args) + for i in Separable_args: + model = models.SeparableConv2d(**i) + out = model(self.xl) + self.assertEqual(torch.isnan(out).sum(), 0) + self.assertEqual(out.shape[1], i["out_channels"]) + + if self.device.type != "cpu": + for i in Separable_args: + model = models.SeparableConv2d(**i).to(device=self.device) + out = model(self.xl2) + self.assertEqual(torch.isnan(out).sum(), 0) + self.assertEqual(out.shape[1], i["out_channels"]) + print( + " Separable conv2d OK: tested", len(Separable_args), "combinations of input arguments" + ) + +if __name__ == "__main__": + unittest.main() diff --git a/test/EEGself/models/zoo_test.py b/test/EEGself/models/zoo_test.py index cffb70e..b1b9789 100644 --- a/test/EEGself/models/zoo_test.py +++ b/test/EEGself/models/zoo_test.py @@ -24,13 +24,13 @@ def setUpClass(cls): ) if cls.device.type == "cpu": cls.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - print("\n---------------------") - print("TESTING MODELS MODULE") + print("\n-------------------------") + print("TESTING MODELS.ZOO MODULE") if cls.device.type != "cpu": print("Found gpu device: testing module with both cpu and gpu") else: print("Didn't found cuda device: testing module with only cpu") - print("---------------------") + print("-------------------------") cls.N = 2 cls.Chan = 8 cls.Samples = 2048 @@ -45,145 +45,6 @@ def setUpClass(cls): def setUp(self): torch.manual_seed(1234) - def test_DepthwiseConv2d(self): - print("Testing Depthwise conv2d with max norm constraint...", end="", flush=True) - Depthwise_args = { - "in_channels": [1], - "depth_multiplier": [2, 3, 4], - "kernel_size": [(1, 64), (5, 1), (5, 64)], - "stride": [1, 2, 3], - "dilation": [1, 2], - "bias": [True, False], - "max_norm": [None, 2, 3], - "padding": ["valid"], - } - Depthwise_args = self.makeGrid(Depthwise_args) - for i in Depthwise_args: - model = models.DepthwiseConv2d(**i) - model.weight = torch.nn.Parameter(model.weight * 10) - out = model(self.xl) - if i["max_norm"] is not None: - norm = model.weight.norm(dim=2, keepdim=True).norm(dim=3, keepdim=True).squeeze() - self.assertEqual((norm > i["max_norm"]).sum(), 0) - self.assertEqual(torch.isnan(out).sum(), 0) - self.assertEqual(out.shape[1], i["depth_multiplier"]) - - if self.device.type != "cpu": - for i in Depthwise_args: - model = models.DepthwiseConv2d(**i).to(device=self.device) - model.weight = torch.nn.Parameter(model.weight * 10) - out = model(self.xl2) - if i["max_norm"] is not None: - norm = ( - model.weight.norm(dim=2, keepdim=True).norm(dim=3, keepdim=True).squeeze() - ) - self.assertEqual((norm > i["max_norm"]).sum(), 0) - self.assertEqual(torch.isnan(out).sum(), 0) - self.assertEqual(out.shape[1], i["depth_multiplier"]) - print( - " Depthwise conv2d OK: tested", - len(Depthwise_args), - " combinations of input arguments", - ) - - def test_SeparableConv2d(self): - print("Testing Separable conv2d with norm constraint...", end="", flush=True) - Separable_args = { - "in_channels": [1], - "out_channels": [5, 16], - "depth_multiplier": [1, 3], - "kernel_size": [(1, 64), (5, 1), (5, 64)], - "stride": [1, 2, 3], - "dilation": [1, 2], - "bias": [True, False], - "depth_max_norm": [None, 2, 3], - "padding": ["valid"], - } - Separable_args = self.makeGrid(Separable_args) - for i in Separable_args: - model = models.SeparableConv2d(**i) - out = model(self.xl) - self.assertEqual(torch.isnan(out).sum(), 0) - self.assertEqual(out.shape[1], i["out_channels"]) - - if self.device.type != "cpu": - for i in Separable_args: - model = models.SeparableConv2d(**i).to(device=self.device) - out = model(self.xl2) - self.assertEqual(torch.isnan(out).sum(), 0) - self.assertEqual(out.shape[1], i["out_channels"]) - print( - " Separable conv2d OK: tested", len(Separable_args), "combinations of input arguments" - ) - - def test_ConstrainedConv2d(self): - print("Testing conv2d with max norm constraint...", end="", flush=True) - Conv_args = { - "in_channels": [1], - "out_channels": [5, 16], - "kernel_size": [(1, 64), (5, 1), (5, 64)], - "stride": [1, 2, 3], - "dilation": [1, 2], - "bias": [True, False], - "max_norm": [None, 2, 3], - "padding": ["valid"], - } - Conv_args = self.makeGrid(Conv_args) - for i in Conv_args: - model = models.ConstrainedConv2d(**i) - model.weight = torch.nn.Parameter(model.weight * 10) - out = model(self.xl) - if i["max_norm"] is not None: - norm = model.weight.norm(dim=2, keepdim=True).norm(dim=3, keepdim=True).squeeze() - self.assertEqual((norm > i["max_norm"]).sum(), 0) - self.assertEqual(torch.isnan(out).sum(), 0) - - if self.device.type != "cpu": - for i in Conv_args: - model = models.ConstrainedConv2d(**i).to(device=self.device) - model.weight = torch.nn.Parameter(model.weight * 10) - out = model(self.xl2) - if i["max_norm"] is not None: - norm = ( - model.weight.norm(dim=2, keepdim=True).norm(dim=3, keepdim=True).squeeze() - ) - self.assertEqual((norm > i["max_norm"]).sum(), 0) - self.assertEqual(torch.isnan(out).sum(), 0) - print( - " Constrained conv2d OK: tested", len(Conv_args), " combinations of input arguments" - ) - - def test_ConstrainedDense(self): - print("Testing Dense layer with max norm constraint...", end="", flush=True) - Dense_args = { - "in_features": [128], - "out_features": [32], - "bias": [True, False], - "max_norm": [None, 2, 3], - } - Dense_args = self.makeGrid(Dense_args) - for i in Dense_args: - model = models.ConstrainedDense(**i) - model.weight = torch.nn.Parameter(model.weight * 10) - out = model(self.xd) - if i["max_norm"] is not None: - norm = model.weight.norm(dim=1, keepdim=True) - self.assertEqual((norm > i["max_norm"]).sum(), 0) - self.assertEqual(torch.isnan(out).sum(), 0) - self.assertEqual(out.shape[1], 32) - - if self.device.type != "cpu": - for i in Dense_args: - model = models.ConstrainedDense(**i).to(device=self.device) - model.weight = torch.nn.Parameter(model.weight * 10) - out = model(self.xd2) - if i["max_norm"] is not None: - norm = model.weight.norm(dim=1, keepdim=True) - self.assertEqual((norm > i["max_norm"]).sum(), 0) - self.assertEqual(torch.isnan(out).sum(), 0) - self.assertEqual(out.shape[1], 32) - print(" Dense layer OK: tested", len(Dense_args), " combinations of input arguments") - def test_DeepConvNet(self): print("Testing DeepConvNet...", end="", flush=True) DCN_args = { @@ -258,8 +119,6 @@ def test_EEGInception(self): self.assertGreaterEqual(out.min(), 0) print(" EEGInception OK: tested", len(EEGin_args), " combinations of input arguments") - # In[8]: - def test_EEGNet(self): print("Testing EEGnet...", end="", flush=True) EEGnet_args = { @@ -296,8 +155,6 @@ def test_EEGNet(self): self.assertGreaterEqual(out.min(), 0) print(" EEGnet OK: tested", len(EEGnet_args), " combinations of input arguments") - # In[9]: - def test_EEGSym(self): print("Testing EEGsym...", end="", flush=True) EEGsym_args = { @@ -334,9 +191,7 @@ def test_EEGSym(self): self.assertGreaterEqual(out.min(), 0) print(" EEGsym OK: tested", len(EEGsym_args), " combinations of input arguments") - # In[10]: - - def test_EEGSym(self): + def test_ResNet(self): print("Testing ResNet...", end="", flush=True) EEGres_args = { "nb_classes": [2, 4], @@ -370,8 +225,6 @@ def test_EEGSym(self): self.assertGreaterEqual(out.min(), 0) print(" ResNet OK: tested", len(EEGres_args), " combinations of input arguments") - # In[11]: - def test_ShallowNet(self): print("Testing ShallowNet...", end="", flush=True) EEGsha_args = { @@ -474,8 +327,6 @@ def test_STNet(self): def test_TinySleepNet(self): print("Testing TinySleepNet...", end="", flush=True) - # nb_classes, Chans, Fs, F=128, kernlength=8, pool=8, - # dropRate=0.5, batch_momentum=0.1, max_dense_norm=2.0, return_logits=True EEGsleep_args = { "nb_classes": [2, 4], "Chans": [self.Chan],