Skip to content

Commit

Permalink
Merge pull request #14 from fedepup/main
Browse files Browse the repository at this point in the history
faster subarray search
  • Loading branch information
fedepup authored Oct 19, 2024
2 parents 024402b + 4f32937 commit c421b81
Show file tree
Hide file tree
Showing 12 changed files with 150 additions and 188 deletions.
38 changes: 2 additions & 36 deletions docs/faq.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,14 @@ from the file name really refer to the specific dataset or subject.
You can give anything you want as long as the previous reasoning about the
identification of unique pairs is satisfied.

3) **Can I implement a Leave One Subject Out cross-validation?**
3) **Can I implement a Leave-One-Subject-Out (LOSO) cross-validation?**

Of course. You just need to call the ``GetEEGSplitTableKfold`` function,
setting validation split to subject mode and setting the number of folds equals
to the number of subjects. Remember to add a subject_id extractor if needed and,
if you have enough data to create a separate test set, to also set the test split
mode to subject and adjust the number of folds according to the number of subjects
minus ones put in the test set

minus ones put in the test set.


Augmentation module
Expand Down Expand Up @@ -93,36 +92,3 @@ check how augmentations specifically perform on your device.
:header-rows: 1
:widths: 15, 14, 14, 14, 14, 14, 14
:class: longtable



SSL module
----------

1) **The SSL module implements only Contrastive Learning algorithms, is it possible
to use selfEEG to create predictive or generative pretraining task?**

It depends on the type of pretraining task you want to define.
However, by defining the right dataloader and loss, and give them to the
fine-tuning tuning function of the ssl module, it is possible to construct simple
predictive or generative pretraining task. For example, a simple strategy can be:

1. Define an EEGDataset class without the label extraction.
2. Define a custom Augmenter.
3. Define a custom ``collite_fn`` function to give to the PyTorch Dataloader.
4. Define the loss function and other training parameters.
5. Run the fine-tuning function with the Dataloader, loss, and other parameters.

The important step here is the definition of the ``collite_fn`` function (see
`here <https://discuss.pytorch.org/t/how-to-create-a-dataloader-with-variable-size-input/8278?u=ptrblck>`_
on how to create custom collite_fn functions), which is used to create the
pretraining target. For example:

1. Reconstructive pretraining (generative): create an augmented batch with the
augmenter, then return the augmented batch as the input, and the original batch
as the target.
2. Predict if the sample was augmented (predictive): apply an augmentation
to a random number of samples before constructing the batch, then return the
constructed batch and a binary label (1: augmented sample, 0: original sample)
3. Predict the type of augmentation applied (predictive): Similar to point 2,
with a multiclass label.
19 changes: 6 additions & 13 deletions selfeeg/dataloading/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ def get_eeg_split_table(
# single results will be then concatenated and sorted to preserve index pos
if stratified:
if (test_ratio == None) and (val_ratio == None):
print("STRATIFICATION can be applied only if at " "least one split ratio is given.")
print("STRATIFICATION can be applied only if at least one split ratio is given.")
else:
N_classes = np.unique(labels)
classSplit = [None] * len(N_classes)
Expand Down Expand Up @@ -727,10 +727,8 @@ def get_eeg_split_table(
# get split subarray
arr = group1["N_samples"].values.tolist()
target = test_ratio * alldatasum
final_idx, subarray = get_subarray_closest_sum(
arr, target, split_tolerance, perseverance
)
final_idx.sort()
final_idx = get_subarray_closest_sum(arr, target, split_tolerance, perseverance, False)
# final_idx.sort()

# update split list according to returned subarray
# and test split mode
Expand Down Expand Up @@ -807,10 +805,8 @@ def get_eeg_split_table(
target = val_ratio * alldatasum
else:
target = val_ratio * sum(arr)
final_idx, subarray = get_subarray_closest_sum(
arr, target, split_tolerance, perseverance
)
final_idx.sort()
final_idx = get_subarray_closest_sum(arr, target, split_tolerance, perseverance, False)
# final_idx.sort()

if val_split_mode == 2:
fileName = group2.iloc[final_idx]["file_name"].values.tolist()
Expand Down Expand Up @@ -1076,9 +1072,6 @@ def get_eeg_split_table_kfold(
if (test_ratio is None) and (test_data_id is None):
test_ratio = 0.0

if seed is not None:
random.seed(seed)

# FIRST STEP: Create test set or exclude data if necessary
# the result of this function call will be an initialization of the split table
# if no data need to be excluded or placed in a test set, the split_set column
Expand Down Expand Up @@ -1300,7 +1293,7 @@ def check_split(
totval = EEGlen2.iloc[val_list]["N_samples"].sum()
tottest = EEGlen2.iloc[test_list]["N_samples"].sum()
class_ratio = np.full([3, Nlab], np.nan)

# iterate through train/validation/test sets
which_to_iter = (n for n, i in enumerate([tottrain, totval, tottest]) if i)
for i in which_to_iter:
Expand Down
14 changes: 6 additions & 8 deletions selfeeg/ssl/generative.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@

from .base import EarlyStopping, SSLBase

__all__ = [
"ReconstructiveSSL"
]
__all__ = ["ReconstructiveSSL"]


class ReconstructiveSSL(SSLBase):
Expand Down Expand Up @@ -151,7 +149,7 @@ def fit(
If an EarlyStopping instance is given with monitoring loss set to
validation loss, but no validation dataloader is given, monitoring
loss will be automatically set to training loss.
validation_dataloader: Dataloader, optional
the pytorch Dataloader used to get the validation batches. It
must return a batch as a single tensor X, thus without label tensor Y.
Expand Down Expand Up @@ -257,10 +255,10 @@ def fit(

if X.device.type != device.type:
X = X.to(device=device)

Xaug = augmenter(X)
Xrec = self(Xaug)

val_loss = self.evaluate_loss(loss_func, [Xrec, X], loss_args)
val_loss_tot += val_loss.item()
if verbose:
Expand Down Expand Up @@ -343,7 +341,7 @@ def test(

Xaug = augmenter(X)
Xrec = self(Xaug)

test_loss = self.evaluate_loss(loss_func, [Xrec, X], loss_args)
test_loss_tot += test_loss
# verbose print
Expand All @@ -353,4 +351,4 @@ def test(
pbar.update()

test_loss_tot /= batch_idx + 1
return test_loss_tot
return test_loss_tot
Loading

0 comments on commit c421b81

Please sign in to comment.