Skip to content

Commit

Permalink
Merge branch 'main' into widgets
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelgarcia authored Oct 16, 2023
2 parents 86d0739 + 038d1d3 commit 07736fd
Show file tree
Hide file tree
Showing 23 changed files with 971 additions and 152 deletions.
5 changes: 5 additions & 0 deletions .github/actions/build-test-environment/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ runs:
- name: git-annex install
run: |
wget https://downloads.kitenet.net/git-annex/linux/current/git-annex-standalone-amd64.tar.gz
mkdir /home/runner/work/installation
mv git-annex-standalone-amd64.tar.gz /home/runner/work/installation/
workdir=$(pwd)
cd /home/runner/work/installation
tar xvzf git-annex-standalone-amd64.tar.gz
echo "$(pwd)/git-annex.linux" >> $GITHUB_PATH
cd $workdir
shell: bash
1 change: 1 addition & 0 deletions .github/workflows/test_containers_singularity_gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,6 @@ jobs:
- name: Run test singularity containers with GPU
env:
REPO_TOKEN: ${{ secrets.PERSONAL_ACCESS_TOKEN }}
SPIKEINTERFACE_DEV_PATH: ${{ github.workspace }}
run: |
pytest -vv --capture=tee-sys -rA src/spikeinterface/sorters/external/tests/test_singularity_containers_gpu.py
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.5.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
Expand Down
8 changes: 5 additions & 3 deletions examples/modules_gallery/core/plot_4_waveform_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@

###############################################################################
# A :py:class:`~spikeinterface.core.WaveformExtractor` object can be created with the
# :py:func:`~spikeinterface.core.extract_waveforms` function:
# :py:func:`~spikeinterface.core.extract_waveforms` function (this defaults to a sparse
# representation of the waveforms):

