diff --git a/docs/faq.rst b/docs/faq.rst index 26047b9..794d845 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -53,15 +53,14 @@ from the file name really refer to the specific dataset or subject. You can give anything you want as long as the previous reasoning about the identification of unique pairs is satisfied. -3) **Can I implement a Leave One Subject Out cross-validation?** +3) **Can I implement a Leave-One-Subject-Out (LOSO) cross-validation?** Of course. You just need to call the ``GetEEGSplitTableKfold`` function, setting validation split to subject mode and setting the number of folds equals to the number of subjects. Remember to add a subject_id extractor if needed and, if you have enough data to create a separate test set, to also set the test split mode to subject and adjust the number of folds according to the number of subjects -minus ones put in the test set - +minus ones put in the test set. Augmentation module @@ -93,36 +92,3 @@ check how augmentations specifically perform on your device. :header-rows: 1 :widths: 15, 14, 14, 14, 14, 14, 14 :class: longtable - - - -SSL module ----------- - -1) **The SSL module implements only Contrastive Learning algorithms, is it possible -to use selfEEG to create predictive or generative pretraining task?** - -It depends on the type of pretraining task you want to define. -However, by defining the right dataloader and loss, and give them to the -fine-tuning tuning function of the ssl module, it is possible to construct simple -predictive or generative pretraining task. For example, a simple strategy can be: - -1. Define an EEGDataset class without the label extraction. -2. Define a custom Augmenter. -3. Define a custom ``collite_fn`` function to give to the PyTorch Dataloader. -4. Define the loss function and other training parameters. -5. Run the fine-tuning function with the Dataloader, loss, and other parameters. - -The important step here is the definition of the ``collite_fn`` function (see -`here `_ -on how to create custom collite_fn functions), which is used to create the -pretraining target. For example: - -1. Reconstructive pretraining (generative): create an augmented batch with the - augmenter, then return the augmented batch as the input, and the original batch - as the target. -2. Predict if the sample was augmented (predictive): apply an augmentation - to a random number of samples before constructing the batch, then return the - constructed batch and a binary label (1: augmented sample, 0: original sample) -3. Predict the type of augmentation applied (predictive): Similar to point 2, - with a multiclass label. diff --git a/selfeeg/dataloading/load.py b/selfeeg/dataloading/load.py index f067a64..e6fa554 100644 --- a/selfeeg/dataloading/load.py +++ b/selfeeg/dataloading/load.py @@ -597,7 +597,7 @@ def get_eeg_split_table( # single results will be then concatenated and sorted to preserve index pos if stratified: if (test_ratio == None) and (val_ratio == None): - print("STRATIFICATION can be applied only if at " "least one split ratio is given.") + print("STRATIFICATION can be applied only if at least one split ratio is given.") else: N_classes = np.unique(labels) classSplit = [None] * len(N_classes) @@ -727,10 +727,8 @@ def get_eeg_split_table( # get split subarray arr = group1["N_samples"].values.tolist() target = test_ratio * alldatasum - final_idx, subarray = get_subarray_closest_sum( - arr, target, split_tolerance, perseverance - ) - final_idx.sort() + final_idx = get_subarray_closest_sum(arr, target, split_tolerance, perseverance, False) + # final_idx.sort() # update split list according to returned subarray # and test split mode @@ -807,10 +805,8 @@ def get_eeg_split_table( target = val_ratio * alldatasum else: target = val_ratio * sum(arr) - final_idx, subarray = get_subarray_closest_sum( - arr, target, split_tolerance, perseverance - ) - final_idx.sort() + final_idx = get_subarray_closest_sum(arr, target, split_tolerance, perseverance, False) + # final_idx.sort() if val_split_mode == 2: fileName = group2.iloc[final_idx]["file_name"].values.tolist() @@ -1076,9 +1072,6 @@ def get_eeg_split_table_kfold( if (test_ratio is None) and (test_data_id is None): test_ratio = 0.0 - if seed is not None: - random.seed(seed) - # FIRST STEP: Create test set or exclude data if necessary # the result of this function call will be an initialization of the split table # if no data need to be excluded or placed in a test set, the split_set column @@ -1300,7 +1293,7 @@ def check_split( totval = EEGlen2.iloc[val_list]["N_samples"].sum() tottest = EEGlen2.iloc[test_list]["N_samples"].sum() class_ratio = np.full([3, Nlab], np.nan) - + # iterate through train/validation/test sets which_to_iter = (n for n, i in enumerate([tottrain, totval, tottest]) if i) for i in which_to_iter: diff --git a/selfeeg/ssl/generative.py b/selfeeg/ssl/generative.py index 48a85a8..0147124 100644 --- a/selfeeg/ssl/generative.py +++ b/selfeeg/ssl/generative.py @@ -13,9 +13,7 @@ from .base import EarlyStopping, SSLBase -__all__ = [ - "ReconstructiveSSL" -] +__all__ = ["ReconstructiveSSL"] class ReconstructiveSSL(SSLBase): @@ -151,7 +149,7 @@ def fit( If an EarlyStopping instance is given with monitoring loss set to validation loss, but no validation dataloader is given, monitoring loss will be automatically set to training loss. - + validation_dataloader: Dataloader, optional the pytorch Dataloader used to get the validation batches. It must return a batch as a single tensor X, thus without label tensor Y. @@ -257,10 +255,10 @@ def fit( if X.device.type != device.type: X = X.to(device=device) - + Xaug = augmenter(X) Xrec = self(Xaug) - + val_loss = self.evaluate_loss(loss_func, [Xrec, X], loss_args) val_loss_tot += val_loss.item() if verbose: @@ -343,7 +341,7 @@ def test( Xaug = augmenter(X) Xrec = self(Xaug) - + test_loss = self.evaluate_loss(loss_func, [Xrec, X], loss_args) test_loss_tot += test_loss # verbose print @@ -353,4 +351,4 @@ def test( pbar.update() test_loss_tot /= batch_idx + 1 - return test_loss_tot \ No newline at end of file + return test_loss_tot diff --git a/selfeeg/ssl/predictive.py b/selfeeg/ssl/predictive.py index 7bb2799..88d93fe 100644 --- a/selfeeg/ssl/predictive.py +++ b/selfeeg/ssl/predictive.py @@ -13,9 +13,8 @@ from .base import EarlyStopping, SSLBase -__all__ = [ - "PredictiveSSL" -] +__all__ = ["PredictiveSSL"] + class PredictiveSSL(SSLBase): """ @@ -23,7 +22,7 @@ class PredictiveSSL(SSLBase): Contrary to contrastive, this pretraining performs a classification or regression task with a generated pseudo-label. A trivial example is the model trying to predict which random augmentation from - a given set was applied to each sample of the batch. + a given set was applied to each sample of the batch. Parameters ---------- @@ -79,10 +78,7 @@ class PredictiveSSL(SSLBase): """ def __init__( - self, - encoder: nn.Module, - head: Union[list[int], nn.Module], - return_logits: bool = True + self, encoder: nn.Module, head: Union[list[int], nn.Module], return_logits: bool = True ): super(PredictiveSSL, self).__init__(encoder) @@ -169,12 +165,12 @@ def fit( expected to take in input a batch tensor X and return both the augmented version of X and the pseudo-label tensor Y. It is highly suggested to resort to the selfeeg's augmentation module, - which implements different data augmentation functions and classes + which implements different data augmentation functions and classes to combine them. RandomAug, for example, can also return the index of the chosen augmentation to be used as a pseudo-label. Default = None - + Note ---- This argument is optional because of the alternative way to provide @@ -221,7 +217,7 @@ def fit( The number of times the augmenter is called for a single batch. Each call selects an equal portion of samples in the batch and gives it to the augmenter. - + Default = 2 Note @@ -232,17 +228,17 @@ def fit( compose submodules operate at the batch level, but one might want to generate batches with multiple labels and not one with only a single label. augmenter_batch_calls solves this problem. - + labels_on_dataloader: boolean, optional Set this to True if the dataloader already provides a set of pseudo-labels. If ``True`` augmenter and augmenter_batch_calls will be ignored. - + Note ---- if you want to pretrain the model by simply solving another task and you need more functionalities, you can consider using the ``fine_tune`` function, which acts as a generic supervised training. - + Default = False verbose: bool, optional Whether to print a progression bar or not. @@ -269,13 +265,13 @@ def fit( """ # Various checks on input parameters. - if (augmenter is None) and not(labels_on_dataloader): + if (augmenter is None) and not (labels_on_dataloader): raise ValueError( "at least an augmenter or a dataloader that can output pseudo-labels must be given" ) if augmenter_batch_calls <= 0: raise ValueError("augmenter_batch_calls must be an integer greater than 0") - + # If some arguments weren't given they will be automatically set (device, epochs, optimizer, loss_func, perform_validation, loss_info, N_train, N_val) = ( self._set_fit_args( @@ -325,41 +321,38 @@ def fit( X = X.to(device=device) else: for i in range(len(X)): - X[i] = X[i].to(device=device) + X[i] = X[i].to(device=device) if isinstance(Y, torch.Tensor): Y = Y.to(device=device) else: for i in range(len(Y)): Y[i] = Y[i].to(device=device) - #pseudo-label must be created, need for data augmentation - else: + # pseudo-label must be created, need for data augmentation + else: X = X.to(device=device) - if augmenter_batch_calls==1: + if augmenter_batch_calls == 1: X, Ytrue = augmenter(X) else: permidx = torch.randperm(X.shape[0]) - piece = X.shape[0]//augmenter_batch_calls + piece = X.shape[0] // augmenter_batch_calls samples = permidx[:piece] X[samples], Ytruei = augmenter(X[samples]) if isinstance(Ytruei, torch.Tensor): Ytrue = torch.empty( - X.shape[0], *Ytruei.shape[1:], - dtype=Ytruei.dtype, device=device + X.shape[0], *Ytruei.shape[1:], dtype=Ytruei.dtype, device=device ) else: - Ytrue = torch.empty( - X.shape[0], device=device, dtype=type(Ytruei) - ) + Ytrue = torch.empty(X.shape[0], device=device, dtype=type(Ytruei)) Ytrue[samples] = Ytruei for i in range(1, augmenter_batch_calls): - samples = permidx[piece*i:piece*(i+1)] + samples = permidx[piece * i : piece * (i + 1)] X[samples], Ytrue[samples] = augmenter(X[samples]) - samples = permidx[piece*(i+1):] - X[samples], Ytrue[samples] = augmenter(X[samples]) + samples = permidx[piece * (i + 1) :] + X[samples], Ytrue[samples] = augmenter(X[samples]) Yhat = self(X) train_loss = self.evaluate_loss(loss_func, [Yhat, Ytrue], loss_args) - + train_loss.backward() optimizer.step() train_loss_tot += train_loss.item() @@ -392,26 +385,28 @@ def fit( X = X.to(device=device) else: for i in range(len(X)): - X[i] = X[i].to(device=device) + X[i] = X[i].to(device=device) if isinstance(Y, torch.Tensor): Y = Y.to(device=device) else: for i in range(len(Y)): Y[i] = Y[i].to(device=device) - #pseudo-label must be created, need for data augmentation - else: + # pseudo-label must be created, need for data augmentation + else: X = X.to(device=device) - if augmenter_batch_calls==1: + if augmenter_batch_calls == 1: X, Ytrue = augmenter(X) else: permidx = torch.randperm(X.shape[0]) - piece = X.shape[0]//augmenter_batch_calls + piece = X.shape[0] // augmenter_batch_calls samples = permidx[:piece] X[samples], Ytruei = augmenter(X[samples]) if isinstance(Ytruei, torch.Tensor): Ytrue = torch.empty( - X.shape[0], *Ytruei.shape[1:], - dtype=Ytruei.dtype, device=device + X.shape[0], + *Ytruei.shape[1:], + dtype=Ytruei.dtype, + device=device, ) else: Ytrue = torch.empty( @@ -419,10 +414,10 @@ def fit( ) Ytrue[samples] = Ytruei for i in range(1, augmenter_batch_calls): - samples = permidx[piece*i:piece*(i+1)] + samples = permidx[piece * i : piece * (i + 1)] X[samples], Ytrue[samples] = augmenter(X[samples]) - samples = permidx[piece*(i+1):] - X[samples], Ytrue[samples] = augmenter(X[samples]) + samples = permidx[piece * (i + 1) :] + X[samples], Ytrue[samples] = augmenter(X[samples]) Yhat = self(X) val_loss = self.evaluate_loss(loss_func, [Yhat, Ytrue], loss_args) @@ -450,9 +445,7 @@ def fit( EarlyStopper.rec_best_weights(self) updated_mdl = True if EarlyStopper(): - print( - f"no improvement after {EarlyStopper.patience} epochs. Training stopped" - ) + print(f"no improvement after {EarlyStopper.patience} epochs. Training stopped") if EarlyStopper.record_best_weights and not (updated_mdl): EarlyStopper.restore_best_weights(self) if return_loss_info: @@ -468,11 +461,11 @@ def fit( def test( self, test_dataloader, - augmenter = None, - loss_func = None, + augmenter=None, + loss_func=None, loss_args: list or dict = [], - augmenter_batch_calls = 2, - labels_on_dataloader = False, + augmenter_batch_calls=2, + labels_on_dataloader=False, verbose: bool = True, device: str = None, ): @@ -503,7 +496,7 @@ def test( file=sys.stdout, ) as pbar: for batch_idx, X in enumerate(test_dataloader): - + # pseudo-label already in X, no need for data augmentations if labels_on_dataloader: Ytrue = X[1] @@ -512,37 +505,34 @@ def test( X = X.to(device=device) else: for i in range(len(X)): - X[i] = X[i].to(device=device) + X[i] = X[i].to(device=device) if isinstance(Y, torch.Tensor): Y = Y.to(device=device) else: for i in range(len(Y)): Y[i] = Y[i].to(device=device) - #pseudo-label must be created, need for data augmentation - else: + # pseudo-label must be created, need for data augmentation + else: X = X.to(device=device) - if augmenter_batch_calls==1: + if augmenter_batch_calls == 1: X, Ytrue = augmenter(X) else: permidx = torch.randperm(X.shape[0]) - piece = X.shape[0]//augmenter_batch_calls + piece = X.shape[0] // augmenter_batch_calls samples = permidx[:piece] X[samples], Ytruei = augmenter(X[samples]) if isinstance(Ytruei, torch.Tensor): Ytrue = torch.empty( - X.shape[0], *Ytruei.shape[1:], - dtype=Ytruei.dtype, device=device + X.shape[0], *Ytruei.shape[1:], dtype=Ytruei.dtype, device=device ) else: - Ytrue = torch.empty( - X.shape[0], device=device, dtype=type(Ytruei) - ) + Ytrue = torch.empty(X.shape[0], device=device, dtype=type(Ytruei)) Ytrue[samples] = Ytruei for i in range(1, augmenter_batch_calls): - samples = permidx[piece*i:piece*(i+1)] + samples = permidx[piece * i : piece * (i + 1)] X[samples], Ytrue[samples] = augmenter(X[samples]) - samples = permidx[piece*(i+1):] - X[samples], Ytrue[samples] = augmenter(X[samples]) + samples = permidx[piece * (i + 1) :] + X[samples], Ytrue[samples] = augmenter(X[samples]) Yhat = self(X) test_loss = self.evaluate_loss(loss_func, [Yhat, Ytrue], loss_args) diff --git a/selfeeg/utils/utils.py b/selfeeg/utils/utils.py index 619b41f..c8958a3 100644 --- a/selfeeg/utils/utils.py +++ b/selfeeg/utils/utils.py @@ -24,11 +24,11 @@ ] -def subarray_closest_sum(arr: list, n: int, k: float) -> list: +def subarray_closest_sum(arr: ArrayLike, n: int, k: float) -> tuple(ArrayLike, float, float, float): """ returns a subarray whose element sum is closest to k. - This function is taken from geeksforgeeks at the following link [link1]_ + This function is inspired from [link1]_ It is important to note that this function returns a subarray and not a subset of the array. A subset is a collection of elements in the array taken @@ -38,7 +38,7 @@ def subarray_closest_sum(arr: list, n: int, k: float) -> list: Parameters ---------- - arr: list + arr: ArrayLike The array to search. n: int The length of the array. @@ -47,61 +47,57 @@ def subarray_closest_sum(arr: list, n: int, k: float) -> list: Returns ------- - best_arr: list + best_arr: ArrayLike The subarray whose element sum is closest to k. + best_start: float + The starting index of the subarray. + best_end: float + The ending index of the subarray. + min_diff: float + Absolute difference between the target value and the sum of the subarray's values. References ---------- .. [link1] https://www.geeksforgeeks.org/subarray-whose-sum-is-closest-to-k/ """ - - # Initialize start and end pointers, current sum, and minimum difference - best_arr = [] + # Initialize start and end pointers, current sum, minimum difference + # and best start and end pointers start = 0 end = 0 + best_start = 0 + best_end = 0 curr_sum = arr[0] - min_diff = float("inf") + # Initialize the minimum difference between the subarray sum and K min_diff = abs(curr_sum - k) + # Traverse through the array while end < n - 1: - # If the current sum is less than K, move the end pointer to the right + + # If the current sum is less than k, move the end pointer to the right if curr_sum < k: end += 1 curr_sum += arr[end] - # If the current sum is greater than or equal to K, - # move the start pointer to the right + # Otherwise, move the start pointer to the right else: curr_sum -= arr[start] start += 1 - # Update the minimum difference between the subarray sum and K + # Update the minimum difference and store best subarray pointers if abs(curr_sum - k) < min_diff: min_diff = abs(curr_sum - k) + best_start = start + best_end = end + # if minimum difference is zero, return the optimal subarray + if min_diff == 0: + return arr[best_start : best_end + 1], best_start, best_end, min_diff - # Print the subarray with the sum closest to K - start = 0 - end = 0 - curr_sum = arr[0] - - while end < n - 1: - if curr_sum < k: - end += 1 - curr_sum += arr[end] - else: - curr_sum -= arr[start] - start += 1 - # Print the subarray with the sum closest to K - if abs(curr_sum - k) == min_diff: - for i in range(start, end + 1): - best_arr.append(arr[i]) - break - return best_arr + return arr[best_start : best_end + 1], best_start, best_end, min_diff def get_subarray_closest_sum( - arr: Sequence[int], + arr: ArrayLike, target: float, tolerance: float = 0.01, perseverance: int = 1000, @@ -122,7 +118,7 @@ def get_subarray_closest_sum( Parameters ---------- - arr: list + arr: ArrayLike The array to search. target: float The target sum. @@ -145,7 +141,7 @@ def get_subarray_closest_sum( final_idx: list A list with the index of the identified subarray. best_sub_arr: list, optional - The identified subarray. + The subarray. Example ------- @@ -161,36 +157,45 @@ def get_subarray_closest_sum( if tolerance < 0 or tolerance > 1: raise ValueError("tolerance must be in [0,1]") - if not (isinstance(perseverance, int)): + else: + upper_bound = target * tolerance + if not isinstance(perseverance, int): perseverance = int(perseverance) - # np.argsort + arr_original = arr idx = range(len(arr)) N = len(arr) - best_sub_arr = [] + subarr_diff = 0 + best_idx = [] + best_start = 0 + best_end = 0 + best_subarr_diff = float("inf") + starti = 0 + endi = 0 + + # c = np.array([arr,idx]).T for _ in range(perseverance): - c = list(zip(arr, idx)) random.shuffle(c) arr, idx = zip(*c) + # np.random.shuffle(c) + # _, starti, endi, subarr_diff = subarray_closest_sum(c[:,0], N, target) + _, starti, endi, subarr_diff = subarray_closest_sum(arr, N, target) + if subarr_diff < best_subarr_diff: + best_subarr_diff = subarr_diff + best_idx = idx + best_start = starti + best_end = endi + if best_subarr_diff < upper_bound: + break - sub_arr = subarray_closest_sum(arr, N, target) - # print(sub_arr, abs(sum(sub_arr)- target), abs(sum(best_sub_arr)-target)) - if (abs(sum(sub_arr) - target)) < (abs(sum(best_sub_arr) - target)): - best_sub_arr = sub_arr - if (target * (1 - tolerance)) < sum(sub_arr) < (target * (1 + tolerance)): - best_sub_arr = sub_arr - break # get final list - best_sub2 = copy.deepcopy(best_sub_arr) - final_idx = [] - for i in range(len(arr)): - if arr[i] in best_sub2: - final_idx.append(idx[i]) - best_sub2.remove(arr[i]) + final_idx = list(best_idx[best_start : best_end + 1]) + final_idx.sort() if return_subarray: - return final_idx, best_sub_arr + best_subarr = list(map(arr_original.__getitem__, final_idx)) + return final_idx, best_subarr else: return final_idx diff --git a/test/EEGself/augmentation/functional_test.py b/test/EEGself/augmentation/functional_test.py index 688be08..e4d3f91 100644 --- a/test/EEGself/augmentation/functional_test.py +++ b/test/EEGself/augmentation/functional_test.py @@ -30,9 +30,10 @@ def setUpClass(cls): else: cls.device = torch.device("cpu") - if cls.device.type == "mps": + if cls.device.type != "cpu": try: - xx = torch.randn(2, 2).to(device=cls.device) + xx = torch.randn(1024).to(device=cls.device) + xx = aug.add_band_noise(xx, "theta", 128) except Exception: cls.device = torch.device("cpu") diff --git a/test/EEGself/dataloading/load_test.py b/test/EEGself/dataloading/load_test.py index 1cffa64..4b944a1 100644 --- a/test/EEGself/dataloading/load_test.py +++ b/test/EEGself/dataloading/load_test.py @@ -287,13 +287,12 @@ def test_get_eeg_split_table(self): "subject_id_extractor": [None], "save": [True], "split_tolerance": [0.001], - "perseverance": [10000], + "perseverance": [1000], "save_path": ["tmpsave/results1.csv"], "seed": [self.seed], } input_grid = self.makeGrid(input_grid) for n, i in enumerate(input_grid): - if n == 500: print("proceding...", end="", flush=True) elif n == 1000: @@ -327,7 +326,7 @@ def test_get_eeg_split_table(self): else: if i["test_data_id"] is None: ratio = abs(0.2 - EEGlen["N_samples"][EEGsplit["split_set"] == 2].sum() / tot) - self.assertTrue(ratio < 1e-2) + self.assertTrue(ratio < 2.5e-2) if not (i["stratified"]) and i["test_split_mode"] == 0: EEGsplit["dataid"] = EEGsplit["file_name"].str[0] group = EEGsplit.groupby(["dataid", "split_set"]) @@ -355,7 +354,7 @@ def test_get_eeg_split_table(self): ratio = abs( thresh - EEGlen["N_samples"][EEGsplit["split_set"] == 1].sum() / tot ) - self.assertTrue(ratio < 1e-2) + self.assertTrue(ratio < 2.5e-2) EEGsplit = dl.get_eeg_split_table( EEGlen, @@ -406,8 +405,8 @@ def test_get_eeg_split_table_kfold(self): "stratified": [False, True], "labels": [Labels], "save": [True], - "split_tolerance": [0.005], - "perseverance": [5000], + "split_tolerance": [0.01], + "perseverance": [1000], "save_path": ["tmpsave/results1.csv"], } input_grid = self.makeGrid(input_grid) @@ -488,7 +487,7 @@ def test_EEGDataset(self): print(" EEGDataset OK") def test_EEGSampler(self): - print("Testing Sampler on both mode...", end="", flush=True) + print("Testing Sampler on both modalities...", end="", flush=True) EEGlen = dl.get_eeg_partition_number( self.eegpath, self.freq, diff --git a/test/EEGself/losses/losses_test.py b/test/EEGself/losses/losses_test.py index 035076d..dcfd19d 100644 --- a/test/EEGself/losses/losses_test.py +++ b/test/EEGself/losses/losses_test.py @@ -26,9 +26,11 @@ def setUpClass(cls): else: cls.device = torch.device("cpu") - if cls.device.type == "mps": + if cls.device.type != "cpu": try: - xx = torch.randn(2, 2).to(device=cls.device) + xx = torch.randn(64, 128).to(device=cls.device) + yy = torch.randn(64, 128).to(device=cls.device) + xx = losses.barlow_loss(xx, yy) except Exception: cls.device = torch.device("cpu") diff --git a/test/EEGself/models/layers_test.py b/test/EEGself/models/layers_test.py index a12641b..ebbd3b7 100644 --- a/test/EEGself/models/layers_test.py +++ b/test/EEGself/models/layers_test.py @@ -26,9 +26,11 @@ def setUpClass(cls): else: cls.device = torch.device("cpu") - if cls.device.type == "mps": + if cls.device.type != "cpu": try: - xx = torch.randn(2, 2).to(device=cls.device) + xx = torch.randn(2, 8, 2048).to(device=cls.device) + lay = models.ConstrainedConv1d(8, 4, 16).to(device=cls.device) + xx = lay(xx) except Exception: cls.device = torch.device("cpu") diff --git a/test/EEGself/models/zoo_test.py b/test/EEGself/models/zoo_test.py index 7bcb655..f7553d7 100644 --- a/test/EEGself/models/zoo_test.py +++ b/test/EEGself/models/zoo_test.py @@ -26,9 +26,11 @@ def setUpClass(cls): else: cls.device = torch.device("cpu") - if cls.device.type == "mps": + if cls.device.type != "cpu": try: - xx = torch.randn(2, 2).to(device=cls.device) + xx = torch.randn(2, 8, 2048).to(device=cls.device) + model = models.EEGNet(2, 8, 2048).to(device=cls.device) + xx = model(xx) except Exception: cls.device = torch.device("cpu") diff --git a/test/EEGself/ssl/ssl_test.py b/test/EEGself/ssl/ssl_test.py index ecdd389..82fc6a2 100644 --- a/test/EEGself/ssl/ssl_test.py +++ b/test/EEGself/ssl/ssl_test.py @@ -62,9 +62,10 @@ def setUpClass(cls): else: cls.device = torch.device("cpu") - if cls.device.type == "mps": + if cls.device.type != "cpu": try: - xx = torch.randn(2, 2).to(device=cls.device) + xx = torch.randn(1024).to(device=cls.device) + xx = aug.add_band_noise(xx, "theta", 128) except Exception: cls.device = torch.device("cpu") diff --git a/test/EEGself/utils/utils_test.py b/test/EEGself/utils/utils_test.py index a97ed42..5d784d2 100644 --- a/test/EEGself/utils/utils_test.py +++ b/test/EEGself/utils/utils_test.py @@ -21,9 +21,12 @@ def setUpClass(cls): else: cls.device = torch.device("cpu") - if cls.device.type == "mps": + if cls.device.type != "cpu": try: - xx = torch.randn(2, 2).to(device=cls.device) + xx = torch.zeros(16, 32, 1024) + xx = xx + torch.sin(torch.linspace(0, 8 * torch.pi, 1024)) * 500 + xx = xx.to(device=cls.device) + xx = utils.scale_range_soft_clip(xx, "scale", "uV") except Exception: cls.device = torch.device("cpu")