Skip to content

Commit

Permalink
Merge pull request #2919 from paulrignanese/main
Browse files Browse the repository at this point in the history
Replace toy_example by generate_ground_truth_recording in sorters folder
  • Loading branch information
samuelgarcia authored May 30, 2024
2 parents af3c8b3 + 5f9670f commit 46734eb
Show file tree
Hide file tree
Showing 10 changed files with 17 additions and 23 deletions.
7 changes: 0 additions & 7 deletions src/spikeinterface/preprocessing/tests/test_phase_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,6 @@

import scipy.fft

if hasattr(pytest, "global_test_folder"):
cache_folder = pytest.global_test_folder / "preprocessing"
else:
cache_folder = Path("cache_folder") / "preprocessing"

set_global_tmp_folder(cache_folder)


def create_shifted_channel():
duration = 5.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
from pathlib import Path

from spikeinterface import generate_ground_truth_recording
from spikeinterface.core.core_tools import is_editable_mode
import spikeinterface.extractors as se
import spikeinterface.sorters as ss
Expand All @@ -23,7 +24,7 @@ def check_gh_settings():


def generate_run_kwargs():
test_recording, _ = se.toy_example(duration=30, seed=0, num_channels=64, num_segments=1)
test_recording, _ = generate_ground_truth_recording(durations=[30], seed=0, num_channels=64)
test_recording = test_recording.save(name="toy")
test_recording.set_channel_gains(1)
test_recording.set_channel_offsets(0)
Expand Down
5 changes: 2 additions & 3 deletions src/spikeinterface/sorters/external/tests/test_kilosort4.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
import pytest
from pathlib import Path

from spikeinterface import load_extractor
from spikeinterface.extractors import toy_example
from spikeinterface import load_extractor, generate_ground_truth_recording
from spikeinterface.sorters import Kilosort4Sorter, run_sorter
from spikeinterface.sorters.tests.common_tests import SorterCommonTestSuite

Expand All @@ -23,7 +22,7 @@ def setUp(self):
if (cache_folder / "rec").is_dir():
recording = load_extractor(cache_folder / "rec")
else:
recording, _ = toy_example(num_channels=32, duration=60, seed=0, num_segments=1)
recording, _ = generate_ground_truth_recording(num_channels=32, durations=[60], seed=0)
recording = recording.save(folder=cache_folder / "rec", verbose=False, format="binary")
self.recording = recording
print(self.recording)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import unittest

import pytest
from spikeinterface.extractors import toy_example
from spikeinterface.sorters import PyKilosortSorter
from spikeinterface.sorters.tests.common_tests import SorterCommonTestSuite

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
from pathlib import Path

from spikeinterface import generate_ground_truth_recording
from spikeinterface.core.core_tools import is_editable_mode
import spikeinterface.extractors as se
import spikeinterface.sorters as ss
Expand All @@ -29,7 +30,7 @@ def check_gh_settings():


def generate_run_kwargs():
test_recording, _ = se.toy_example(duration=30, seed=0, num_channels=64, num_segments=1)
test_recording, _ = generate_ground_truth_recording(durations=[30], seed=0, num_channels=64)
test_recording = test_recording.save(name="toy")
test_recording.set_channel_gains(1)
test_recording.set_channel_offsets(0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@

import pytest

from spikeinterface import generate_ground_truth_recording
from spikeinterface.core.core_tools import is_editable_mode
import spikeinterface.extractors as se

import spikeinterface.sorters as ss

os.environ["SINGULARITY_DISABLE_CACHE"] = "true"
Expand All @@ -23,7 +24,7 @@ def check_gh_settings():


def generate_run_kwargs():
test_recording, _ = se.toy_example(duration=30, seed=0, num_channels=64, num_segments=1)
test_recording, _ = generate_ground_truth_recording(durations=[30], seed=0, num_channels=64)
test_recording = test_recording.save(name="toy")
test_recording.set_channel_gains(1)
test_recording.set_channel_offsets(0)
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/sorters/external/yass.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class YassSorter(BaseSorter):
1. Retraining Neural Networks (Default)
rec, sort = se.toy_example(duration=300)
rec, sort = generate_ground_truth_recording(durations=[300])
sorting_yass = ss.run_yass(rec, '/home/cat/Downloads/test2')
Expand Down
6 changes: 3 additions & 3 deletions src/spikeinterface/sorters/tests/common_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pathlib import Path
import shutil

from spikeinterface.extractors import toy_example
from spikeinterface import generate_ground_truth_recording
from spikeinterface.sorters import run_sorter
from spikeinterface.core.snippets_tools import snippets_from_sorting

Expand All @@ -24,7 +24,7 @@ class SorterCommonTestSuite:
SorterClass = None

def setUp(self):
recording, sorting_gt = toy_example(num_channels=4, duration=60, seed=0, num_segments=1)
recording, sorting_gt = generate_ground_truth_recording(num_channels=4, durations=[60], seed=0)
rec_folder = cache_folder / "rec"
if rec_folder.is_dir():
shutil.rmtree(rec_folder)
Expand Down Expand Up @@ -80,7 +80,7 @@ class SnippetsSorterCommonTestSuite:
SorterClass = None

def setUp(self):
recording, sorting_gt = toy_example(num_channels=4, duration=60, seed=0, num_segments=1)
recording, sorting_gt = generate_ground_truth_recording(num_channels=4, durations=[60], seed=0)
snippets_folder = cache_folder / "snippets"
if snippets_folder.is_dir():
shutil.rmtree(snippets_folder)
Expand Down
7 changes: 4 additions & 3 deletions src/spikeinterface/sorters/tests/test_container_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import os

import spikeinterface as si
from spikeinterface.extractors import toy_example
from spikeinterface import generate_ground_truth_recording

from spikeinterface.sorters.container_tools import find_recording_folders, ContainerClient, install_package_in_container

ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS"))
Expand All @@ -21,10 +22,10 @@ def setup_module():
for test_dir in test_dirs:
if test_dir.exists():
shutil.rmtree(test_dir)
rec1, _ = toy_example(num_segments=1)
rec1, _ = generate_ground_truth_recording(durations=[10])
rec1 = rec1.save(folder=cache_folder / "mono")

rec2, _ = toy_example(num_segments=3)
rec2, _ = generate_ground_truth_recording(durations=[10, 10, 10])
rec2 = rec2.save(folder=cache_folder / "multi")


Expand Down
1 change: 0 additions & 1 deletion src/spikeinterface/sorters/tests/test_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from spikeinterface.core import load_extractor

# from spikeinterface.extractors import toy_example
from spikeinterface import generate_ground_truth_recording
from spikeinterface.sorters import run_sorter_jobs, run_sorter_by_property

Expand Down

0 comments on commit 46734eb

Please sign in to comment.