folder = 'waveform_folder'
we = extract_waveforms(
Expand Down Expand Up @@ -87,6 +88,7 @@
recording,
sorting,
folder,
sparse=False,
ms_before=3.,
ms_after=4.,
max_spikes_per_unit=500,
Expand Down Expand Up @@ -149,7 +151,7 @@
#
# Option 1) Save a dense waveform extractor to sparse:
#
# In this case, from an existing waveform extractor, we can first estimate a
# In this case, from an existing (dense) waveform extractor, we can first estimate a
# sparsity (which channels each unit is defined on) and then save to a new
# folder in sparse mode:

Expand All @@ -173,7 +175,7 @@


###############################################################################
# Option 2) Directly extract sparse waveforms:
# Option 2) Directly extract sparse waveforms (current spikeinterface default):
#
# We can also directly extract sparse waveforms. To do so, dense waveforms are
# extracted first using a small number of spikes (:code:`'num_spikes_for_sparsity'`)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
# because it contains a reference to the "Recording" and the "Sorting" objects:

folder = 'waveforms_mearec'
we = si.extract_waveforms(recording, sorting, folder,
we = si.extract_waveforms(recording, sorting, folder, sparse=False,
ms_before=1, ms_after=2., max_spikes_per_unit=500,
n_jobs=1, chunk_durations='1s')
print(we)
Expand Down
23 changes: 14 additions & 9 deletions examples/modules_gallery/qualitymetrics/plot_4_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
quality metrics.
"""
#############################################################################
# Import the modules and/or functions necessary from spikeinterface

import spikeinterface as si
import spikeinterface.extractors as se
Expand All @@ -15,22 +17,21 @@


##############################################################################
# First, let's download a simulated dataset
# from the repo 'https://gin.g-node.org/NeuralEnsemble/ephy_testing_data'
# Let's download a simulated dataset
# from the repo 'https://gin.g-node.org/NeuralEnsemble/ephy_testing_data'
#
# Let's imagine that the ground-truth sorting is in fact the output of a sorter.
#

local_path = si.download_dataset(remote_path='mearec/mearec_test_10s.h5')
recording, sorting = se.read_mearec(local_path)
recording, sorting = se.read_mearec(file_path=local_path)
print(recording)
print(sorting)

##############################################################################
# First, we extract waveforms and compute their PC scores:
# First, we extract waveforms (to be saved in the folder 'wfs_mearec') and
# compute their PC scores:

folder = 'wfs_mearec'
we = si.extract_waveforms(recording, sorting, folder,
we = si.extract_waveforms(recording, sorting, folder='wfs_mearec',
ms_before=1, ms_after=2., max_spikes_per_unit=500,
n_jobs=1, chunk_size=30000)
print(we)
Expand All @@ -47,12 +48,15 @@
##############################################################################
# We can now threshold each quality metric and select units based on some rules.
#
# The easiest and most intuitive way is to use boolean masking with dataframe:
# The easiest and most intuitive way is to use boolean masking with a dataframe.
#
# Then create a list of unit ids that we want to keep

keep_mask = (metrics['snr'] > 7.5) & (metrics['isi_violations_ratio'] < 0.2) & (metrics['nn_hit_rate'] > 0.90)
print(keep_mask)

keep_unit_ids = keep_mask[keep_mask].index.values
keep_unit_ids = [unit_id for unit_id in keep_unit_ids]
print(keep_unit_ids)

##############################################################################
Expand All @@ -61,4 +65,5 @@

curated_sorting = sorting.select_units(keep_unit_ids)
print(curated_sorting)
se.NpzSortingExtractor.write_sorting(curated_sorting, 'curated_sorting.pnz')

se.NpzSortingExtractor.write_sorting(sorting=curated_sorting, save_path='curated_sorting.npz')
7 changes: 5 additions & 2 deletions src/spikeinterface/comparison/groundtruthstudy.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def run_sorters(self, case_keys=None, engine="loop", engine_kwargs={}, keep=True
sorting_exists = sorting_folder.exists()

sorter_folder = self.folder / "sorters" / self.key_to_str(key)
sorter_folder_exists = sorting_folder.exists()
sorter_folder_exists = sorter_folder.exists()

if keep:
if sorting_exists:
Expand All @@ -185,6 +185,9 @@ def run_sorters(self, case_keys=None, engine="loop", engine_kwargs={}, keep=True
if log_file.exists():
log_file.unlink()

if sorter_folder_exists:
shutil.rmtree(sorter_folder)

params = self.cases[key]["run_sorter_params"].copy()
# this ensure that sorter_name is given
recording, _ = self.datasets[self.cases[key]["dataset"]]
Expand Down Expand Up @@ -283,7 +286,7 @@ def extract_waveforms_gt(self, case_keys=None, **extract_kwargs):
# the waveforms depend on the dataset key
wf_folder = base_folder / self.key_to_str(dataset_key)
recording, gt_sorting = self.datasets[dataset_key]
we = extract_waveforms(recording, gt_sorting, folder=wf_folder)
we = extract_waveforms(recording, gt_sorting, folder=wf_folder, **extract_kwargs)

def get_waveform_extractor(self, key):
# some recording are not dumpable to json and the waveforms extactor need it!
Expand Down
5 changes: 3 additions & 2 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,7 @@ def __init__(
"num_channels": num_channels,
"durations": durations,
"sampling_frequency": sampling_frequency,
"noise_level": noise_level,
"dtype": dtype,
"seed": seed,
"strategy": strategy,
Expand Down Expand Up @@ -876,13 +877,13 @@ def generate_single_fake_waveform(


default_unit_params_range = dict(
alpha=(5_000.0, 15_000.0),
alpha=(6_000.0, 9_000.0),
depolarization_ms=(0.09, 0.14),
repolarization_ms=(0.5, 0.8),
recovery_ms=(1.0, 1.5),
positive_amplitude=(0.05, 0.15),
smooth_ms=(0.03, 0.07),
decay_power=(1.2, 1.8),
decay_power=(1.4, 1.8),
)


Expand Down
27 changes: 23 additions & 4 deletions src/spikeinterface/core/waveform_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,14 +833,30 @@ def select_units(self, unit_ids, new_folder=None, use_relative_path: bool = Fals
sparsity = ChannelSparsity(mask, unit_ids, self.channel_ids)
else:
sparsity = None
we = WaveformExtractor.create(self.recording, sorting, folder=None, mode="memory", sparsity=sparsity)
we.set_params(**self._params)
if self.has_recording():
we = WaveformExtractor.create(self.recording, sorting, folder=None, mode="memory", sparsity=sparsity)
else:
we = WaveformExtractor(
recording=None,
sorting=sorting,
folder=None,
sparsity=sparsity,
rec_attributes=self._rec_attributes,
allow_unfiltered=True,
)
we._params = self._params
# copy memory objects
if self.has_waveforms():
we._memory_objects = {"wfs_arrays": {}, "sampled_indices": {}}
for unit_id in unit_ids:
we._memory_objects["wfs_arrays"][unit_id] = self._memory_objects["wfs_arrays"][unit_id]
we._memory_objects["sampled_indices"][unit_id] = self._memory_objects["sampled_indices"][unit_id]
if self.format == "memory":
we._memory_objects["wfs_arrays"][unit_id] = self._memory_objects["wfs_arrays"][unit_id]
we._memory_objects["sampled_indices"][unit_id] = self._memory_objects["sampled_indices"][
unit_id
]
else:
we._memory_objects["wfs_arrays"][unit_id] = self.get_waveforms(unit_id)
we._memory_objects["sampled_indices"][unit_id] = self.get_sampled_indices(unit_id)

# finally select extensions data
for ext_name in self.get_available_extension_names():
Expand Down Expand Up @@ -2016,6 +2032,9 @@ def set_params(self, **params):
params = self._set_params(**params)
self._params = params

if self.waveform_extractor.is_read_only():
return

params_to_save = params.copy()
if "sparsity" in params and params["sparsity"] is not None:
assert isinstance(
Expand Down
5 changes: 4 additions & 1 deletion src/spikeinterface/extractors/neoextractors/openephys.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,10 @@ def __init__(
probe = None

if probe is not None:
self = self.set_probe(probe, in_place=True)
if probe.shank_ids is not None:
self.set_probe(probe, in_place=True, group_mode="by_shank")
else:
self.set_probe(probe, in_place=True)
probe_name = probe.annotations["probe_name"]
# load num_channels_per_adc depending on probe type
if "2.0" in probe_name:
Expand Down
1 change: 0 additions & 1 deletion src/spikeinterface/postprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from .template_metrics import (
TemplateMetricsCalculator,
compute_template_metrics,
calculate_template_metrics,
get_template_metric_names,
)

Expand Down
Loading

0 comments on commit 07736fd

Please sign in to comment.