diff --git a/.github/actions/build-test-environment/action.yml b/.github/actions/build-test-environment/action.yml index 004fe31203..7241f60a8b 100644 --- a/.github/actions/build-test-environment/action.yml +++ b/.github/actions/build-test-environment/action.yml @@ -37,6 +37,11 @@ runs: - name: git-annex install run: | wget https://downloads.kitenet.net/git-annex/linux/current/git-annex-standalone-amd64.tar.gz + mkdir /home/runner/work/installation + mv git-annex-standalone-amd64.tar.gz /home/runner/work/installation/ + workdir=$(pwd) + cd /home/runner/work/installation tar xvzf git-annex-standalone-amd64.tar.gz echo "$(pwd)/git-annex.linux" >> $GITHUB_PATH + cd $workdir shell: bash diff --git a/.github/workflows/test_containers_singularity_gpu.yml b/.github/workflows/test_containers_singularity_gpu.yml index e74fbeb4a5..d075f5a6ef 100644 --- a/.github/workflows/test_containers_singularity_gpu.yml +++ b/.github/workflows/test_containers_singularity_gpu.yml @@ -46,5 +46,6 @@ jobs: - name: Run test singularity containers with GPU env: REPO_TOKEN: ${{ secrets.PERSONAL_ACCESS_TOKEN }} + SPIKEINTERFACE_DEV_PATH: ${{ github.workspace }} run: | pytest -vv --capture=tee-sys -rA src/spikeinterface/sorters/external/tests/test_singularity_containers_gpu.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 07601cd208..7153a7dfc0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.5.0 hooks: - id: check-yaml - id: end-of-file-fixer diff --git a/examples/modules_gallery/core/plot_4_waveform_extractor.py b/examples/modules_gallery/core/plot_4_waveform_extractor.py index 6c886c1eb0..bee8f4061b 100644 --- a/examples/modules_gallery/core/plot_4_waveform_extractor.py +++ b/examples/modules_gallery/core/plot_4_waveform_extractor.py @@ -49,7 +49,8 @@ ############################################################################### # A :py:class:`~spikeinterface.core.WaveformExtractor` object can be created with the -# :py:func:`~spikeinterface.core.extract_waveforms` function: +# :py:func:`~spikeinterface.core.extract_waveforms` function (this defaults to a sparse +# representation of the waveforms): folder = 'waveform_folder' we = extract_waveforms( @@ -87,6 +88,7 @@ recording, sorting, folder, + sparse=False, ms_before=3., ms_after=4., max_spikes_per_unit=500, @@ -149,7 +151,7 @@ # # Option 1) Save a dense waveform extractor to sparse: # -# In this case, from an existing waveform extractor, we can first estimate a +# In this case, from an existing (dense) waveform extractor, we can first estimate a # sparsity (which channels each unit is defined on) and then save to a new # folder in sparse mode: @@ -173,7 +175,7 @@ ############################################################################### -# Option 2) Directly extract sparse waveforms: +# Option 2) Directly extract sparse waveforms (current spikeinterface default): # # We can also directly extract sparse waveforms. To do so, dense waveforms are # extracted first using a small number of spikes (:code:`'num_spikes_for_sparsity'`) diff --git a/examples/modules_gallery/qualitymetrics/plot_3_quality_mertics.py b/examples/modules_gallery/qualitymetrics/plot_3_quality_mertics.py index 209f357457..7b6aae3e30 100644 --- a/examples/modules_gallery/qualitymetrics/plot_3_quality_mertics.py +++ b/examples/modules_gallery/qualitymetrics/plot_3_quality_mertics.py @@ -30,7 +30,7 @@ # because it contains a reference to the "Recording" and the "Sorting" objects: folder = 'waveforms_mearec' -we = si.extract_waveforms(recording, sorting, folder, +we = si.extract_waveforms(recording, sorting, folder, sparse=False, ms_before=1, ms_after=2., max_spikes_per_unit=500, n_jobs=1, chunk_durations='1s') print(we) diff --git a/examples/modules_gallery/qualitymetrics/plot_4_curation.py b/examples/modules_gallery/qualitymetrics/plot_4_curation.py index c66f55f221..edd7a85ce5 100644 --- a/examples/modules_gallery/qualitymetrics/plot_4_curation.py +++ b/examples/modules_gallery/qualitymetrics/plot_4_curation.py @@ -6,6 +6,8 @@ quality metrics. """ +############################################################################# +# Import the modules and/or functions necessary from spikeinterface import spikeinterface as si import spikeinterface.extractors as se @@ -15,22 +17,21 @@ ############################################################################## -# First, let's download a simulated dataset -# from the repo 'https://gin.g-node.org/NeuralEnsemble/ephy_testing_data' +# Let's download a simulated dataset +# from the repo 'https://gin.g-node.org/NeuralEnsemble/ephy_testing_data' # # Let's imagine that the ground-truth sorting is in fact the output of a sorter. -# local_path = si.download_dataset(remote_path='mearec/mearec_test_10s.h5') -recording, sorting = se.read_mearec(local_path) +recording, sorting = se.read_mearec(file_path=local_path) print(recording) print(sorting) ############################################################################## -# First, we extract waveforms and compute their PC scores: +# First, we extract waveforms (to be saved in the folder 'wfs_mearec') and +# compute their PC scores: -folder = 'wfs_mearec' -we = si.extract_waveforms(recording, sorting, folder, +we = si.extract_waveforms(recording, sorting, folder='wfs_mearec', ms_before=1, ms_after=2., max_spikes_per_unit=500, n_jobs=1, chunk_size=30000) print(we) @@ -47,12 +48,15 @@ ############################################################################## # We can now threshold each quality metric and select units based on some rules. # -# The easiest and most intuitive way is to use boolean masking with dataframe: +# The easiest and most intuitive way is to use boolean masking with a dataframe. +# +# Then create a list of unit ids that we want to keep keep_mask = (metrics['snr'] > 7.5) & (metrics['isi_violations_ratio'] < 0.2) & (metrics['nn_hit_rate'] > 0.90) print(keep_mask) keep_unit_ids = keep_mask[keep_mask].index.values +keep_unit_ids = [unit_id for unit_id in keep_unit_ids] print(keep_unit_ids) ############################################################################## @@ -61,4 +65,5 @@ curated_sorting = sorting.select_units(keep_unit_ids) print(curated_sorting) -se.NpzSortingExtractor.write_sorting(curated_sorting, 'curated_sorting.pnz') + +se.NpzSortingExtractor.write_sorting(sorting=curated_sorting, save_path='curated_sorting.npz') diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index df0b5296c0..e5f4ce8b31 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -165,7 +165,7 @@ def run_sorters(self, case_keys=None, engine="loop", engine_kwargs={}, keep=True sorting_exists = sorting_folder.exists() sorter_folder = self.folder / "sorters" / self.key_to_str(key) - sorter_folder_exists = sorting_folder.exists() + sorter_folder_exists = sorter_folder.exists() if keep: if sorting_exists: @@ -185,6 +185,9 @@ def run_sorters(self, case_keys=None, engine="loop", engine_kwargs={}, keep=True if log_file.exists(): log_file.unlink() + if sorter_folder_exists: + shutil.rmtree(sorter_folder) + params = self.cases[key]["run_sorter_params"].copy() # this ensure that sorter_name is given recording, _ = self.datasets[self.cases[key]["dataset"]] @@ -283,7 +286,7 @@ def extract_waveforms_gt(self, case_keys=None, **extract_kwargs): # the waveforms depend on the dataset key wf_folder = base_folder / self.key_to_str(dataset_key) recording, gt_sorting = self.datasets[dataset_key] - we = extract_waveforms(recording, gt_sorting, folder=wf_folder) + we = extract_waveforms(recording, gt_sorting, folder=wf_folder, **extract_kwargs) def get_waveform_extractor(self, key): # some recording are not dumpable to json and the waveforms extactor need it! diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 0c67404069..dc84d31987 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -654,6 +654,7 @@ def __init__( "num_channels": num_channels, "durations": durations, "sampling_frequency": sampling_frequency, + "noise_level": noise_level, "dtype": dtype, "seed": seed, "strategy": strategy, @@ -876,13 +877,13 @@ def generate_single_fake_waveform( default_unit_params_range = dict( - alpha=(5_000.0, 15_000.0), + alpha=(6_000.0, 9_000.0), depolarization_ms=(0.09, 0.14), repolarization_ms=(0.5, 0.8), recovery_ms=(1.0, 1.5), positive_amplitude=(0.05, 0.15), smooth_ms=(0.03, 0.07), - decay_power=(1.2, 1.8), + decay_power=(1.4, 1.8), ) diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index 0fc5694207..d4ae140b90 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -833,14 +833,30 @@ def select_units(self, unit_ids, new_folder=None, use_relative_path: bool = Fals sparsity = ChannelSparsity(mask, unit_ids, self.channel_ids) else: sparsity = None - we = WaveformExtractor.create(self.recording, sorting, folder=None, mode="memory", sparsity=sparsity) - we.set_params(**self._params) + if self.has_recording(): + we = WaveformExtractor.create(self.recording, sorting, folder=None, mode="memory", sparsity=sparsity) + else: + we = WaveformExtractor( + recording=None, + sorting=sorting, + folder=None, + sparsity=sparsity, + rec_attributes=self._rec_attributes, + allow_unfiltered=True, + ) + we._params = self._params # copy memory objects if self.has_waveforms(): we._memory_objects = {"wfs_arrays": {}, "sampled_indices": {}} for unit_id in unit_ids: - we._memory_objects["wfs_arrays"][unit_id] = self._memory_objects["wfs_arrays"][unit_id] - we._memory_objects["sampled_indices"][unit_id] = self._memory_objects["sampled_indices"][unit_id] + if self.format == "memory": + we._memory_objects["wfs_arrays"][unit_id] = self._memory_objects["wfs_arrays"][unit_id] + we._memory_objects["sampled_indices"][unit_id] = self._memory_objects["sampled_indices"][ + unit_id + ] + else: + we._memory_objects["wfs_arrays"][unit_id] = self.get_waveforms(unit_id) + we._memory_objects["sampled_indices"][unit_id] = self.get_sampled_indices(unit_id) # finally select extensions data for ext_name in self.get_available_extension_names(): @@ -2016,6 +2032,9 @@ def set_params(self, **params): params = self._set_params(**params) self._params = params + if self.waveform_extractor.is_read_only(): + return + params_to_save = params.copy() if "sparsity" in params and params["sparsity"] is not None: assert isinstance( diff --git a/src/spikeinterface/extractors/neoextractors/openephys.py b/src/spikeinterface/extractors/neoextractors/openephys.py index cd2b6fb941..bb3ae3435a 100644 --- a/src/spikeinterface/extractors/neoextractors/openephys.py +++ b/src/spikeinterface/extractors/neoextractors/openephys.py @@ -183,7 +183,10 @@ def __init__( probe = None if probe is not None: - self = self.set_probe(probe, in_place=True) + if probe.shank_ids is not None: + self.set_probe(probe, in_place=True, group_mode="by_shank") + else: + self.set_probe(probe, in_place=True) probe_name = probe.annotations["probe_name"] # load num_channels_per_adc depending on probe type if "2.0" in probe_name: diff --git a/src/spikeinterface/postprocessing/__init__.py b/src/spikeinterface/postprocessing/__init__.py index 223bda5e30..d7e1ffac01 100644 --- a/src/spikeinterface/postprocessing/__init__.py +++ b/src/spikeinterface/postprocessing/__init__.py @@ -10,7 +10,6 @@ from .template_metrics import ( TemplateMetricsCalculator, compute_template_metrics, - calculate_template_metrics, get_template_metric_names, ) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 681f6f3e84..3f47c505ad 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -4,16 +4,29 @@ 22/04/2020 """ import numpy as np +import warnings +from typing import Optional from copy import deepcopy -from ..core import WaveformExtractor +from ..core import WaveformExtractor, ChannelSparsity from ..core.template_tools import get_template_extremum_channel from ..core.waveform_extractor import BaseWaveformExtractorExtension -import warnings + + +global DEBUG +DEBUG = False + + +def get_single_channel_template_metric_names(): + return deepcopy(list(_single_channel_metric_name_to_func.keys())) + + +def get_multi_channel_template_metric_names(): + return deepcopy(list(_multi_channel_metric_name_to_func.keys())) def get_template_metric_names(): - return deepcopy(list(_metric_name_to_func.keys())) + return get_single_channel_template_metric_names() + get_multi_channel_template_metric_names() class TemplateMetricsCalculator(BaseWaveformExtractorExtension): @@ -26,20 +39,31 @@ class TemplateMetricsCalculator(BaseWaveformExtractorExtension): """ extension_name = "template_metrics" + min_channels_for_multi_channel_warning = 10 - def __init__(self, waveform_extractor): + def __init__(self, waveform_extractor: WaveformExtractor): BaseWaveformExtractorExtension.__init__(self, waveform_extractor) - def _set_params(self, metric_names=None, peak_sign="neg", upsampling_factor=10, sparsity=None, window_slope_ms=0.7): + def _set_params( + self, + metric_names=None, + peak_sign="neg", + upsampling_factor=10, + sparsity=None, + metrics_kwargs=None, + include_multi_channel_metrics=False, + ): if metric_names is None: - metric_names = get_template_metric_names() - + metric_names = get_single_channel_template_metric_names() + if include_multi_channel_metrics: + metric_names += get_multi_channel_template_metric_names() + metrics_kwargs = metrics_kwargs or dict() params = dict( metric_names=[str(name) for name in metric_names], sparsity=sparsity, peak_sign=peak_sign, upsampling_factor=int(upsampling_factor), - window_slope_ms=float(window_slope_ms), + metrics_kwargs=metrics_kwargs, ) return params @@ -60,6 +84,9 @@ def _run(self): unit_ids = self.waveform_extractor.sorting.unit_ids sampling_frequency = self.waveform_extractor.sampling_frequency + metrics_single_channel = [m for m in metric_names if m in get_single_channel_template_metric_names()] + metrics_multi_channel = [m for m in metric_names if m in get_multi_channel_template_metric_names()] + if sparsity is None: extremum_channels_ids = get_template_extremum_channel( self.waveform_extractor, peak_sign=peak_sign, outputs="id" @@ -79,6 +106,8 @@ def _run(self): template_metrics = pd.DataFrame(index=multi_index, columns=metric_names) all_templates = self.waveform_extractor.get_all_templates() + channel_locations = self.waveform_extractor.get_channel_locations() + for unit_index, unit_id in enumerate(unit_ids): template_all_chans = all_templates[unit_index] chan_ids = np.array(extremum_channels_ids[unit_id]) @@ -87,6 +116,7 @@ def _run(self): chan_ind = self.waveform_extractor.channel_ids_to_indices(chan_ids) template = template_all_chans[:, chan_ind] + # compute single_channel metrics for i, template_single in enumerate(template.T): if sparsity is None: index = unit_id @@ -100,15 +130,50 @@ def _run(self): template_upsampled = template_single sampling_frequency_up = sampling_frequency - for metric_name in metric_names: + trough_idx, peak_idx = get_trough_and_peak_idx(template_upsampled) + + for metric_name in metrics_single_channel: func = _metric_name_to_func[metric_name] value = func( template_upsampled, sampling_frequency=sampling_frequency_up, - window_ms=self._params["window_slope_ms"], + trough_idx=trough_idx, + peak_idx=peak_idx, + **self._params["metrics_kwargs"], ) template_metrics.at[index, metric_name] = value + # compute metrics multi_channel + for metric_name in metrics_multi_channel: + # retrieve template (with sparsity if waveform extractor is sparse) + template = self.waveform_extractor.get_template(unit_id=unit_id) + + if template.shape[1] < self.min_channels_for_multi_channel_warning: + warnings.warn( + f"With less than {self.min_channels_for_multi_channel_warning} channels, " + "multi-channel metrics might not be reliable." + ) + if self.waveform_extractor.is_sparse(): + channel_locations_sparse = channel_locations[self.waveform_extractor.sparsity.mask[unit_index]] + else: + channel_locations_sparse = channel_locations + + if upsampling_factor > 1: + assert isinstance(upsampling_factor, (int, np.integer)), "'upsample' must be an integer" + template_upsampled = resample_poly(template, up=upsampling_factor, down=1, axis=0) + sampling_frequency_up = upsampling_factor * sampling_frequency + else: + template_upsampled = template + sampling_frequency_up = sampling_frequency + + func = _metric_name_to_func[metric_name] + value = func( + template_upsampled, + channel_locations=channel_locations_sparse, + sampling_frequency=sampling_frequency_up, + **self._params["metrics_kwargs"], + ) + template_metrics.at[index, metric_name] = value self._extension_data["metrics"] = template_metrics def get_data(self): @@ -132,14 +197,31 @@ def get_extension_function(): WaveformExtractor.register_extension(TemplateMetricsCalculator) +_default_function_kwargs = dict( + recovery_window_ms=0.7, + peak_relative_threshold=0.2, + peak_width_ms=0.1, + depth_direction="y", + min_channels_for_velocity=5, + min_r2_velocity=0.5, + exp_peak_function="ptp", + min_r2_exp_decay=0.5, + spread_threshold=0.2, + spread_smooth_um=20, + column_range=None, +) + + def compute_template_metrics( waveform_extractor, - load_if_exists=False, - metric_names=None, - peak_sign="neg", - upsampling_factor=10, - sparsity=None, - window_slope_ms=0.7, + load_if_exists: bool = False, + metric_names: Optional[list[str]] = None, + peak_sign: Optional[str] = "neg", + upsampling_factor: int = 10, + sparsity: Optional[ChannelSparsity] = None, + include_multi_channel_metrics: bool = False, + metrics_kwargs: dict = None, + debug_plots: bool = False, ): """ Compute template metrics including: @@ -148,6 +230,14 @@ def compute_template_metrics( * halfwidth * repolarization_slope * recovery_slope + * num_positive_peaks + * num_negative_peaks + + Optionally, the following multi-channel metrics can be computed (when include_multi_channel_metrics=True): + * velocity_above + * velocity_below + * exp_decay + * spread Parameters ---------- @@ -157,34 +247,77 @@ def compute_template_metrics( Whether to load precomputed template metrics, if they already exist. metric_names : list, optional List of metrics to compute (see si.postprocessing.get_template_metric_names()), by default None - peak_sign : str, optional - "pos" | "neg", by default 'neg' - upsampling_factor : int, optional - Upsample factor, by default 10 - sparsity: dict or None - Default is sparsity=None and template metric is computed on extremum channel only. - If given, the dictionary should contain a unit ids as keys and a channel id or a list of channel ids as values. - For more generating a sparsity dict, see the postprocessing.compute_sparsity() function. - window_slope_ms: float - Window in ms after the positiv peak to compute slope, by default 0.7 + peak_sign : {"neg", "pos"}, default: "neg" + Whether to use the positive ("pos") or negative ("neg") peaks to estimate extremum channels. + upsampling_factor : int, default: 10 + The upsampling factor to upsample the templates + sparsity: ChannelSparsity or None, default: None + If None, template metrics are computed on the extremum channel only. + If sparsity is given, template metrics are computed on all sparse channels of each unit. + For more on generating a ChannelSparsity, see the `~spikeinterface.compute_sparsity()` function. + include_multi_channel_metrics: bool, default: False + Whether to compute multi-channel metrics + metrics_kwargs: dict + Additional arguments to pass to the metric functions. Including: + * recovery_window_ms: the window in ms after the peak to compute the recovery_slope, default: 0.7 + * peak_relative_threshold: the relative threshold to detect positive and negative peaks, default: 0.2 + * peak_width_ms: the width in samples to detect peaks, default: 0.2 + * depth_direction: the direction to compute velocity above and below, default: "y" (see notes) + * min_channels_for_velocity: the minimum number of channels above or below to compute velocity, default: 5 + * min_r2_velocity: the minimum r2 to accept the velocity fit, default: 0.7 + * exp_peak_function: the function to use to compute the peak amplitude for the exp decay, default: "ptp" + * min_r2_exp_decay: the minimum r2 to accept the exp decay fit, default: 0.5 + * spread_threshold: the threshold to compute the spread, default: 0.2 + * spread_smooth_um: the smoothing in um to compute the spread, default: 20 + * column_range: the range in um in the horizontal direction to consider channels for velocity, default: None + - If None, all channels all channels are considered + - If 0 or 1, only the "column" that includes the max channel is considered + - If > 1, only channels within range (+/-) um from the max channel horizontal position are used Returns ------- - tempalte_metrics : pd.DataFrame + template_metrics : pd.DataFrame Dataframe with the computed template metrics. If 'sparsity' is None, the index is the unit_id. If 'sparsity' is given, the index is a multi-index (unit_id, channel_id) + + Notes + ----- + If any multi-channel metric is in the metric_names or include_multi_channel_metrics is True, sparsity must be None, + so that one metric value will be computed per unit. + For multi-channel metrics, 3D channel locations are not supported. By default, the depth direction is "y". """ + if debug_plots: + global DEBUG + DEBUG = True if load_if_exists and waveform_extractor.is_extension(TemplateMetricsCalculator.extension_name): tmc = waveform_extractor.load_extension(TemplateMetricsCalculator.extension_name) else: tmc = TemplateMetricsCalculator(waveform_extractor) + # For 2D metrics, external sparsity must be None, so that one metric value will be computed per unit. + if include_multi_channel_metrics or ( + metric_names is not None and any([m in get_multi_channel_template_metric_names() for m in metric_names]) + ): + assert sparsity is None, ( + "If multi-channel metrics are computed, sparsity must be None, " + "so that each unit will correspond to 1 row of the output dataframe." + ) + assert ( + waveform_extractor.get_channel_locations().shape[1] == 2 + ), "If multi-channel metrics are computed, channel locations must be 2D." + default_kwargs = _default_function_kwargs.copy() + if metrics_kwargs is None: + metrics_kwargs = default_kwargs + else: + default_kwargs.update(metrics_kwargs) + metrics_kwargs = default_kwargs tmc.set_params( metric_names=metric_names, peak_sign=peak_sign, upsampling_factor=upsampling_factor, sparsity=sparsity, - window_slope_ms=window_slope_ms, + include_multi_channel_metrics=include_multi_channel_metrics, + metrics_kwargs=metrics_kwargs, ) tmc.run() @@ -197,7 +330,19 @@ def get_trough_and_peak_idx(template): """ Return the indices into the input template of the detected trough (minimum of template) and peak (maximum of template, after trough). - Assumes negative trough and positive peak + Assumes negative trough and positive peak. + + Parameters + ---------- + template: numpy.ndarray + The 1D template waveform + + Returns + ------- + trough_idx: int + The index of the trough + peak_idx: int + The index of the peak """ assert template.ndim == 1 trough_idx = np.argmin(template) @@ -205,41 +350,92 @@ def get_trough_and_peak_idx(template): return trough_idx, peak_idx -def get_peak_to_valley(template, **kwargs): +######################################################################################### +# Single-channel metrics +def get_peak_to_valley(template_single, sampling_frequency, trough_idx=None, peak_idx=None, **kwargs): """ - Time between trough and peak in s + Return the peak to valley duration in seconds of input waveforms. + + Parameters + ---------- + template_single: numpy.ndarray + The 1D template waveform + sampling_frequency : float + The sampling frequency of the template + trough_idx: int, default: None + The index of the trough + peak_idx: int, default: None + The index of the peak + + Returns + ------- + ptv: float + The peak to valley duration in seconds """ - sampling_frequency = kwargs["sampling_frequency"] - trough_idx, peak_idx = get_trough_and_peak_idx(template) + if trough_idx is None or peak_idx is None: + trough_idx, peak_idx = get_trough_and_peak_idx(template_single) ptv = (peak_idx - trough_idx) / sampling_frequency return ptv -def get_peak_trough_ratio(template, **kwargs): +def get_peak_trough_ratio(template_single, sampling_frequency=None, trough_idx=None, peak_idx=None, **kwargs): """ - Ratio between peak heigth and trough depth + Return the peak to trough ratio of input waveforms. + + Parameters + ---------- + template_single: numpy.ndarray + The 1D template waveform + sampling_frequency : float + The sampling frequency of the template + trough_idx: int, default: None + The index of the trough + peak_idx: int, default: None + The index of the peak + + Returns + ------- + ptratio: float + The peak to trough ratio """ - trough_idx, peak_idx = get_trough_and_peak_idx(template) - ptratio = template[peak_idx] / template[trough_idx] + if trough_idx is None or peak_idx is None: + trough_idx, peak_idx = get_trough_and_peak_idx(template_single) + ptratio = template_single[peak_idx] / template_single[trough_idx] return ptratio -def get_half_width(template, **kwargs): +def get_half_width(template_single, sampling_frequency, trough_idx=None, peak_idx=None, **kwargs): """ - Width of waveform at its half of amplitude in s + Return the half width of input waveforms in seconds. + + Parameters + ---------- + template_single: numpy.ndarray + The 1D template waveform + sampling_frequency : float + The sampling frequency of the template + trough_idx: int, default: None + The index of the trough + peak_idx: int, default: None + The index of the peak + + Returns + ------- + hw: float + The half width in seconds """ - trough_idx, peak_idx = get_trough_and_peak_idx(template) - sampling_frequency = kwargs["sampling_frequency"] + if trough_idx is None or peak_idx is None: + trough_idx, peak_idx = get_trough_and_peak_idx(template_single) if peak_idx == 0: return np.nan - trough_val = template[trough_idx] + trough_val = template_single[trough_idx] # threshold is half of peak heigth (assuming baseline is 0) threshold = 0.5 * trough_val - (cpre_idx,) = np.where(template[:trough_idx] < threshold) - (cpost_idx,) = np.where(template[trough_idx:] < threshold) + (cpre_idx,) = np.where(template_single[:trough_idx] < threshold) + (cpost_idx,) = np.where(template_single[trough_idx:] < threshold) if len(cpre_idx) == 0 or len(cpost_idx) == 0: hw = np.nan @@ -254,7 +450,7 @@ def get_half_width(template, **kwargs): return hw -def get_repolarization_slope(template, **kwargs): +def get_repolarization_slope(template_single, sampling_frequency, trough_idx=None, **kwargs): """ Return slope of repolarization period between trough and baseline @@ -264,17 +460,25 @@ def get_repolarization_slope(template, **kwargs): Optionally the function returns also the indices per waveform where the potential crosses baseline. - """ - trough_idx, peak_idx = get_trough_and_peak_idx(template) - sampling_frequency = kwargs["sampling_frequency"] + Parameters + ---------- + template_single: numpy.ndarray + The 1D template waveform + sampling_frequency : float + The sampling frequency of the template + trough_idx: int, default: None + The index of the trough + """ + if trough_idx is None: + trough_idx = get_trough_and_peak_idx(template_single) - times = np.arange(template.shape[0]) / sampling_frequency + times = np.arange(template_single.shape[0]) / sampling_frequency if trough_idx == 0: return np.nan - (rtrn_idx,) = np.nonzero(template[trough_idx:] >= 0) + (rtrn_idx,) = np.nonzero(template_single[trough_idx:] >= 0) if len(rtrn_idx) == 0: return np.nan # first time after trough, where template is at baseline @@ -285,11 +489,11 @@ def get_repolarization_slope(template, **kwargs): import scipy.stats - res = scipy.stats.linregress(times[trough_idx:return_to_base_idx], template[trough_idx:return_to_base_idx]) + res = scipy.stats.linregress(times[trough_idx:return_to_base_idx], template_single[trough_idx:return_to_base_idx]) return res.slope -def get_recovery_slope(template, window_ms=0.7, **kwargs): +def get_recovery_slope(template_single, sampling_frequency, peak_idx=None, **kwargs): """ Return the recovery slope of input waveforms. After repolarization, the neuron hyperpolarizes untill it peaks. The recovery slope is the @@ -299,41 +503,450 @@ def get_recovery_slope(template, window_ms=0.7, **kwargs): Takes a numpy array of waveforms and returns an array with recovery slopes per waveform. + + Parameters + ---------- + template_single: numpy.ndarray + The 1D template waveform + sampling_frequency : float + The sampling frequency of the template + peak_idx: int, default: None + The index of the peak + **kwargs: Required kwargs: + - recovery_window_ms: the window in ms after the peak to compute the recovery_slope """ + import scipy.stats - trough_idx, peak_idx = get_trough_and_peak_idx(template) - sampling_frequency = kwargs["sampling_frequency"] + assert "recovery_window_ms" in kwargs, "recovery_window_ms must be given as kwarg" + recovery_window_ms = kwargs["recovery_window_ms"] + if peak_idx is None: + _, peak_idx = get_trough_and_peak_idx(template_single) - times = np.arange(template.shape[0]) / sampling_frequency + times = np.arange(template_single.shape[0]) / sampling_frequency if peak_idx == 0: return np.nan - max_idx = int(peak_idx + ((window_ms / 1000) * sampling_frequency)) - max_idx = np.min([max_idx, template.shape[0]]) + max_idx = int(peak_idx + ((recovery_window_ms / 1000) * sampling_frequency)) + max_idx = np.min([max_idx, template_single.shape[0]]) - import scipy.stats - - res = scipy.stats.linregress(times[peak_idx:max_idx], template[peak_idx:max_idx]) + res = scipy.stats.linregress(times[peak_idx:max_idx], template_single[peak_idx:max_idx]) return res.slope -_metric_name_to_func = { +def get_num_positive_peaks(template_single, sampling_frequency, **kwargs): + """ + Count the number of positive peaks in the template. + + Parameters + ---------- + template_single: numpy.ndarray + The 1D template waveform + sampling_frequency : float + The sampling frequency of the template + **kwargs: Required kwargs: + - peak_relative_threshold: the relative threshold to detect positive and negative peaks + - peak_width_ms: the width in samples to detect peaks + """ + from scipy.signal import find_peaks + + assert "peak_relative_threshold" in kwargs, "peak_relative_threshold must be given as kwarg" + assert "peak_width_ms" in kwargs, "peak_width_ms must be given as kwarg" + peak_relative_threshold = kwargs["peak_relative_threshold"] + peak_width_ms = kwargs["peak_width_ms"] + max_value = np.max(np.abs(template_single)) + peak_width_samples = int(peak_width_ms / 1000 * sampling_frequency) + + pos_peaks = find_peaks(template_single, height=peak_relative_threshold * max_value, width=peak_width_samples) + + return len(pos_peaks[0]) + + +def get_num_negative_peaks(template_single, sampling_frequency, **kwargs): + """ + Count the number of negative peaks in the template. + + Parameters + ---------- + template_single: numpy.ndarray + The 1D template waveform + sampling_frequency : float + The sampling frequency of the template + **kwargs: Required kwargs: + - peak_relative_threshold: the relative threshold to detect positive and negative peaks + - peak_width_ms: the width in samples to detect peaks + """ + from scipy.signal import find_peaks + + assert "peak_relative_threshold" in kwargs, "peak_relative_threshold must be given as kwarg" + assert "peak_width_ms" in kwargs, "peak_width_ms must be given as kwarg" + peak_relative_threshold = kwargs["peak_relative_threshold"] + peak_width_ms = kwargs["peak_width_ms"] + max_value = np.max(np.abs(template_single)) + peak_width_samples = int(peak_width_ms / 1000 * sampling_frequency) + + neg_peaks = find_peaks(-template_single, height=peak_relative_threshold * max_value, width=peak_width_samples) + + return len(neg_peaks[0]) + + +_single_channel_metric_name_to_func = { "peak_to_valley": get_peak_to_valley, "peak_trough_ratio": get_peak_trough_ratio, "half_width": get_half_width, "repolarization_slope": get_repolarization_slope, "recovery_slope": get_recovery_slope, + "num_positive_peaks": get_num_positive_peaks, + "num_negative_peaks": get_num_negative_peaks, } -# back-compatibility -def calculate_template_metrics(*args, **kwargs): - warnings.warn( - "The 'calculate_template_metrics' function is deprecated. " "Use 'compute_template_metrics' instead", - DeprecationWarning, - stacklevel=2, - ) - return compute_template_metrics(*args, **kwargs) +######################################################################################### +# Multi-channel metrics + + +def transform_column_range(template, channel_locations, column_range, depth_direction="y"): + """ + Transform template anch channel locations based on column range. + """ + column_dim = 0 if depth_direction == "y" else 1 + if column_range is None: + template_column_range = template + channel_locations_column_range = channel_locations + else: + max_channel_x = channel_locations[np.argmax(np.ptp(template, axis=0)), 0] + column_mask = np.abs(channel_locations[:, column_dim] - max_channel_x) <= column_range + template_column_range = template[:, column_mask] + channel_locations_column_range = channel_locations[column_mask] + return template_column_range, channel_locations_column_range + +def sort_template_and_locations(template, channel_locations, depth_direction="y"): + """ + Sort template and locations. + """ + depth_dim = 1 if depth_direction == "y" else 0 + sort_indices = np.argsort(channel_locations[:, depth_dim]) + return template[:, sort_indices], channel_locations[sort_indices, :] + + +def fit_velocity(peak_times, channel_dist): + """ + Fit velocity from peak times and channel distances using ribust Theilsen estimator. + """ + # from scipy.stats import linregress + # slope, intercept, _, _, _ = linregress(peak_times, channel_dist) + + from sklearn.linear_model import TheilSenRegressor + + theil = TheilSenRegressor() + theil.fit(peak_times.reshape(-1, 1), channel_dist) + slope = theil.coef_[0] + intercept = theil.intercept_ + score = theil.score(peak_times.reshape(-1, 1), channel_dist) + return slope, intercept, score + + +def get_velocity_above(template, channel_locations, sampling_frequency, **kwargs): + """ + Compute the velocity above the max channel of the template. + + Parameters + ---------- + template: numpy.ndarray + The template waveform (num_samples, num_channels) + channel_locations: numpy.ndarray + The channel locations (num_channels, 2) + sampling_frequency : float + The sampling frequency of the template + **kwargs: Required kwargs: + - depth_direction: the direction to compute velocity above and below ("x", "y", or "z") + - min_channels_for_velocity: the minimum number of channels above or below to compute velocity + - min_r2_velocity: the minimum r2 to accept the velocity fit + - column_range: the range in um in the x-direction to consider channels for velocity + """ + assert "depth_direction" in kwargs, "depth_direction must be given as kwarg" + assert "min_channels_for_velocity" in kwargs, "min_channels_for_velocity must be given as kwarg" + assert "min_r2_velocity" in kwargs, "min_r2_velocity must be given as kwarg" + assert "column_range" in kwargs, "column_range must be given as kwarg" + + depth_direction = kwargs["depth_direction"] + min_channels_for_velocity = kwargs["min_channels_for_velocity"] + min_r2_velocity = kwargs["min_r2_velocity"] + column_range = kwargs["column_range"] + + depth_dim = 1 if depth_direction == "y" else 0 + template, channel_locations = transform_column_range(template, channel_locations, column_range, depth_direction) + template, channel_locations = sort_template_and_locations(template, channel_locations, depth_direction) + + # find location of max channel + max_sample_idx, max_channel_idx = np.unravel_index(np.argmin(template), template.shape) + max_peak_time = max_sample_idx / sampling_frequency * 1000 + max_channel_location = channel_locations[max_channel_idx] + + channels_above = channel_locations[:, depth_dim] >= max_channel_location[depth_dim] + + # we only consider samples forward in time with respect to the max channel + # TODO: not sure + # template_above = template[max_sample_idx:, channels_above] + template_above = template[:, channels_above] + channel_locations_above = channel_locations[channels_above] + + peak_times_ms_above = np.argmin(template_above, 0) / sampling_frequency * 1000 - max_peak_time + distances_um_above = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations_above]) + velocity_above, intercept, score = fit_velocity(peak_times_ms_above, distances_um_above) + + global DEBUG + if DEBUG: + import matplotlib.pyplot as plt + + fig, axs = plt.subplots(ncols=2, figsize=(10, 7)) + offset = 1.2 * np.max(np.ptp(template, axis=0)) + ts = np.arange(template.shape[0]) / sampling_frequency * 1000 - max_peak_time + (channel_indices_above,) = np.nonzero(channels_above) + for i, single_template in enumerate(template.T): + color = "r" if i in channel_indices_above else "k" + axs[0].plot(ts, single_template + i * offset, color=color) + axs[0].axvline(0, color="g", ls="--") + axs[1].plot(peak_times_ms_above, distances_um_above, "o") + x = np.linspace(peak_times_ms_above.min(), peak_times_ms_above.max(), 20) + axs[1].plot(x, intercept + x * velocity_above) + axs[1].set_xlabel("Peak time (ms)") + axs[1].set_ylabel("Distance from max channel (um)") + fig.suptitle( + f"Velocity above: {velocity_above:.2f} um/ms - score {score:.2f} - channels: {np.sum(channels_above)}" + ) + plt.show() + + if np.sum(channels_above) < min_channels_for_velocity or score < min_r2_velocity: + velocity_above = np.nan + + return velocity_above + + +def get_velocity_below(template, channel_locations, sampling_frequency, **kwargs): + """ + Compute the velocity below the max channel of the template. + + Parameters + ---------- + template: numpy.ndarray + The template waveform (num_samples, num_channels) + channel_locations: numpy.ndarray + The channel locations (num_channels, 2) + sampling_frequency : float + The sampling frequency of the template + **kwargs: Required kwargs: + - depth_direction: the direction to compute velocity above and below ("x", "y", or "z") + - min_channels_for_velocity: the minimum number of channels above or below to compute velocity + - min_r2_velocity: the minimum r2 to accept the velocity fit + - column_range: the range in um in the x-direction to consider channels for velocity + """ + assert "depth_direction" in kwargs, "depth_direction must be given as kwarg" + assert "min_channels_for_velocity" in kwargs, "min_channels_for_velocity must be given as kwarg" + assert "min_r2_velocity" in kwargs, "min_r2_velocity must be given as kwarg" + assert "column_range" in kwargs, "column_range must be given as kwarg" + + depth_direction = kwargs["depth_direction"] + min_channels_for_velocity = kwargs["min_channels_for_velocity"] + min_r2_velocity = kwargs["min_r2_velocity"] + column_range = kwargs["column_range"] + + depth_dim = 1 if depth_direction == "y" else 0 + template, channel_locations = transform_column_range(template, channel_locations, column_range) + template, channel_locations = sort_template_and_locations(template, channel_locations, depth_direction) + + # find location of max channel + max_sample_idx, max_channel_idx = np.unravel_index(np.argmin(template), template.shape) + max_peak_time = max_sample_idx / sampling_frequency * 1000 + max_channel_location = channel_locations[max_channel_idx] + + channels_below = channel_locations[:, depth_dim] <= max_channel_location[depth_dim] + + # we only consider samples forward in time with respect to the max channel + # template_below = template[max_sample_idx:, channels_below] + template_below = template[:, channels_below] + channel_locations_below = channel_locations[channels_below] + + peak_times_ms_below = np.argmin(template_below, 0) / sampling_frequency * 1000 - max_peak_time + distances_um_below = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations_below]) + velocity_below, intercept, score = fit_velocity(peak_times_ms_below, distances_um_below) + + global DEBUG + if DEBUG: + import matplotlib.pyplot as plt + + fig, axs = plt.subplots(ncols=2, figsize=(10, 7)) + offset = 1.2 * np.max(np.ptp(template, axis=0)) + ts = np.arange(template.shape[0]) / sampling_frequency * 1000 - max_peak_time + (channel_indices_below,) = np.nonzero(channels_below) + for i, single_template in enumerate(template.T): + color = "r" if i in channel_indices_below else "k" + axs[0].plot(ts, single_template + i * offset, color=color) + axs[0].axvline(0, color="g", ls="--") + axs[1].plot(peak_times_ms_below, distances_um_below, "o") + x = np.linspace(peak_times_ms_below.min(), peak_times_ms_below.max(), 20) + axs[1].plot(x, intercept + x * velocity_below) + axs[1].set_xlabel("Peak time (ms)") + axs[1].set_ylabel("Distance from max channel (um)") + fig.suptitle( + f"Velocity below: {np.round(velocity_below, 3)} um/ms - score {score:.2f} - channels: {np.sum(channels_below)}" + ) + plt.show() + + if np.sum(channels_below) < min_channels_for_velocity or score < min_r2_velocity: + velocity_below = np.nan + + return velocity_below + + +def get_exp_decay(template, channel_locations, sampling_frequency=None, **kwargs): + """ + Compute the exponential decay of the template amplitude over distance. + + Parameters + ---------- + template: numpy.ndarray + The template waveform (num_samples, num_channels) + channel_locations: numpy.ndarray + The channel locations (num_channels, 2) + sampling_frequency : float + The sampling frequency of the template + **kwargs: Required kwargs: + - exp_peak_function: the function to use to compute the peak amplitude for the exp decay ("ptp" or "min") + - min_r2_exp_decay: the minimum r2 to accept the exp decay fit + """ + from scipy.optimize import curve_fit + from sklearn.metrics import r2_score + + def exp_decay(x, decay, amp0, offset): + return amp0 * np.exp(-decay * x) + offset + + assert "exp_peak_function" in kwargs, "exp_peak_function must be given as kwarg" + exp_peak_function = kwargs["exp_peak_function"] + assert "min_r2_exp_decay" in kwargs, "min_r2_exp_decay must be given as kwarg" + min_r2_exp_decay = kwargs["min_r2_exp_decay"] + # exp decay fit + if exp_peak_function == "ptp": + fun = np.ptp + elif exp_peak_function == "min": + fun = np.min + peak_amplitudes = np.abs(fun(template, axis=0)) + max_channel_location = channel_locations[np.argmax(peak_amplitudes)] + channel_distances = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations]) + distances_sort_indices = np.argsort(channel_distances) + # np.float128 avoids overflow error + channel_distances_sorted = channel_distances[distances_sort_indices].astype(np.float128) + peak_amplitudes_sorted = peak_amplitudes[distances_sort_indices].astype(np.float128) + try: + amp0 = peak_amplitudes_sorted[0] + offset0 = np.min(peak_amplitudes_sorted) + + popt, _ = curve_fit( + exp_decay, + channel_distances_sorted, + peak_amplitudes_sorted, + bounds=([1e-5, amp0 - 0.5 * amp0, 0], [2, amp0 + 0.5 * amp0, 2 * offset0]), + p0=[1e-3, peak_amplitudes_sorted[0], offset0], + ) + r2 = r2_score(peak_amplitudes_sorted, exp_decay(channel_distances_sorted, *popt)) + exp_decay_value = popt[0] + + if r2 < min_r2_exp_decay: + exp_decay_value = np.nan + + global DEBUG + if DEBUG: + import matplotlib.pyplot as plt + + fig, ax = plt.subplots() + ax.plot(channel_distances_sorted, peak_amplitudes_sorted, "o") + x = np.arange(channel_distances_sorted.min(), channel_distances_sorted.max()) + ax.plot(x, exp_decay(x, *popt)) + ax.set_xlabel("Distance from max channel (um)") + ax.set_ylabel("Peak amplitude") + ax.set_title( + f"Exp decay: {np.round(exp_decay_value, 3)} - Amp: {np.round(popt[1], 3)} - Offset: {np.round(popt[2], 3)} - " + f"R2: {np.round(r2, 4)}" + ) + fig.suptitle("Exp decay") + plt.show() + except: + exp_decay_value = np.nan + + return exp_decay_value + + +def get_spread(template, channel_locations, sampling_frequency, **kwargs): + """ + Compute the spread of the template amplitude over distance. + + Parameters + ---------- + template: numpy.ndarray + The template waveform (num_samples, num_channels) + channel_locations: numpy.ndarray + The channel locations (num_channels, 2) + sampling_frequency : float + The sampling frequency of the template + **kwargs: Required kwargs: + - depth_direction: the direction to compute velocity above and below ("x", "y", or "z") + - spread_threshold: the threshold to compute the spread + - column_range: the range in um in the x-direction to consider channels for velocity + """ + assert "depth_direction" in kwargs, "depth_direction must be given as kwarg" + depth_direction = kwargs["depth_direction"] + assert "spread_threshold" in kwargs, "spread_threshold must be given as kwarg" + spread_threshold = kwargs["spread_threshold"] + assert "spread_smooth_um" in kwargs, "spread_smooth_um must be given as kwarg" + spread_smooth_um = kwargs["spread_smooth_um"] + assert "column_range" in kwargs, "column_range must be given as kwarg" + column_range = kwargs["column_range"] + + depth_dim = 1 if depth_direction == "y" else 0 + template, channel_locations = transform_column_range(template, channel_locations, column_range) + template, channel_locations = sort_template_and_locations(template, channel_locations, depth_direction) + + MM = np.ptp(template, 0) + MM = MM / np.max(MM) + channel_depths = channel_locations[:, depth_dim] + + if spread_smooth_um is not None and spread_smooth_um > 0: + from scipy.ndimage import gaussian_filter1d + + spread_sigma = spread_smooth_um / np.median(np.diff(np.unique(channel_depths))) + MM = gaussian_filter1d(MM, spread_sigma) + + channel_locations_above_theshold = channel_locations[MM > spread_threshold] + channel_depth_above_theshold = channel_locations_above_theshold[:, depth_dim] + spread = np.ptp(channel_depth_above_theshold) + + global DEBUG + if DEBUG: + import matplotlib.pyplot as plt + + fig, axs = plt.subplots(ncols=2, figsize=(10, 7)) + axs[0].imshow( + template.T, + aspect="auto", + origin="lower", + extent=[0, template.shape[0] / sampling_frequency, channel_depths[0], channel_depths[-1]], + ) + axs[1].plot(channel_depths, MM, "o-") + axs[1].axhline(spread_threshold, ls="--", color="r") + axs[1].set_xlabel("Depth (um)") + axs[1].set_ylabel("Amplitude") + axs[1].set_title(f"Spread: {np.round(spread, 3)} um") + fig.suptitle("Spread") + plt.show() + + return spread + + +_multi_channel_metric_name_to_func = { + "velocity_above": get_velocity_above, + "velocity_below": get_velocity_below, + "exp_decay": get_exp_decay, + "spread": get_spread, +} -calculate_template_metrics.__doc__ = compute_template_metrics.__doc__ +_metric_name_to_func = {**_single_channel_metric_name_to_func, **_multi_channel_metric_name_to_func} diff --git a/src/spikeinterface/postprocessing/tests/test_template_metrics.py b/src/spikeinterface/postprocessing/tests/test_template_metrics.py index 9895e2ec4c..a27ccc77f8 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_metrics.py +++ b/src/spikeinterface/postprocessing/tests/test_template_metrics.py @@ -17,9 +17,13 @@ def test_sparse_metrics(self): tm_sparse = self.extension_class.get_extension_function()(self.we1, sparsity=self.sparsity1) print(tm_sparse) + def test_multi_channel_metrics(self): + tm_multi = self.extension_class.get_extension_function()(self.we1, include_multi_channel_metrics=True) + print(tm_multi) + if __name__ == "__main__": test = TemplateMetricsExtensionTest() test.setUp() - test.test_extension() - test.test_sparse_metrics() + # test.test_extension() + test.test_multi_channel_metrics() diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 0c3b9f95d1..a16b642dd5 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -67,6 +67,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # recording_f = whiten(recording_f, dtype="float32") recording_f = zscore(recording_f, dtype="float32") + noise_levels = np.ones(num_channels, dtype=np.float32) ## Then, we are detecting peaks with a locally_exclusive method detection_params = params["detection"].copy() @@ -87,7 +88,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): selection_params["n_peaks"] = params["selection"]["n_peaks_per_channel"] * num_channels selection_params["n_peaks"] = max(selection_params["min_n_peaks"], selection_params["n_peaks"]) - noise_levels = np.ones(num_channels, dtype=np.float32) selection_params.update({"noise_levels": noise_levels}) selected_peaks = select_peaks( peaks, method="smart_sampling_amplitudes", select_per_channel=False, **selection_params @@ -107,6 +107,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): clustering_params.update(dict(shared_memory=params["shared_memory"])) clustering_params["job_kwargs"] = job_kwargs clustering_params["tmp_folder"] = sorter_output_folder / "clustering" + clustering_params.update({"noise_levels": noise_levels}) labels, peak_labels = find_cluster_from_peaks( recording_f, selected_peaks, method="random_projections", method_kwargs=clustering_params @@ -114,7 +115,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## We get the labels for our peaks mask = peak_labels > -1 - sorting = NumpySorting.from_times_labels(selected_peaks["sample_index"][mask], peak_labels[mask], sampling_rate) + sorting = NumpySorting.from_times_labels( + selected_peaks["sample_index"][mask], peak_labels[mask].astype(int), sampling_rate + ) clustering_folder = sorter_output_folder / "clustering" if clustering_folder.exists(): shutil.rmtree(clustering_folder) diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index 32cf27ceb1..e256915fa6 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -29,10 +29,12 @@ class Tridesclous2Sorter(ComponentsBasedSorter): "waveforms": { "ms_before": 0.5, "ms_after": 1.5, + "radius_um": 120.0, }, - "filtering": {"freq_min": 300.0, "freq_max": 8000.0}, + "filtering": {"freq_min": 300.0, "freq_max": 12000.0}, "detection": {"peak_sign": "neg", "detect_threshold": 5, "exclude_sweep_ms": 1.5, "radius_um": 150.0}, "selection": {"n_peaks_per_channel": 5000, "min_n_peaks": 20000}, + "features": {}, "svd": {"n_components": 6}, "clustering": { "split_radius_um": 40.0, @@ -154,13 +156,14 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # upsampling_um=5.0, # ) + radius_um = params["waveforms"]["radius_um"] node3 = ExtractSparseWaveforms( recording, parents=[node0], return_output=True, - ms_before=0.5, - ms_after=1.5, - radius_um=100.0, + ms_before=ms_before, + ms_after=ms_after, + radius_um=radius_um, ) model_folder_path = sorter_output_folder / "tsvd_model" @@ -193,6 +196,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): original_labels = peaks["channel_index"] + min_cluster_size = 50 + post_split_label, split_count = split_clusters( original_labels, recording, @@ -205,8 +210,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # feature_name="sparse_wfs", neighbours_mask=neighbours_mask, waveforms_sparse_mask=sparse_mask, - min_size_split=50, - min_cluster_size=50, + min_size_split=min_cluster_size, + min_cluster_size=min_cluster_size, min_samples=50, n_pca_features=3, ), @@ -224,22 +229,22 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): recording, features_folder, radius_um=merge_radius_um, - method="project_distribution", + # method="project_distribution", + # method_kwargs=dict( + # waveforms_sparse_mask=sparse_mask, + # feature_name="sparse_wfs", + # projection="centroid", + # criteria="distrib_overlap", + # threshold_overlap=0.3, + # min_cluster_size=min_cluster_size + 1, + # num_shift=5, + # ), + method="normalized_template_diff", method_kwargs=dict( - # neighbours_mask=neighbours_mask, waveforms_sparse_mask=sparse_mask, - # feature_name="sparse_tsvd", - feature_name="sparse_wfs", - # projection='lda', - projection="centroid", - # criteria='diptest', - # threshold_diptest=0.5, - # criteria="percentile", - # threshold_percentile=80., - criteria="distrib_overlap", - threshold_overlap=0.4, - # num_shift=0 - num_shift=2, + threshold_diff=0.2, + min_cluster_size=min_cluster_size + 1, + num_shift=5, ), **job_kwargs, ) @@ -249,10 +254,19 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): new_peaks = peaks.copy() new_peaks["sample_index"] -= peak_shifts + # clean very small cluster before peeler + minimum_cluster_size = 25 + labels_set, count = np.unique(post_merge_label, return_counts=True) + to_remove = labels_set[count < minimum_cluster_size] + + mask = np.isin(post_merge_label, to_remove) + post_merge_label[mask] = -1 + + # final label sets labels_set = np.unique(post_merge_label) labels_set = labels_set[labels_set >= 0] - mask = post_merge_label >= 0 + mask = post_merge_label >= 0 sorting_temp = NumpySorting.from_times_labels( new_peaks["sample_index"][mask], post_merge_label[mask], diff --git a/src/spikeinterface/sorters/tests/test_launcher.py b/src/spikeinterface/sorters/tests/test_launcher.py index a5e29c8fd9..fdadf533f5 100644 --- a/src/spikeinterface/sorters/tests/test_launcher.py +++ b/src/spikeinterface/sorters/tests/test_launcher.py @@ -233,15 +233,15 @@ def test_run_sorters_with_dict(): if __name__ == "__main__": - # setup_module() + setup_module() job_list = get_job_list() - # test_run_sorter_jobs_loop(job_list) + test_run_sorter_jobs_loop(job_list) # test_run_sorter_jobs_joblib(job_list) # test_run_sorter_jobs_processpoolexecutor(job_list) # test_run_sorter_jobs_multiprocessing(job_list) # test_run_sorter_jobs_dask(job_list) - test_run_sorter_jobs_slurm(job_list) + # test_run_sorter_jobs_slurm(job_list) # test_run_sorter_by_property() diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py index d68b8e5449..bd413417bf 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py @@ -524,7 +524,7 @@ def plot_statistics(self, metric="cosine", annotations=True, detect_threshold=5) template_real = template_real.reshape(template_real.size, 1).T if metric == "cosine": - dist = sklearn.metrics.pairwise.cosine_similarity(template, template_real, metric).flatten().tolist() + dist = sklearn.metrics.pairwise.cosine_similarity(template, template_real).flatten().tolist() else: dist = sklearn.metrics.pairwise_distances(template, template_real, metric).flatten().tolist() res += dist diff --git a/src/spikeinterface/sortingcomponents/clustering/merge.py b/src/spikeinterface/sortingcomponents/clustering/merge.py index d892d0723a..d35b562298 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merge.py +++ b/src/spikeinterface/sortingcomponents/clustering/merge.py @@ -256,7 +256,7 @@ def find_merge_pairs( sparse_wfs, sparse_mask, radius_um=70, - method="waveforms_lda", + method="project_distribution", method_kwargs={}, **job_kwargs # n_jobs=1, @@ -308,7 +308,16 @@ def find_merge_pairs( max_workers=n_jobs, initializer=find_pair_worker_init, mp_context=get_context(mp_context), - initargs=(recording, features_dict_or_folder, peak_labels, method, method_kwargs, max_threads_per_process), + initargs=( + recording, + features_dict_or_folder, + peak_labels, + labels_set, + templates, + method, + method_kwargs, + max_threads_per_process, + ), ) as pool: jobs = [] for ind0, ind1 in zip(indices0, indices1): @@ -338,13 +347,22 @@ def find_merge_pairs( def find_pair_worker_init( - recording, features_dict_or_folder, original_labels, method, method_kwargs, max_threads_per_process + recording, + features_dict_or_folder, + original_labels, + labels_set, + templates, + method, + method_kwargs, + max_threads_per_process, ): global _ctx _ctx = {} _ctx["recording"] = recording _ctx["original_labels"] = original_labels + _ctx["labels_set"] = labels_set + _ctx["templates"] = templates _ctx["method"] = method _ctx["method_kwargs"] = method_kwargs _ctx["method_class"] = find_pair_method_dict[method] @@ -364,8 +382,16 @@ def find_pair_function_wrapper(label0, label1): global _ctx with threadpool_limits(limits=_ctx["max_threads_per_process"]): is_merge, label0, label1, shift, merge_value = _ctx["method_class"].merge( - label0, label1, _ctx["original_labels"], _ctx["peaks"], _ctx["features"], **_ctx["method_kwargs"] + label0, + label1, + _ctx["labels_set"], + _ctx["templates"], + _ctx["original_labels"], + _ctx["peaks"], + _ctx["features"], + **_ctx["method_kwargs"], ) + return is_merge, label0, label1, shift, merge_value @@ -388,6 +414,8 @@ class ProjectDistribution: def merge( label0, label1, + labels_set, + templates, original_labels, peaks, features, @@ -398,6 +426,7 @@ def merge( threshold_diptest=0.5, threshold_percentile=80.0, threshold_overlap=0.4, + min_cluster_size=50, num_shift=2, ): if num_shift > 0: @@ -414,7 +443,7 @@ def merge( chans1 = np.unique(peaks["channel_index"][inds1]) target_chans1 = np.flatnonzero(np.all(waveforms_sparse_mask[chans1, :], axis=0)) - if inds0.size < 40 or inds1.size < 40: + if inds0.size < min_cluster_size or inds1.size < min_cluster_size: is_merge = False merge_value = 0 final_shift = 0 @@ -525,7 +554,9 @@ def merge( # DEBUG = True DEBUG = False - if DEBUG and is_merge: + # if DEBUG and is_merge: + # if DEBUG and (overlap > 0.1 and overlap <0.3): + if DEBUG: # if DEBUG and not is_merge: # if DEBUG and (overlap > 0.05 and overlap <0.25): # if label0 == 49 and label1== 65: @@ -575,7 +606,100 @@ def merge( return is_merge, label0, label1, final_shift, merge_value +class NormalizedTemplateDiff: + """ + Compute the normalized (some kind of) template differences. + And merge if below a threhold. + Do this at several shift. + + """ + + name = "normalized_template_diff" + + @staticmethod + def merge( + label0, + label1, + labels_set, + templates, + original_labels, + peaks, + features, + waveforms_sparse_mask=None, + threshold_diff=0.05, + min_cluster_size=50, + num_shift=5, + ): + assert waveforms_sparse_mask is not None + + (inds0,) = np.nonzero(original_labels == label0) + chans0 = np.unique(peaks["channel_index"][inds0]) + target_chans0 = np.flatnonzero(np.all(waveforms_sparse_mask[chans0, :], axis=0)) + + (inds1,) = np.nonzero(original_labels == label1) + chans1 = np.unique(peaks["channel_index"][inds1]) + target_chans1 = np.flatnonzero(np.all(waveforms_sparse_mask[chans1, :], axis=0)) + + # if inds0.size < min_cluster_size or inds1.size < min_cluster_size: + # is_merge = False + # merge_value = 0 + # final_shift = 0 + # return is_merge, label0, label1, final_shift, merge_value + + target_chans = np.intersect1d(target_chans0, target_chans1) + union_chans = np.union1d(target_chans0, target_chans1) + + ind0 = list(labels_set).index(label0) + template0 = templates[ind0][:, target_chans] + + ind1 = list(labels_set).index(label1) + template1 = templates[ind1][:, target_chans] + + num_samples = template0.shape[0] + # norm = np.mean(np.abs(template0)) + np.mean(np.abs(template1)) + norm = np.mean(np.abs(template0) + np.abs(template1)) + all_shift_diff = [] + for shift in range(-num_shift, num_shift + 1): + temp0 = template0[num_shift : num_samples - num_shift, :] + temp1 = template1[num_shift + shift : num_samples - num_shift + shift, :] + d = np.mean(np.abs(temp0 - temp1)) / (norm) + all_shift_diff.append(d) + normed_diff = np.min(all_shift_diff) + + is_merge = normed_diff < threshold_diff + if is_merge: + merge_value = normed_diff + final_shift = np.argmin(all_shift_diff) - num_shift + else: + final_shift = 0 + merge_value = np.nan + + # DEBUG = False + DEBUG = True + if DEBUG and normed_diff < 0.2: + # if DEBUG: + + import matplotlib.pyplot as plt + + fig, ax = plt.subplots() + + m0 = template0.flatten() + m1 = template1.flatten() + + ax.plot(m0, color="C0", label=f"{label0} {inds0.size}") + ax.plot(m1, color="C1", label=f"{label1} {inds1.size}") + + ax.set_title( + f"union{union_chans.size} intersect{target_chans.size} \n {normed_diff:.3f} {final_shift} {is_merge}" + ) + ax.legend() + plt.show() + + return is_merge, label0, label1, final_shift, merge_value + + find_pair_method_list = [ ProjectDistribution, + NormalizedTemplateDiff, ] find_pair_method_dict = {e.name: e for e in find_pair_method_list} diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index a81458d7a8..72acd49f4f 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -20,7 +20,12 @@ from spikeinterface.core import extract_waveforms from spikeinterface.sortingcomponents.waveforms.savgol_denoiser import SavGolDenoiser from spikeinterface.sortingcomponents.features_from_peaks import RandomProjectionsFeature -from spikeinterface.core.node_pipeline import run_node_pipeline, ExtractDenseWaveforms, PeakRetriever +from spikeinterface.core.node_pipeline import ( + run_node_pipeline, + ExtractDenseWaveforms, + ExtractSparseWaveforms, + PeakRetriever, +) class RandomProjectionClustering: @@ -43,7 +48,8 @@ class RandomProjectionClustering: "ms_before": 1, "ms_after": 1, "random_seed": 42, - "smoothing_kwargs": {"window_length_ms": 1}, + "noise_levels": None, + "smoothing_kwargs": {"window_length_ms": 0.25}, "shared_memory": True, "tmp_folder": None, "job_kwargs": {"n_jobs": os.cpu_count(), "chunk_memory": "100M", "verbose": True, "progress_bar": True}, @@ -72,7 +78,10 @@ def main_function(cls, recording, peaks, params): num_samples = nbefore + nafter num_chans = recording.get_num_channels() - noise_levels = get_noise_levels(recording, return_scaled=False) + if d["noise_levels"] is None: + noise_levels = get_noise_levels(recording, return_scaled=False) + else: + noise_levels = d["noise_levels"] np.random.seed(d["random_seed"]) @@ -82,10 +91,16 @@ def main_function(cls, recording, peaks, params): else: tmp_folder = Path(params["tmp_folder"]).absolute() - ### Then we extract the SVD features + tmp_folder.mkdir(parents=True, exist_ok=True) + node0 = PeakRetriever(recording, peaks) - node1 = ExtractDenseWaveforms( - recording, parents=[node0], return_output=False, ms_before=params["ms_before"], ms_after=params["ms_after"] + node1 = ExtractSparseWaveforms( + recording, + parents=[node0], + return_output=False, + ms_before=params["ms_before"], + ms_after=params["ms_after"], + radius_um=params["radius_um"], ) node2 = SavGolDenoiser(recording, parents=[node0, node1], return_output=False, **params["smoothing_kwargs"]) @@ -123,6 +138,8 @@ def sigmoid(x, L, x0, k, b): return_output=True, projections=projections, radius_um=params["radius_um"], + sigmoid=None, + sparse=True, ) pipeline_nodes = [node0, node1, node2, node3] @@ -136,6 +153,18 @@ def sigmoid(x, L, x0, k, b): clustering = hdbscan.hdbscan(hdbscan_data, **d["hdbscan_kwargs"]) peak_labels = clustering[0] + # peak_labels = -1 * np.ones(len(peaks), dtype=int) + # nb_clusters = 0 + # for c in np.unique(peaks['channel_index']): + # mask = peaks['channel_index'] == c + # clustering = hdbscan.hdbscan(hdbscan_data[mask], **d['hdbscan_kwargs']) + # local_labels = clustering[0] + # valid_clusters = local_labels > -1 + # if np.sum(valid_clusters) > 0: + # local_labels[valid_clusters] += nb_clusters + # peak_labels[mask] = local_labels + # nb_clusters += len(np.unique(local_labels[valid_clusters])) + labels = np.unique(peak_labels) labels = labels[labels >= 0] @@ -174,15 +203,6 @@ def sigmoid(x, L, x0, k, b): if verbose: print("We found %d raw clusters, starting to clean with matching..." % (len(labels))) - # create a tmp folder - if params["tmp_folder"] is None: - name = "".join(random.choices(string.ascii_uppercase + string.digits, k=8)) - tmp_folder = get_global_tmp_folder() / name - else: - tmp_folder = Path(params["tmp_folder"]) - - tmp_folder.mkdir(parents=True, exist_ok=True) - sorting_folder = tmp_folder / "sorting" unit_ids = np.arange(len(np.unique(spikes["unit_index"]))) sorting = NumpySorting(spikes, fs, unit_ids=unit_ids) diff --git a/src/spikeinterface/sortingcomponents/clustering/split.py b/src/spikeinterface/sortingcomponents/clustering/split.py index 9836e9110f..b433a2d16d 100644 --- a/src/spikeinterface/sortingcomponents/clustering/split.py +++ b/src/spikeinterface/sortingcomponents/clustering/split.py @@ -192,6 +192,7 @@ def split( # target channel subset is done intersect local channels + neighbours local_chans = np.unique(peaks["channel_index"][peak_indices]) + target_channels = np.flatnonzero(np.all(neighbours_mask[local_chans, :], axis=0)) # TODO fix this a better way, this when cluster have too few overlapping channels @@ -204,6 +205,7 @@ def split( local_labels[dont_have_channels] = -2 kept = np.flatnonzero(~dont_have_channels) + if kept.size < min_size_split: return False, None diff --git a/src/spikeinterface/sortingcomponents/features_from_peaks.py b/src/spikeinterface/sortingcomponents/features_from_peaks.py index b534c2356d..06d22181cb 100644 --- a/src/spikeinterface/sortingcomponents/features_from_peaks.py +++ b/src/spikeinterface/sortingcomponents/features_from_peaks.py @@ -186,6 +186,7 @@ def __init__( projections=None, sigmoid=None, radius_um=None, + sparse=True, ): PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) @@ -195,7 +196,8 @@ def __init__( self.channel_distance = get_channel_distances(recording) self.neighbours_mask = self.channel_distance < radius_um self.radius_um = radius_um - self._kwargs.update(dict(projections=projections, sigmoid=sigmoid, radius_um=radius_um)) + self.sparse = sparse + self._kwargs.update(dict(projections=projections, sigmoid=sigmoid, radius_um=radius_um, sparse=sparse)) self._dtype = recording.get_dtype() def get_dtype(self): @@ -213,7 +215,10 @@ def compute(self, traces, peaks, waveforms): (idx,) = np.nonzero(peaks["channel_index"] == main_chan) (chan_inds,) = np.nonzero(self.neighbours_mask[main_chan]) local_projections = self.projections[chan_inds, :] - wf_ptp = np.ptp(waveforms[idx][:, :, chan_inds], axis=1) + if self.sparse: + wf_ptp = np.ptp(waveforms[idx][:, :, : len(chan_inds)], axis=1) + else: + wf_ptp = np.ptp(waveforms[idx][:, :, chan_inds], axis=1) if self.sigmoid is not None: wf_ptp *= self._sigmoid(wf_ptp) diff --git a/src/spikeinterface/widgets/spikes_on_traces.py b/src/spikeinterface/widgets/spikes_on_traces.py index c2bed8fe41..b68efc3f8a 100644 --- a/src/spikeinterface/widgets/spikes_on_traces.py +++ b/src/spikeinterface/widgets/spikes_on_traces.py @@ -162,10 +162,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): max_y = np.max(traces_widget.data_plot["channel_locations"][:, 1]) n = len(traces_widget.data_plot["channel_ids"]) - order = traces_widget.data_plot["order"] - - if order is None: - order = np.arange(n) if ax.get_legend() is not None: ax.get_legend().remove() @@ -221,7 +217,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): # discontinuity times[:, -1] = np.nan times_r = times.reshape(times.shape[0] * times.shape[1]) - waveforms = traces[waveform_idxs] # [:, :, order] + waveforms = traces[waveform_idxs] waveforms_r = waveforms.reshape((waveforms.shape[0] * waveforms.shape[1], waveforms.shape[2])) for i, chan_id in enumerate(traces_widget.data_plot["channel_ids"]): diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index cec4b5ce53..ff3d1436ba 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -120,7 +120,7 @@ plot_study_run_times = StudyRunTimesWidget plot_study_unit_counts = StudyUnitCountsWidget plot_study_performances = StudyPerformances -plot_stufy_performances_vs_metrics = StudyPerformancesVsMetrics +plot_study_performances_vs_metrics = StudyPerformancesVsMetrics def plot_timeseries(*args, **kwargs):