From f96e2b79df22b699cfd380b126b4c977339f3ab1 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 13 Sep 2023 19:28:00 +0200 Subject: [PATCH 01/35] Extend and refactor waveform metrics --- src/spikeinterface/postprocessing/__init__.py | 1 - .../postprocessing/template_metrics.py | 583 ++++++++++++++++-- .../tests/test_template_metrics.py | 8 +- 3 files changed, 534 insertions(+), 58 deletions(-) diff --git a/src/spikeinterface/postprocessing/__init__.py b/src/spikeinterface/postprocessing/__init__.py index 223bda5e30..d7e1ffac01 100644 --- a/src/spikeinterface/postprocessing/__init__.py +++ b/src/spikeinterface/postprocessing/__init__.py @@ -10,7 +10,6 @@ from .template_metrics import ( TemplateMetricsCalculator, compute_template_metrics, - calculate_template_metrics, get_template_metric_names, ) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 681f6f3e84..119f0dc53d 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -11,9 +11,24 @@ from ..core.waveform_extractor import BaseWaveformExtractorExtension import warnings +# DEBUG = True + +# if DEBUG: +# import matplotlib.pyplot as plt +# plt.ion() +# plt.show() + + +def get_1d_template_metric_names(): + return deepcopy(list(_1d_metric_name_to_func.keys())) + + +def get_2d_template_metric_names(): + return deepcopy(list(_2d_metric_name_to_func.keys())) + def get_template_metric_names(): - return deepcopy(list(_metric_name_to_func.keys())) + return get_1d_template_metric_names() + get_2d_template_metric_names() class TemplateMetricsCalculator(BaseWaveformExtractorExtension): @@ -26,20 +41,31 @@ class TemplateMetricsCalculator(BaseWaveformExtractorExtension): """ extension_name = "template_metrics" + min_channels_for_2d_warning = 10 def __init__(self, waveform_extractor): BaseWaveformExtractorExtension.__init__(self, waveform_extractor) - def _set_params(self, metric_names=None, peak_sign="neg", upsampling_factor=10, sparsity=None, window_slope_ms=0.7): + def _set_params( + self, + metric_names=None, + peak_sign="neg", + upsampling_factor=10, + sparsity=None, + functions_kwargs=None, + include_2d_metrics=False, + ): if metric_names is None: - metric_names = get_template_metric_names() - + metric_names = get_1d_template_metric_names() + if include_2d_metrics: + metric_names += get_2d_template_metric_names() + functions_kwargs = functions_kwargs or dict() params = dict( metric_names=[str(name) for name in metric_names], sparsity=sparsity, peak_sign=peak_sign, upsampling_factor=int(upsampling_factor), - window_slope_ms=float(window_slope_ms), + functions_kwargs=functions_kwargs, ) return params @@ -60,6 +86,9 @@ def _run(self): unit_ids = self.waveform_extractor.sorting.unit_ids sampling_frequency = self.waveform_extractor.sampling_frequency + metrics_1d = [m for m in metric_names if m in get_1d_template_metric_names()] + metrics_2d = [m for m in metric_names if m in get_2d_template_metric_names()] + if sparsity is None: extremum_channels_ids = get_template_extremum_channel( self.waveform_extractor, peak_sign=peak_sign, outputs="id" @@ -79,6 +108,8 @@ def _run(self): template_metrics = pd.DataFrame(index=multi_index, columns=metric_names) all_templates = self.waveform_extractor.get_all_templates() + channel_locations = self.waveform_extractor.get_channel_locations() + for unit_index, unit_id in enumerate(unit_ids): template_all_chans = all_templates[unit_index] chan_ids = np.array(extremum_channels_ids[unit_id]) @@ -87,6 +118,7 @@ def _run(self): chan_ind = self.waveform_extractor.channel_ids_to_indices(chan_ids) template = template_all_chans[:, chan_ind] + # compute 1d metrics for i, template_single in enumerate(template.T): if sparsity is None: index = unit_id @@ -100,15 +132,50 @@ def _run(self): template_upsampled = template_single sampling_frequency_up = sampling_frequency - for metric_name in metric_names: + trough_idx, peak_idx = get_trough_and_peak_idx(template_upsampled) + + for metric_name in metrics_1d: func = _metric_name_to_func[metric_name] value = func( template_upsampled, sampling_frequency=sampling_frequency_up, - window_ms=self._params["window_slope_ms"], + trough_idx=trough_idx, + peak_idx=peak_idx, + **self._params["functions_kwargs"], ) template_metrics.at[index, metric_name] = value + # compute metrics 2d + for metric_name in metrics_2d: + # retrieve template (with sparsity if waveform extractor is sparse) + template = self.waveform_extractor.get_template(unit_id=unit_id) + + if template.shape[1] < self.min_channels_for_2d_warning: + warnings.warn( + f"With less than {self.min_channels_for_2d_warning} channels, " + "2D metrics might not be reliable." + ) + if self.waveform_extractor.is_sparse(): + channel_locations_sparse = channel_locations[self.waveform_extractor.sparsity.mask[unit_index]] + else: + channel_locations_sparse = channel_locations + + if upsampling_factor > 1: + assert isinstance(upsampling_factor, (int, np.integer)), "'upsample' must be an integer" + template_upsampled = resample_poly(template, up=upsampling_factor, down=1, axis=0) + sampling_frequency_up = upsampling_factor * sampling_frequency + else: + template_upsampled = template + sampling_frequency_up = sampling_frequency + + func = _metric_name_to_func[metric_name] + value = func( + template_upsampled, + channel_locations=channel_locations_sparse, + sampling_frequency=sampling_frequency_up, + **self._params["functions_kwargs"], + ) + template_metrics.at[index, metric_name] = value self._extension_data["metrics"] = template_metrics def get_data(self): @@ -139,7 +206,17 @@ def compute_template_metrics( peak_sign="neg", upsampling_factor=10, sparsity=None, - window_slope_ms=0.7, + include_2d_metrics=False, + functions_kwargs=dict( + recovery_window_ms=0.7, + peak_relative_threshold=0.2, + peak_width_ms=0.2, + depth_direction="y", + min_channels_for_velocity=5, + min_r2_for_velocity=0.5, + exp_peak_function="ptp", + spread_threshold=0.2, + ), ): """ Compute template metrics including: @@ -148,6 +225,14 @@ def compute_template_metrics( * halfwidth * repolarization_slope * recovery_slope + * num_positive_peaks + * num_negative_peaks + + Optionally, the following 2d metrics can be computed (when include_2d_metrics=True): + * velocity_above + * velocity_below + * exp_decay + * spread Parameters ---------- @@ -157,34 +242,57 @@ def compute_template_metrics( Whether to load precomputed template metrics, if they already exist. metric_names : list, optional List of metrics to compute (see si.postprocessing.get_template_metric_names()), by default None - peak_sign : str, optional - "pos" | "neg", by default 'neg' - upsampling_factor : int, optional - Upsample factor, by default 10 - sparsity: dict or None + peak_sign : {"neg", "pos"}, default: "neg" + The peak sign + upsampling_factor : int, default: 10 + The upsampling factor to upsample the templates + sparsity: dict or None, default: None Default is sparsity=None and template metric is computed on extremum channel only. If given, the dictionary should contain a unit ids as keys and a channel id or a list of channel ids as values. For more generating a sparsity dict, see the postprocessing.compute_sparsity() function. - window_slope_ms: float - Window in ms after the positiv peak to compute slope, by default 0.7 + include_2d_metrics: bool, default: False + Whether to compute 2d metrics + functions_kwargs: dict + Additional arguments to pass to the metric functions. Including: + * recovery_window_ms: the window in ms after the peak to compute the recovery_slope, default: 0.7 + * peak_relative_threshold: the relative threshold to detect positive and negative peaks, default: 0.2 + * peak_width_ms: the width in samples to detect peaks, default: 0.2 + * depth_direction: the direction to compute velocity above and below, default: "y" + * min_channels_for_velocity: the minimum number of channels above or below to compute velocity, default: 5 + * min_r2_for_velocity: the minimum r2 to accept the velocity fit, default: 0.7 + * exp_peak_function: the function to use to compute the peak amplitude for the exp decay, default: "ptp" + * spread_threshold: the threshold to compute the spread, default: 0.2 Returns ------- - tempalte_metrics : pd.DataFrame + template_metrics : pd.DataFrame Dataframe with the computed template metrics. If 'sparsity' is None, the index is the unit_id. If 'sparsity' is given, the index is a multi-index (unit_id, channel_id) + + Notes + ----- + If any 2d metric is in the metric_names or include_2d_metrics is True, sparsity must be None, so that one metric + value will be computed per unit. """ if load_if_exists and waveform_extractor.is_extension(TemplateMetricsCalculator.extension_name): tmc = waveform_extractor.load_extension(TemplateMetricsCalculator.extension_name) else: tmc = TemplateMetricsCalculator(waveform_extractor) + # For 2D metrics, external sparsity must be None, so that one metric value will be computed per unit. + if include_2d_metrics or ( + metric_names is not None and any([m in get_2d_template_metric_names() for m in metric_names]) + ): + assert ( + sparsity is None + ), "If 2D metrics are computed, sparsity must be None, so that each unit will correspond to 1 row of the output dataframe." tmc.set_params( metric_names=metric_names, peak_sign=peak_sign, upsampling_factor=upsampling_factor, sparsity=sparsity, - window_slope_ms=window_slope_ms, + include_2d_metrics=include_2d_metrics, + functions_kwargs=functions_kwargs, ) tmc.run() @@ -197,7 +305,19 @@ def get_trough_and_peak_idx(template): """ Return the indices into the input template of the detected trough (minimum of template) and peak (maximum of template, after trough). - Assumes negative trough and positive peak + Assumes negative trough and positive peak. + + Parameters + ---------- + template: numpy.ndarray + The 1D template waveform + + Returns + ------- + trough_idx: int + The index of the trough + peak_idx: int + The index of the peak """ assert template.ndim == 1 trough_idx = np.argmin(template) @@ -205,41 +325,94 @@ def get_trough_and_peak_idx(template): return trough_idx, peak_idx -def get_peak_to_valley(template, **kwargs): +######################################################################################### +# 1D metrics +def get_peak_to_valley(template_single, trough_idx=None, peak_idx=None, **kwargs): """ - Time between trough and peak in s + Return the peak to valley duration in seconds of input waveforms. + + Parameters + ---------- + template_single: numpy.ndarray + The 1D template waveform + trough_idx: int, default: None + The index of the trough + peak_idx: int, default: None + The index of the peak + **kwargs: Required kwargs: + - sampling_frequency: the sampling frequency + + Returns + ------- + ptv: float + The peak to valley duration in seconds """ sampling_frequency = kwargs["sampling_frequency"] - trough_idx, peak_idx = get_trough_and_peak_idx(template) + if trough_idx is None or peak_idx is None: + trough_idx, peak_idx = get_trough_and_peak_idx(template_single) ptv = (peak_idx - trough_idx) / sampling_frequency return ptv -def get_peak_trough_ratio(template, **kwargs): +def get_peak_trough_ratio(template_single, trough_idx=None, peak_idx=None, **kwargs): """ - Ratio between peak heigth and trough depth + Return the peak to trough ratio of input waveforms. + + Parameters + ---------- + template_single: numpy.ndarray + The 1D template waveform + trough_idx: int, default: None + The index of the trough + peak_idx: int, default: None + The index of the peak + **kwargs: Required kwargs: + - sampling_frequency: the sampling frequency + + Returns + ------- + ptratio: float + The peak to trough ratio """ - trough_idx, peak_idx = get_trough_and_peak_idx(template) - ptratio = template[peak_idx] / template[trough_idx] + if trough_idx is None or peak_idx is None: + trough_idx, peak_idx = get_trough_and_peak_idx(template_single) + ptratio = template_single[peak_idx] / template_single[trough_idx] return ptratio -def get_half_width(template, **kwargs): +def get_half_width(template_single, trough_idx=None, peak_idx=None, **kwargs): """ - Width of waveform at its half of amplitude in s + Return the half width of input waveforms in seconds. + + Parameters + ---------- + template_single: numpy.ndarray + The 1D template waveform + trough_idx: int, default: None + The index of the trough + peak_idx: int, default: None + The index of the peak + **kwargs: Required kwargs: + - sampling_frequency: the sampling frequency + + Returns + ------- + hw: float + The half width in seconds """ - trough_idx, peak_idx = get_trough_and_peak_idx(template) + if trough_idx is None or peak_idx is None: + trough_idx, peak_idx = get_trough_and_peak_idx(template_single) sampling_frequency = kwargs["sampling_frequency"] if peak_idx == 0: return np.nan - trough_val = template[trough_idx] + trough_val = template_single[trough_idx] # threshold is half of peak heigth (assuming baseline is 0) threshold = 0.5 * trough_val - (cpre_idx,) = np.where(template[:trough_idx] < threshold) - (cpost_idx,) = np.where(template[trough_idx:] < threshold) + (cpre_idx,) = np.where(template_single[:trough_idx] < threshold) + (cpost_idx,) = np.where(template_single[trough_idx:] < threshold) if len(cpre_idx) == 0 or len(cpost_idx) == 0: hw = np.nan @@ -254,7 +427,7 @@ def get_half_width(template, **kwargs): return hw -def get_repolarization_slope(template, **kwargs): +def get_repolarization_slope(template_single, trough_idx=None, **kwargs): """ Return slope of repolarization period between trough and baseline @@ -264,17 +437,26 @@ def get_repolarization_slope(template, **kwargs): Optionally the function returns also the indices per waveform where the potential crosses baseline. - """ - trough_idx, peak_idx = get_trough_and_peak_idx(template) + Parameters + ---------- + template_single: numpy.ndarray + The 1D template waveform + trough_idx: int, default: None + The index of the trough + **kwargs: Required kwargs: + - sampling_frequency: the sampling frequency + """ + if trough_idx is None: + trough_idx = get_trough_and_peak_idx(template_single) sampling_frequency = kwargs["sampling_frequency"] - times = np.arange(template.shape[0]) / sampling_frequency + times = np.arange(template_single.shape[0]) / sampling_frequency if trough_idx == 0: return np.nan - (rtrn_idx,) = np.nonzero(template[trough_idx:] >= 0) + (rtrn_idx,) = np.nonzero(template_single[trough_idx:] >= 0) if len(rtrn_idx) == 0: return np.nan # first time after trough, where template is at baseline @@ -285,11 +467,11 @@ def get_repolarization_slope(template, **kwargs): import scipy.stats - res = scipy.stats.linregress(times[trough_idx:return_to_base_idx], template[trough_idx:return_to_base_idx]) + res = scipy.stats.linregress(times[trough_idx:return_to_base_idx], template_single[trough_idx:return_to_base_idx]) return res.slope -def get_recovery_slope(template, window_ms=0.7, **kwargs): +def get_recovery_slope(template_single, peak_idx=None, **kwargs): """ Return the recovery slope of input waveforms. After repolarization, the neuron hyperpolarizes untill it peaks. The recovery slope is the @@ -299,41 +481,332 @@ def get_recovery_slope(template, window_ms=0.7, **kwargs): Takes a numpy array of waveforms and returns an array with recovery slopes per waveform. + + Parameters + ---------- + template_single: numpy.ndarray + The 1D template waveform + peak_idx: int, default: None + The index of the peak + **kwargs: Required kwargs: + - sampling_frequency: the sampling frequency + - recovery_window_ms: the window in ms after the peak to compute the recovery_slope """ + import scipy.stats - trough_idx, peak_idx = get_trough_and_peak_idx(template) + assert "recovery_window_ms" in kwargs, "recovery_window_ms must be given as kwarg" + recovery_window_ms = kwargs["recovery_window_ms"] + if peak_idx is None: + _, peak_idx = get_trough_and_peak_idx(template_single) sampling_frequency = kwargs["sampling_frequency"] - times = np.arange(template.shape[0]) / sampling_frequency + times = np.arange(template_single.shape[0]) / sampling_frequency if peak_idx == 0: return np.nan - max_idx = int(peak_idx + ((window_ms / 1000) * sampling_frequency)) - max_idx = np.min([max_idx, template.shape[0]]) - - import scipy.stats + max_idx = int(peak_idx + ((recovery_window_ms / 1000) * sampling_frequency)) + max_idx = np.min([max_idx, template_single.shape[0]]) - res = scipy.stats.linregress(times[peak_idx:max_idx], template[peak_idx:max_idx]) + res = scipy.stats.linregress(times[peak_idx:max_idx], template_single[peak_idx:max_idx]) return res.slope -_metric_name_to_func = { +def get_num_positive_peaks(template_single, **kwargs): + """ + Count the number of positive peaks in the template. + + Parameters + ---------- + template_single: numpy.ndarray + The 1D template waveform + **kwargs: Required kwargs: + - peak_relative_threshold: the relative threshold to detect positive and negative peaks + - peak_width_ms: the width in samples to detect peaks + - sampling_frequency: the sampling frequency + """ + from scipy.signal import find_peaks + + assert "peak_relative_threshold" in kwargs, "peak_relative_threshold must be given as kwarg" + assert "peak_width_ms" in kwargs, "peak_width_ms must be given as kwarg" + peak_relative_threshold = kwargs["peak_relative_threshold"] + peak_width_ms = kwargs["peak_width_ms"] + max_value = np.max(np.abs(template_single)) + peak_width_samples = int(peak_width_ms / 1000 * kwargs["sampling_frequency"]) + + pos_peaks = find_peaks(template_single, height=peak_relative_threshold * max_value, width=peak_width_samples) + + return len(pos_peaks[0]) + + +def get_num_negative_peaks(template_single, **kwargs): + """ + Count the number of negative peaks in the template. + + Parameters + ---------- + template_single: numpy.ndarray + The 1D template waveform + **kwargs: Required kwargs: + - peak_relative_threshold: the relative threshold to detect positive and negative peaks + - peak_width_ms: the width in samples to detect peaks + - sampling_frequency: the sampling frequency + """ + from scipy.signal import find_peaks + + assert "peak_relative_threshold" in kwargs, "peak_relative_threshold must be given as kwarg" + assert "peak_width_ms" in kwargs, "peak_width_ms must be given as kwarg" + peak_relative_threshold = kwargs["peak_relative_threshold"] + peak_width_ms = kwargs["peak_width_ms"] + max_value = np.max(np.abs(template_single)) + peak_width_samples = int(peak_width_ms / 1000 * kwargs["sampling_frequency"]) + + neg_peaks = find_peaks(-template_single, height=peak_relative_threshold * max_value, width=peak_width_samples) + + return len(neg_peaks[0]) + + +_1d_metric_name_to_func = { "peak_to_valley": get_peak_to_valley, "peak_trough_ratio": get_peak_trough_ratio, "half_width": get_half_width, "repolarization_slope": get_repolarization_slope, "recovery_slope": get_recovery_slope, + "num_positive_peaks": get_num_positive_peaks, + "num_negative_peaks": get_num_negative_peaks, } -# back-compatibility -def calculate_template_metrics(*args, **kwargs): - warnings.warn( - "The 'calculate_template_metrics' function is deprecated. " "Use 'compute_template_metrics' instead", - DeprecationWarning, - stacklevel=2, - ) - return compute_template_metrics(*args, **kwargs) +######################################################################################### +# 2D metrics + + +def fit_velocity(peak_times, channel_dist): + # from scipy.stats import linregress + # slope, intercept, _, _, _ = linregress(peak_times, channel_dist) + + from sklearn.linear_model import TheilSenRegressor + + theil = TheilSenRegressor() + theil.fit(peak_times.reshape(-1, 1), channel_dist) + slope = theil.coef_[0] + intercept = theil.intercept_ + score = theil.score(peak_times.reshape(-1, 1), channel_dist) + return slope, intercept, score + + +def get_velocity_above(template, channel_locations, **kwargs): + """ + Compute the velocity above the max channel of the template. + + Parameters + ---------- + template: numpy.ndarray + The template waveform (num_samples, num_channels) + channel_locations: numpy.ndarray + The channel locations (num_channels, 2) + **kwargs: Required kwargs: + - depth_direction: the direction to compute velocity above and below ("x", "y", or "z") + - min_channels_for_velocity: the minimum number of channels above or below to compute velocity + - min_r2_for_velocity: the minimum r2 to accept the velocity fit + - sampling_frequency: the sampling frequency + """ + assert "depth_direction" in kwargs, "depth_direction must be given as kwarg" + assert "min_channels_for_velocity" in kwargs, "min_channels_for_velocity must be given as kwarg" + assert "min_r2_for_velocity" in kwargs, "min_r2_for_velocity must be given as kwarg" + + depth_direction = kwargs["depth_direction"] + min_channels_for_velocity = kwargs["min_channels_for_velocity"] + min_r2_for_velocity = kwargs["min_r2_for_velocity"] + + direction_index = ["x", "y", "z"].index(depth_direction) + sampling_frequency = kwargs["sampling_frequency"] + + # find location of max channel + max_sample_idx, max_channel_idx = np.unravel_index(np.argmin(template), template.shape) + max_channel_location = channel_locations[max_channel_idx] + + channels_above = channel_locations[:, direction_index] >= max_channel_location[direction_index] + + # we only consider samples forward in time with respect to the max channel + template_above = template[max_sample_idx:, channels_above] + channel_locations_above = channel_locations[channels_above] + + peak_times_ms_above = np.argmin(template_above, 0) / sampling_frequency * 1000 + distances_um_above = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations_above]) + velocity_above, intercept, score = fit_velocity(peak_times_ms_above, distances_um_above) + + # if DEBUG: + # fig, ax = plt.subplots() + # ax.plot(peak_times_ms_above, distances_um_above, "o") + # x = np.linspace(peak_times_ms_above.min(), peak_times_ms_above.max(), 20) + # ax.plot(x, intercept + x * velocity_above) + # ax.set_xlabel("Peak time (ms)") + # ax.set_ylabel("Distance from max channel (um)") + # ax.set_title(f"Velocity above: {velocity_above:.2f} um/ms") + + if np.sum(channels_above) < min_channels_for_velocity: + # if DEBUG: + # ax.set_title("NaN velocity - not enough channels") + return np.nan + + if score < min_r2_for_velocity: + # if DEBUG: + # ax.set_title(f"NaN velocity - R2 is too low: {score:.2f}") + return np.nan + return velocity_above + + +def get_velocity_below(template, channel_locations, **kwargs): + """ + Compute the velocity below the max channel of the template. + + Parameters + ---------- + template: numpy.ndarray + The template waveform (num_samples, num_channels) + channel_locations: numpy.ndarray + The channel locations (num_channels, 2) + **kwargs: Required kwargs: + - depth_direction: the direction to compute velocity above and below ("x", "y", or "z") + - min_channels_for_velocity: the minimum number of channels above or below to compute velocity + - min_r2_for_velocity: the minimum r2 to accept the velocity fit + - sampling_frequency: the sampling frequency + """ + assert "depth_direction" in kwargs, "depth_direction must be given as kwarg" + assert "min_channels_for_velocity" in kwargs, "min_channels_for_velocity must be given as kwarg" + assert "min_r2_for_velocity" in kwargs, "min_r2_for_velocity must be given as kwarg" + direction = kwargs["depth_direction"] + min_channels_for_velocity = kwargs["min_channels_for_velocity"] + min_r2_for_velocity = kwargs["min_r2_for_velocity"] + direction_index = ["x", "y", "z"].index(direction) + + # find location of max channel + max_sample_idx, max_channel_idx = np.unravel_index(np.argmin(template), template.shape) + max_channel_location = channel_locations[max_channel_idx] + sampling_frequency = kwargs["sampling_frequency"] + + channels_below = channel_locations[:, direction_index] <= max_channel_location[direction_index] + + # we only consider samples forward in time with respect to the max channel + template_below = template[max_sample_idx:, channels_below] + channel_locations_below = channel_locations[channels_below] + + peak_times_ms_below = np.argmin(template_below, 0) / sampling_frequency * 1000 + distances_um_below = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations_below]) + velocity_below, intercept, score = fit_velocity(peak_times_ms_below, distances_um_below) + + # if DEBUG: + # fig, ax = plt.subplots() + # ax.plot(peak_times_ms_below, distances_um_below, "o") + # x = np.linspace(peak_times_ms_below.min(), peak_times_ms_below.max(), 20) + # ax.plot(x, intercept + x * velocity_below) + # ax.set_xlabel("Peak time (ms)") + # ax.set_ylabel("Distance from max channel (um)") + # ax.set_title(f"Velocity below: {np.round(velocity_below, 3)} um/ms") + + if np.sum(channels_below) < min_channels_for_velocity: + # if DEBUG: + # ax.set_title("NaN velocity - not enough channels") + return np.nan + + if score < min_r2_for_velocity: + # if DEBUG: + # ax.set_title(f"NaN velocity - R2 is too low: {np.round(score, 3)}") + return np.nan + + return velocity_below + + +def get_exp_decay(template, channel_locations, **kwargs): + """ + Compute the exponential decay of the template amplitude over distance. + + Parameters + ---------- + template: numpy.ndarray + The template waveform (num_samples, num_channels) + channel_locations: numpy.ndarray + The channel locations (num_channels, 2) + **kwargs: Required kwargs: + - exp_peak_function: the function to use to compute the peak amplitude for the exp decay ("ptp" or "min") + """ + from scipy.optimize import curve_fit + + def exp_decay(x, a, b, c): + return a * np.exp(-b * x) + c + + assert "exp_peak_function" in kwargs, "exp_peak_function must be given as kwarg" + exp_peak_function = kwargs["exp_peak_function"] + # exp decay fit + if exp_peak_function == "ptp": + fun = np.ptp + elif exp_peak_function == "min": + fun = np.min + peak_amplitudes = np.abs(fun(template, axis=0)) + max_channel_location = channel_locations[np.argmax(peak_amplitudes)] + channel_distances = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations]) + distances_sort_indices = np.argsort(channel_distances) + channel_distances_sorted = channel_distances[distances_sort_indices] + peak_amplitudes_sorted = peak_amplitudes[distances_sort_indices] + try: + popt, _ = curve_fit(exp_decay, channel_distances_sorted, peak_amplitudes_sorted) + exp_decay_value = popt[1] + # if DEBUG: + # fig, ax = plt.subplots() + # ax.plot(channel_distances_sorted, peak_amplitudes_sorted, "o") + # x = np.arange(channel_distances_sorted.min(), channel_distances_sorted.max()) + # ax.plot(x, exp_decay(x, *popt)) + # ax.set_xlabel("Distance from max channel (um)") + # ax.set_ylabel("Peak amplitude") + # ax.set_title(f"Exp decay: {np.round(exp_decay_value, 3)}") + except: + exp_decay_value = np.nan + return exp_decay_value + + +def get_spread(template, channel_locations, **kwargs): + """ + Compute the spread of the template amplitude over distance. + Parameters + ---------- + template: numpy.ndarray + The template waveform (num_samples, num_channels) + channel_locations: numpy.ndarray + The channel locations (num_channels, 2) + **kwargs: Required kwargs: + - depth_direction: the direction to compute velocity above and below ("x", "y", or "z") + - spread_threshold: the threshold to compute the spread + """ + assert "depth_direction" in kwargs, "depth_direction must be given as kwarg" + depth_direction = kwargs["depth_direction"] + assert "spread_threshold" in kwargs, "spread_threshold must be given as kwarg" + spread_threshold = kwargs["spread_threshold"] + + direction_index = ["x", "y", "z"].index(depth_direction) + MM = np.ptp(template, 0) + MM = MM / np.max(MM) + channel_locations_above_theshold = channel_locations[MM > spread_threshold] + channel_depth_above_theshold = channel_locations_above_theshold[:, direction_index] + spread = np.ptp(channel_depth_above_theshold) + + # if DEBUG: + # fig, ax = plt.subplots() + # channel_depths = channel_locations[:, direction_index] + # sort_indices = np.argsort(channel_depths) + # ax.plot(channel_depths[sort_indices], MM[sort_indices], "o-") + # ax.axhline(spread_threshold, ls="--", color="r") + # ax.set_xlabel("Depth (um)") + # ax.set_ylabel("Amplitude") + # ax.set_title(f"Spread: {np.round(spread, 3)} um") + return spread + + +_2d_metric_name_to_func = { + "velocity_above": get_velocity_above, + "velocity_below": get_velocity_below, + "exp_decay": get_exp_decay, + "spread": get_spread, +} -calculate_template_metrics.__doc__ = compute_template_metrics.__doc__ +_metric_name_to_func = {**_1d_metric_name_to_func, **_2d_metric_name_to_func} diff --git a/src/spikeinterface/postprocessing/tests/test_template_metrics.py b/src/spikeinterface/postprocessing/tests/test_template_metrics.py index 9895e2ec4c..5dcff3ffba 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_metrics.py +++ b/src/spikeinterface/postprocessing/tests/test_template_metrics.py @@ -17,9 +17,13 @@ def test_sparse_metrics(self): tm_sparse = self.extension_class.get_extension_function()(self.we1, sparsity=self.sparsity1) print(tm_sparse) + def test_2d_metrics(self): + tm_2d = self.extension_class.get_extension_function()(self.we1, include_2d_metrics=True) + print(tm_2d) + if __name__ == "__main__": test = TemplateMetricsExtensionTest() test.setUp() - test.test_extension() - test.test_sparse_metrics() + # test.test_extension() + test.test_2d_metrics() From 226ad852e25596c0f6072f48a72e2e3d4a84afab Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 22 Sep 2023 12:32:33 +0200 Subject: [PATCH 02/35] Update tests --- .../postprocessing/tests/test_template_metrics.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/postprocessing/tests/test_template_metrics.py b/src/spikeinterface/postprocessing/tests/test_template_metrics.py index 5dcff3ffba..a27ccc77f8 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_metrics.py +++ b/src/spikeinterface/postprocessing/tests/test_template_metrics.py @@ -17,13 +17,13 @@ def test_sparse_metrics(self): tm_sparse = self.extension_class.get_extension_function()(self.we1, sparsity=self.sparsity1) print(tm_sparse) - def test_2d_metrics(self): - tm_2d = self.extension_class.get_extension_function()(self.we1, include_2d_metrics=True) - print(tm_2d) + def test_multi_channel_metrics(self): + tm_multi = self.extension_class.get_extension_function()(self.we1, include_multi_channel_metrics=True) + print(tm_multi) if __name__ == "__main__": test = TemplateMetricsExtensionTest() test.setUp() # test.test_extension() - test.test_2d_metrics() + test.test_multi_channel_metrics() From 00f91eb99de0052daf6ae67a47026e1490bcd278 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 25 Sep 2023 12:02:51 +0200 Subject: [PATCH 03/35] Do not save/overwrite params in read-only mode --- src/spikeinterface/core/waveform_extractor.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index 6881ab3ec5..9f85603e51 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -1988,6 +1988,9 @@ def set_params(self, **params): params = self._set_params(**params) self._params = params + if self.waveform_extractor.is_read_only(): + return + params_to_save = params.copy() if "sparsity" in params and params["sparsity"] is not None: assert isinstance( From 7ba84ad7d9913b4846d9d6903a13a1f441156647 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 28 Sep 2023 12:25:29 +0200 Subject: [PATCH 04/35] updates --- src/spikeinterface/core/waveform_extractor.py | 24 +- .../postprocessing/template_metrics.py | 343 +++++++++++------- 2 files changed, 239 insertions(+), 128 deletions(-) diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index 9f85603e51..79456a40ce 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -811,14 +811,30 @@ def select_units(self, unit_ids, new_folder=None, use_relative_path: bool = Fals sparsity = ChannelSparsity(mask, unit_ids, self.channel_ids) else: sparsity = None - we = WaveformExtractor.create(self.recording, sorting, folder=None, mode="memory", sparsity=sparsity) - we.set_params(**self._params) + if self.has_recording(): + we = WaveformExtractor.create(self.recording, sorting, folder=None, mode="memory", sparsity=sparsity) + else: + we = WaveformExtractor( + recording=None, + sorting=sorting, + folder=None, + sparsity=sparsity, + rec_attributes=self._rec_attributes, + allow_unfiltered=True, + ) + we._params = self._params # copy memory objects if self.has_waveforms(): we._memory_objects = {"wfs_arrays": {}, "sampled_indices": {}} for unit_id in unit_ids: - we._memory_objects["wfs_arrays"][unit_id] = self._memory_objects["wfs_arrays"][unit_id] - we._memory_objects["sampled_indices"][unit_id] = self._memory_objects["sampled_indices"][unit_id] + if self.format == "memory": + we._memory_objects["wfs_arrays"][unit_id] = self._memory_objects["wfs_arrays"][unit_id] + we._memory_objects["sampled_indices"][unit_id] = self._memory_objects["sampled_indices"][ + unit_id + ] + else: + we._memory_objects["wfs_arrays"][unit_id] = self.get_waveforms(unit_id) + we._memory_objects["sampled_indices"][unit_id] = self.get_sampled_indices(unit_id) # finally select extensions data for ext_name in self.get_available_extension_names(): diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index ea44dea9cb..090dae4567 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -11,12 +11,9 @@ from ..core.waveform_extractor import BaseWaveformExtractorExtension import warnings -# DEBUG = True -# if DEBUG: -# import matplotlib.pyplot as plt -# plt.ion() -# plt.show() +global DEBUG +DEBUG = False def get_single_channel_template_metric_names(): @@ -52,20 +49,20 @@ def _set_params( peak_sign="neg", upsampling_factor=10, sparsity=None, - functions_kwargs=None, + metrics_kwargs=None, include_multi_channel_metrics=False, ): if metric_names is None: metric_names = get_single_channel_template_metric_names() if include_multi_channel_metrics: metric_names += get_multi_channel_template_metric_names() - functions_kwargs = functions_kwargs or dict() + metrics_kwargs = metrics_kwargs or dict() params = dict( metric_names=[str(name) for name in metric_names], sparsity=sparsity, peak_sign=peak_sign, upsampling_factor=int(upsampling_factor), - functions_kwargs=functions_kwargs, + metrics_kwargs=metrics_kwargs, ) return params @@ -141,7 +138,7 @@ def _run(self): sampling_frequency=sampling_frequency_up, trough_idx=trough_idx, peak_idx=peak_idx, - **self._params["functions_kwargs"], + **self._params["metrics_kwargs"], ) template_metrics.at[index, metric_name] = value @@ -173,7 +170,7 @@ def _run(self): template_upsampled, channel_locations=channel_locations_sparse, sampling_frequency=sampling_frequency_up, - **self._params["functions_kwargs"], + **self._params["metrics_kwargs"], ) template_metrics.at[index, metric_name] = value self._extension_data["metrics"] = template_metrics @@ -199,6 +196,21 @@ def get_extension_function(): WaveformExtractor.register_extension(TemplateMetricsCalculator) +_default_function_kwargs = dict( + recovery_window_ms=0.7, + peak_relative_threshold=0.2, + peak_width_ms=0.1, + depth_direction="y", + min_channels_for_velocity=5, + min_r2_velocity=0.5, + exp_peak_function="ptp", + min_r2_exp_decay=0.5, + spread_threshold=0.2, + spread_smooth_um=20, + same_x=False, +) + + def compute_template_metrics( waveform_extractor, load_if_exists=False, @@ -207,16 +219,8 @@ def compute_template_metrics( upsampling_factor=10, sparsity=None, include_multi_channel_metrics=False, - functions_kwargs=dict( - recovery_window_ms=0.7, - peak_relative_threshold=0.2, - peak_width_ms=0.2, - depth_direction="y", - min_channels_for_velocity=5, - min_r2_for_velocity=0.5, - exp_peak_function="ptp", - spread_threshold=0.2, - ), + metrics_kwargs=None, + debug_plots=False, ): """ Compute template metrics including: @@ -252,14 +256,14 @@ def compute_template_metrics( For more generating a sparsity dict, see the postprocessing.compute_sparsity() function. include_multi_channel_metrics: bool, default: False Whether to compute multi-channel metrics - functions_kwargs: dict + metrics_kwargs: dict Additional arguments to pass to the metric functions. Including: * recovery_window_ms: the window in ms after the peak to compute the recovery_slope, default: 0.7 * peak_relative_threshold: the relative threshold to detect positive and negative peaks, default: 0.2 * peak_width_ms: the width in samples to detect peaks, default: 0.2 * depth_direction: the direction to compute velocity above and below, default: "y" * min_channels_for_velocity: the minimum number of channels above or below to compute velocity, default: 5 - * min_r2_for_velocity: the minimum r2 to accept the velocity fit, default: 0.7 + * min_r2_velocity: the minimum r2 to accept the velocity fit, default: 0.7 * exp_peak_function: the function to use to compute the peak amplitude for the exp decay, default: "ptp" * spread_threshold: the threshold to compute the spread, default: 0.2 @@ -275,6 +279,9 @@ def compute_template_metrics( If any multi-channel metric is in the metric_names or include_multi_channel_metrics is True, sparsity must be None, so that one metric value will be computed per unit. """ + if debug_plots: + global DEBUG + DEBUG = True if load_if_exists and waveform_extractor.is_extension(TemplateMetricsCalculator.extension_name): tmc = waveform_extractor.load_extension(TemplateMetricsCalculator.extension_name) else: @@ -287,13 +294,19 @@ def compute_template_metrics( "If multi-channel metrics are computed, sparsity must be None, " "so that each unit will correspond to 1 row of the output dataframe." ) + default_kwargs = _default_function_kwargs.copy() + if metrics_kwargs is None: + metrics_kwargs = default_kwargs + else: + default_kwargs.update(metrics_kwargs) + metrics_kwargs = default_kwargs tmc.set_params( metric_names=metric_names, peak_sign=peak_sign, upsampling_factor=upsampling_factor, sparsity=sparsity, include_multi_channel_metrics=include_multi_channel_metrics, - functions_kwargs=functions_kwargs, + metrics_kwargs=metrics_kwargs, ) tmc.run() @@ -328,7 +341,7 @@ def get_trough_and_peak_idx(template): ######################################################################################### # Single-channel metrics -def get_peak_to_valley(template_single, trough_idx=None, peak_idx=None, **kwargs): +def get_peak_to_valley(template_single, sampling_frequency, trough_idx=None, peak_idx=None, **kwargs): """ Return the peak to valley duration in seconds of input waveforms. @@ -340,22 +353,19 @@ def get_peak_to_valley(template_single, trough_idx=None, peak_idx=None, **kwargs The index of the trough peak_idx: int, default: None The index of the peak - **kwargs: Required kwargs: - - sampling_frequency: the sampling frequency Returns ------- ptv: float The peak to valley duration in seconds """ - sampling_frequency = kwargs["sampling_frequency"] if trough_idx is None or peak_idx is None: trough_idx, peak_idx = get_trough_and_peak_idx(template_single) ptv = (peak_idx - trough_idx) / sampling_frequency return ptv -def get_peak_trough_ratio(template_single, trough_idx=None, peak_idx=None, **kwargs): +def get_peak_trough_ratio(template_single, sampling_frequency=None, trough_idx=None, peak_idx=None, **kwargs): """ Return the peak to trough ratio of input waveforms. @@ -367,8 +377,6 @@ def get_peak_trough_ratio(template_single, trough_idx=None, peak_idx=None, **kwa The index of the trough peak_idx: int, default: None The index of the peak - **kwargs: Required kwargs: - - sampling_frequency: the sampling frequency Returns ------- @@ -381,7 +389,7 @@ def get_peak_trough_ratio(template_single, trough_idx=None, peak_idx=None, **kwa return ptratio -def get_half_width(template_single, trough_idx=None, peak_idx=None, **kwargs): +def get_half_width(template_single, sampling_frequency, trough_idx=None, peak_idx=None, **kwargs): """ Return the half width of input waveforms in seconds. @@ -393,8 +401,6 @@ def get_half_width(template_single, trough_idx=None, peak_idx=None, **kwargs): The index of the trough peak_idx: int, default: None The index of the peak - **kwargs: Required kwargs: - - sampling_frequency: the sampling frequency Returns ------- @@ -403,7 +409,6 @@ def get_half_width(template_single, trough_idx=None, peak_idx=None, **kwargs): """ if trough_idx is None or peak_idx is None: trough_idx, peak_idx = get_trough_and_peak_idx(template_single) - sampling_frequency = kwargs["sampling_frequency"] if peak_idx == 0: return np.nan @@ -428,7 +433,7 @@ def get_half_width(template_single, trough_idx=None, peak_idx=None, **kwargs): return hw -def get_repolarization_slope(template_single, trough_idx=None, **kwargs): +def get_repolarization_slope(template_single, sampling_frequency, trough_idx=None, **kwargs): """ Return slope of repolarization period between trough and baseline @@ -445,12 +450,9 @@ def get_repolarization_slope(template_single, trough_idx=None, **kwargs): The 1D template waveform trough_idx: int, default: None The index of the trough - **kwargs: Required kwargs: - - sampling_frequency: the sampling frequency """ if trough_idx is None: trough_idx = get_trough_and_peak_idx(template_single) - sampling_frequency = kwargs["sampling_frequency"] times = np.arange(template_single.shape[0]) / sampling_frequency @@ -472,7 +474,7 @@ def get_repolarization_slope(template_single, trough_idx=None, **kwargs): return res.slope -def get_recovery_slope(template_single, peak_idx=None, **kwargs): +def get_recovery_slope(template_single, sampling_frequency, peak_idx=None, **kwargs): """ Return the recovery slope of input waveforms. After repolarization, the neuron hyperpolarizes untill it peaks. The recovery slope is the @@ -490,7 +492,6 @@ def get_recovery_slope(template_single, peak_idx=None, **kwargs): peak_idx: int, default: None The index of the peak **kwargs: Required kwargs: - - sampling_frequency: the sampling frequency - recovery_window_ms: the window in ms after the peak to compute the recovery_slope """ import scipy.stats @@ -499,7 +500,6 @@ def get_recovery_slope(template_single, peak_idx=None, **kwargs): recovery_window_ms = kwargs["recovery_window_ms"] if peak_idx is None: _, peak_idx = get_trough_and_peak_idx(template_single) - sampling_frequency = kwargs["sampling_frequency"] times = np.arange(template_single.shape[0]) / sampling_frequency @@ -512,7 +512,7 @@ def get_recovery_slope(template_single, peak_idx=None, **kwargs): return res.slope -def get_num_positive_peaks(template_single, **kwargs): +def get_num_positive_peaks(template_single, sampling_frequency, **kwargs): """ Count the number of positive peaks in the template. @@ -523,7 +523,6 @@ def get_num_positive_peaks(template_single, **kwargs): **kwargs: Required kwargs: - peak_relative_threshold: the relative threshold to detect positive and negative peaks - peak_width_ms: the width in samples to detect peaks - - sampling_frequency: the sampling frequency """ from scipy.signal import find_peaks @@ -532,14 +531,14 @@ def get_num_positive_peaks(template_single, **kwargs): peak_relative_threshold = kwargs["peak_relative_threshold"] peak_width_ms = kwargs["peak_width_ms"] max_value = np.max(np.abs(template_single)) - peak_width_samples = int(peak_width_ms / 1000 * kwargs["sampling_frequency"]) + peak_width_samples = int(peak_width_ms / 1000 * sampling_frequency) pos_peaks = find_peaks(template_single, height=peak_relative_threshold * max_value, width=peak_width_samples) return len(pos_peaks[0]) -def get_num_negative_peaks(template_single, **kwargs): +def get_num_negative_peaks(template_single, sampling_frequency, **kwargs): """ Count the number of negative peaks in the template. @@ -550,7 +549,6 @@ def get_num_negative_peaks(template_single, **kwargs): **kwargs: Required kwargs: - peak_relative_threshold: the relative threshold to detect positive and negative peaks - peak_width_ms: the width in samples to detect peaks - - sampling_frequency: the sampling frequency """ from scipy.signal import find_peaks @@ -559,7 +557,7 @@ def get_num_negative_peaks(template_single, **kwargs): peak_relative_threshold = kwargs["peak_relative_threshold"] peak_width_ms = kwargs["peak_width_ms"] max_value = np.max(np.abs(template_single)) - peak_width_samples = int(peak_width_ms / 1000 * kwargs["sampling_frequency"]) + peak_width_samples = int(peak_width_ms / 1000 * sampling_frequency) neg_peaks = find_peaks(-template_single, height=peak_relative_threshold * max_value, width=peak_width_samples) @@ -581,6 +579,20 @@ def get_num_negative_peaks(template_single, **kwargs): # Multi-channel metrics +def transform_same_x(template, channel_locations): + max_channel_x = channel_locations[np.argmax(np.ptp(template, axis=0)), 0] + same_x_mask = channel_locations[:, 0] == max_channel_x + channel_locations_same_x = channel_locations[same_x_mask] + template_same_x = template[:, same_x_mask] + return template_same_x, channel_locations_same_x + + +def sort_template_and_locations(template, channel_locations, depth_direction="y"): + direction_index = ["x", "y", "z"].index(depth_direction) + sort_indices = np.argsort(channel_locations[:, direction_index]) + return template[:, sort_indices], channel_locations[sort_indices, :] + + def fit_velocity(peak_times, channel_dist): # from scipy.stats import linregress # slope, intercept, _, _, _ = linregress(peak_times, channel_dist) @@ -595,7 +607,7 @@ def fit_velocity(peak_times, channel_dist): return slope, intercept, score -def get_velocity_above(template, channel_locations, **kwargs): +def get_velocity_above(template, channel_locations, sampling_frequency, **kwargs): """ Compute the velocity above the max channel of the template. @@ -608,56 +620,70 @@ def get_velocity_above(template, channel_locations, **kwargs): **kwargs: Required kwargs: - depth_direction: the direction to compute velocity above and below ("x", "y", or "z") - min_channels_for_velocity: the minimum number of channels above or below to compute velocity - - min_r2_for_velocity: the minimum r2 to accept the velocity fit - - sampling_frequency: the sampling frequency + - min_r2_velocity: the minimum r2 to accept the velocity fit """ assert "depth_direction" in kwargs, "depth_direction must be given as kwarg" assert "min_channels_for_velocity" in kwargs, "min_channels_for_velocity must be given as kwarg" - assert "min_r2_for_velocity" in kwargs, "min_r2_for_velocity must be given as kwarg" + assert "min_r2_velocity" in kwargs, "min_r2_velocity must be given as kwarg" + assert "same_x" in kwargs, "same_x must be given as kwarg" depth_direction = kwargs["depth_direction"] min_channels_for_velocity = kwargs["min_channels_for_velocity"] - min_r2_for_velocity = kwargs["min_r2_for_velocity"] + min_r2_velocity = kwargs["min_r2_velocity"] + same_x = kwargs["same_x"] direction_index = ["x", "y", "z"].index(depth_direction) - sampling_frequency = kwargs["sampling_frequency"] + template, channel_locations = sort_template_and_locations(template, channel_locations, depth_direction) + + if same_x: + template, channel_locations = transform_same_x(template, channel_locations) # find location of max channel max_sample_idx, max_channel_idx = np.unravel_index(np.argmin(template), template.shape) + max_peak_time = max_sample_idx / sampling_frequency * 1000 max_channel_location = channel_locations[max_channel_idx] channels_above = channel_locations[:, direction_index] >= max_channel_location[direction_index] # we only consider samples forward in time with respect to the max channel - template_above = template[max_sample_idx:, channels_above] + # TODO: not sure + # template_above = template[max_sample_idx:, channels_above] + template_above = template[:, channels_above] channel_locations_above = channel_locations[channels_above] - peak_times_ms_above = np.argmin(template_above, 0) / sampling_frequency * 1000 + peak_times_ms_above = np.argmin(template_above, 0) / sampling_frequency * 1000 - max_peak_time distances_um_above = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations_above]) velocity_above, intercept, score = fit_velocity(peak_times_ms_above, distances_um_above) - # if DEBUG: - # fig, ax = plt.subplots() - # ax.plot(peak_times_ms_above, distances_um_above, "o") - # x = np.linspace(peak_times_ms_above.min(), peak_times_ms_above.max(), 20) - # ax.plot(x, intercept + x * velocity_above) - # ax.set_xlabel("Peak time (ms)") - # ax.set_ylabel("Distance from max channel (um)") - # ax.set_title(f"Velocity above: {velocity_above:.2f} um/ms") - - if np.sum(channels_above) < min_channels_for_velocity: - # if DEBUG: - # ax.set_title("NaN velocity - not enough channels") - return np.nan + global DEBUG + if DEBUG: + import matplotlib.pyplot as plt + + fig, axs = plt.subplots(ncols=2, figsize=(10, 7)) + offset = 1.2 * np.max(np.ptp(template, axis=0)) + ts = np.arange(template.shape[0]) / sampling_frequency * 1000 - max_peak_time + (channel_indices_above,) = np.nonzero(channels_above) + for i, single_template in enumerate(template.T): + color = "r" if i in channel_indices_above else "k" + axs[0].plot(ts, single_template + i * offset, color=color) + axs[0].axvline(0, color="g", ls="--") + axs[1].plot(peak_times_ms_above, distances_um_above, "o") + x = np.linspace(peak_times_ms_above.min(), peak_times_ms_above.max(), 20) + axs[1].plot(x, intercept + x * velocity_above) + axs[1].set_xlabel("Peak time (ms)") + axs[1].set_ylabel("Distance from max channel (um)") + fig.suptitle( + f"Velocity above: {velocity_above:.2f} um/ms - score {score:.2f} - channels: {np.sum(channels_above)}" + ) + plt.show() + + if np.sum(channels_above) < min_channels_for_velocity or score < min_r2_velocity: + velocity_above = np.nan - if score < min_r2_for_velocity: - # if DEBUG: - # ax.set_title(f"NaN velocity - R2 is too low: {score:.2f}") - return np.nan return velocity_above -def get_velocity_below(template, channel_locations, **kwargs): +def get_velocity_below(template, channel_locations, sampling_frequency, **kwargs): """ Compute the velocity below the max channel of the template. @@ -670,55 +696,70 @@ def get_velocity_below(template, channel_locations, **kwargs): **kwargs: Required kwargs: - depth_direction: the direction to compute velocity above and below ("x", "y", or "z") - min_channels_for_velocity: the minimum number of channels above or below to compute velocity - - min_r2_for_velocity: the minimum r2 to accept the velocity fit - - sampling_frequency: the sampling frequency + - min_r2_velocity: the minimum r2 to accept the velocity fit + - same_x: whether to transform the template and channel locations to have the same x coordinate """ assert "depth_direction" in kwargs, "depth_direction must be given as kwarg" assert "min_channels_for_velocity" in kwargs, "min_channels_for_velocity must be given as kwarg" - assert "min_r2_for_velocity" in kwargs, "min_r2_for_velocity must be given as kwarg" - direction = kwargs["depth_direction"] + assert "min_r2_velocity" in kwargs, "min_r2_velocity must be given as kwarg" + assert "same_x" in kwargs, "same_x must be given as kwarg" + + depth_direction = kwargs["depth_direction"] min_channels_for_velocity = kwargs["min_channels_for_velocity"] - min_r2_for_velocity = kwargs["min_r2_for_velocity"] - direction_index = ["x", "y", "z"].index(direction) + min_r2_velocity = kwargs["min_r2_velocity"] + same_x = kwargs["same_x"] + + direction_index = ["x", "y", "z"].index(depth_direction) + template, channel_locations = sort_template_and_locations(template, channel_locations, depth_direction) + + if same_x: + template, channel_locations = transform_same_x(template, channel_locations) # find location of max channel max_sample_idx, max_channel_idx = np.unravel_index(np.argmin(template), template.shape) + max_peak_time = max_sample_idx / sampling_frequency * 1000 max_channel_location = channel_locations[max_channel_idx] - sampling_frequency = kwargs["sampling_frequency"] channels_below = channel_locations[:, direction_index] <= max_channel_location[direction_index] # we only consider samples forward in time with respect to the max channel - template_below = template[max_sample_idx:, channels_below] + # template_below = template[max_sample_idx:, channels_below] + template_below = template[:, channels_below] channel_locations_below = channel_locations[channels_below] - peak_times_ms_below = np.argmin(template_below, 0) / sampling_frequency * 1000 + peak_times_ms_below = np.argmin(template_below, 0) / sampling_frequency * 1000 - max_peak_time distances_um_below = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations_below]) velocity_below, intercept, score = fit_velocity(peak_times_ms_below, distances_um_below) - # if DEBUG: - # fig, ax = plt.subplots() - # ax.plot(peak_times_ms_below, distances_um_below, "o") - # x = np.linspace(peak_times_ms_below.min(), peak_times_ms_below.max(), 20) - # ax.plot(x, intercept + x * velocity_below) - # ax.set_xlabel("Peak time (ms)") - # ax.set_ylabel("Distance from max channel (um)") - # ax.set_title(f"Velocity below: {np.round(velocity_below, 3)} um/ms") - - if np.sum(channels_below) < min_channels_for_velocity: - # if DEBUG: - # ax.set_title("NaN velocity - not enough channels") - return np.nan + global DEBUG + if DEBUG: + import matplotlib.pyplot as plt + + fig, axs = plt.subplots(ncols=2, figsize=(10, 7)) + offset = 1.2 * np.max(np.ptp(template, axis=0)) + ts = np.arange(template.shape[0]) / sampling_frequency * 1000 - max_peak_time + (channel_indices_below,) = np.nonzero(channels_below) + for i, single_template in enumerate(template.T): + color = "r" if i in channel_indices_below else "k" + axs[0].plot(ts, single_template + i * offset, color=color) + axs[0].axvline(0, color="g", ls="--") + axs[1].plot(peak_times_ms_below, distances_um_below, "o") + x = np.linspace(peak_times_ms_below.min(), peak_times_ms_below.max(), 20) + axs[1].plot(x, intercept + x * velocity_below) + axs[1].set_xlabel("Peak time (ms)") + axs[1].set_ylabel("Distance from max channel (um)") + fig.suptitle( + f"Velocity below: {np.round(velocity_below, 3)} um/ms - score {score:.2f} - channels: {np.sum(channels_below)}" + ) + plt.show() - if score < min_r2_for_velocity: - # if DEBUG: - # ax.set_title(f"NaN velocity - R2 is too low: {np.round(score, 3)}") - return np.nan + if np.sum(channels_below) < min_channels_for_velocity or score < min_r2_velocity: + velocity_below = np.nan return velocity_below -def get_exp_decay(template, channel_locations, **kwargs): +def get_exp_decay(template, channel_locations, sampling_frequency=None, **kwargs): """ Compute the exponential decay of the template amplitude over distance. @@ -730,14 +771,18 @@ def get_exp_decay(template, channel_locations, **kwargs): The channel locations (num_channels, 2) **kwargs: Required kwargs: - exp_peak_function: the function to use to compute the peak amplitude for the exp decay ("ptp" or "min") + - min_r2_exp_decay: the minimum r2 to accept the exp decay fit """ from scipy.optimize import curve_fit + from sklearn.metrics import r2_score - def exp_decay(x, a, b, c): - return a * np.exp(-b * x) + c + def exp_decay(x, decay, amp0, offset): + return amp0 * np.exp(-decay * x) + offset assert "exp_peak_function" in kwargs, "exp_peak_function must be given as kwarg" exp_peak_function = kwargs["exp_peak_function"] + assert "min_r2_exp_decay" in kwargs, "min_r2_exp_decay must be given as kwarg" + min_r2_exp_decay = kwargs["min_r2_exp_decay"] # exp decay fit if exp_peak_function == "ptp": fun = np.ptp @@ -747,25 +792,49 @@ def exp_decay(x, a, b, c): max_channel_location = channel_locations[np.argmax(peak_amplitudes)] channel_distances = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations]) distances_sort_indices = np.argsort(channel_distances) - channel_distances_sorted = channel_distances[distances_sort_indices] - peak_amplitudes_sorted = peak_amplitudes[distances_sort_indices] + # np.float128 avoids overflow error + channel_distances_sorted = channel_distances[distances_sort_indices].astype(np.float128) + peak_amplitudes_sorted = peak_amplitudes[distances_sort_indices].astype(np.float128) try: - popt, _ = curve_fit(exp_decay, channel_distances_sorted, peak_amplitudes_sorted) - exp_decay_value = popt[1] - # if DEBUG: - # fig, ax = plt.subplots() - # ax.plot(channel_distances_sorted, peak_amplitudes_sorted, "o") - # x = np.arange(channel_distances_sorted.min(), channel_distances_sorted.max()) - # ax.plot(x, exp_decay(x, *popt)) - # ax.set_xlabel("Distance from max channel (um)") - # ax.set_ylabel("Peak amplitude") - # ax.set_title(f"Exp decay: {np.round(exp_decay_value, 3)}") + amp0 = peak_amplitudes_sorted[0] + offset0 = np.min(peak_amplitudes_sorted) + + popt, _ = curve_fit( + exp_decay, + channel_distances_sorted, + peak_amplitudes_sorted, + bounds=([1e-5, amp0 - 0.5 * amp0, 0], [2, amp0 + 0.5 * amp0, 2 * offset0]), + p0=[1e-3, peak_amplitudes_sorted[0], offset0], + ) + r2 = r2_score(peak_amplitudes_sorted, exp_decay(channel_distances_sorted, *popt)) + exp_decay_value = popt[0] + + if r2 < min_r2_exp_decay: + exp_decay_value = np.nan + + global DEBUG + if DEBUG: + import matplotlib.pyplot as plt + + fig, ax = plt.subplots() + ax.plot(channel_distances_sorted, peak_amplitudes_sorted, "o") + x = np.arange(channel_distances_sorted.min(), channel_distances_sorted.max()) + ax.plot(x, exp_decay(x, *popt)) + ax.set_xlabel("Distance from max channel (um)") + ax.set_ylabel("Peak amplitude") + ax.set_title( + f"Exp decay: {np.round(exp_decay_value, 3)} - Amp: {np.round(popt[1], 3)} - Offset: {np.round(popt[2], 3)} - " + f"R2: {np.round(r2, 4)}" + ) + fig.suptitle("Exp decay") + plt.show() except: exp_decay_value = np.nan + return exp_decay_value -def get_spread(template, channel_locations, **kwargs): +def get_spread(template, channel_locations, sampling_frequency, **kwargs): """ Compute the spread of the template amplitude over distance. @@ -783,23 +852,49 @@ def get_spread(template, channel_locations, **kwargs): depth_direction = kwargs["depth_direction"] assert "spread_threshold" in kwargs, "spread_threshold must be given as kwarg" spread_threshold = kwargs["spread_threshold"] + assert "spread_smooth_um" in kwargs, "spread_smooth_um must be given as kwarg" + spread_smooth_um = kwargs["spread_smooth_um"] + assert "same_x" in kwargs, "same_x must be given as kwarg" + same_x = kwargs["same_x"] direction_index = ["x", "y", "z"].index(depth_direction) + template, channel_locations = sort_template_and_locations(template, channel_locations, depth_direction) + + if same_x: + template, channel_locations = transform_same_x(template, channel_locations) MM = np.ptp(template, 0) MM = MM / np.max(MM) + channel_depths = channel_locations[:, direction_index] + + if spread_smooth_um is not None and spread_smooth_um > 0: + from scipy.ndimage import gaussian_filter1d + + spread_sigma = spread_smooth_um / np.median(np.diff(np.unique(channel_depths))) + MM = gaussian_filter1d(MM, spread_sigma) + channel_locations_above_theshold = channel_locations[MM > spread_threshold] channel_depth_above_theshold = channel_locations_above_theshold[:, direction_index] spread = np.ptp(channel_depth_above_theshold) - # if DEBUG: - # fig, ax = plt.subplots() - # channel_depths = channel_locations[:, direction_index] - # sort_indices = np.argsort(channel_depths) - # ax.plot(channel_depths[sort_indices], MM[sort_indices], "o-") - # ax.axhline(spread_threshold, ls="--", color="r") - # ax.set_xlabel("Depth (um)") - # ax.set_ylabel("Amplitude") - # ax.set_title(f"Spread: {np.round(spread, 3)} um") + global DEBUG + if DEBUG: + import matplotlib.pyplot as plt + + fig, axs = plt.subplots(ncols=2, figsize=(10, 7)) + axs[0].imshow( + template.T, + aspect="auto", + origin="lower", + extent=[0, template.shape[0] / sampling_frequency, channel_depths[0], channel_depths[1]], + ) + axs[1].plot(channel_depths, MM, "o-") + axs[1].axhline(spread_threshold, ls="--", color="r") + axs[1].set_xlabel("Depth (um)") + axs[1].set_ylabel("Amplitude") + axs[1].set_title(f"Spread: {np.round(spread, 3)} um") + fig.suptitle("Spread") + plt.show() + return spread From c1cd889beacca66f43262f95e18033100f98d59d Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 29 Sep 2023 13:19:35 +0200 Subject: [PATCH 05/35] Add 'column_range' and simplify dimension handling --- .../postprocessing/template_metrics.py | 76 +++++++++++-------- 1 file changed, 44 insertions(+), 32 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 090dae4567..774ebab4a9 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -207,7 +207,7 @@ def get_extension_function(): min_r2_exp_decay=0.5, spread_threshold=0.2, spread_smooth_um=20, - same_x=False, + column_range=None, ) @@ -265,7 +265,13 @@ def compute_template_metrics( * min_channels_for_velocity: the minimum number of channels above or below to compute velocity, default: 5 * min_r2_velocity: the minimum r2 to accept the velocity fit, default: 0.7 * exp_peak_function: the function to use to compute the peak amplitude for the exp decay, default: "ptp" + * min_r2_exp_decay: the minimum r2 to accept the exp decay fit, default: 0.5 * spread_threshold: the threshold to compute the spread, default: 0.2 + * spread_smooth_um: the smoothing in um to compute the spread, default: 20 + * column_range: the range in um in the horizontal direction to consider channels for velocity, default: None + - If None, all channels all channels are considered + - If 0 or 1, only the "column" that includes the max channel is considered + - If > 1, only channels within range (+/-) um from the max channel horizontal position are used Returns ------- @@ -278,6 +284,7 @@ def compute_template_metrics( ----- If any multi-channel metric is in the metric_names or include_multi_channel_metrics is True, sparsity must be None, so that one metric value will be computed per unit. + For multi-channel metrocs, 3D channel locations are not supported. By default, the depth direction is "y". """ if debug_plots: global DEBUG @@ -294,6 +301,9 @@ def compute_template_metrics( "If multi-channel metrics are computed, sparsity must be None, " "so that each unit will correspond to 1 row of the output dataframe." ) + assert ( + waveform_extractor.get_channel_locations().shape[1] == 2 + ), "If multi-channel metrics are computed, channel locations must be 2D." default_kwargs = _default_function_kwargs.copy() if metrics_kwargs is None: metrics_kwargs = default_kwargs @@ -579,17 +589,22 @@ def get_num_negative_peaks(template_single, sampling_frequency, **kwargs): # Multi-channel metrics -def transform_same_x(template, channel_locations): - max_channel_x = channel_locations[np.argmax(np.ptp(template, axis=0)), 0] - same_x_mask = channel_locations[:, 0] == max_channel_x - channel_locations_same_x = channel_locations[same_x_mask] - template_same_x = template[:, same_x_mask] - return template_same_x, channel_locations_same_x +def transform_column_range(template, channel_locations, column_range, depth_direction="y"): + column_dim = 0 if depth_direction == "y" else 1 + if column_range is None: + template_column_range = template + channel_locations_column_range = channel_locations + else: + max_channel_x = channel_locations[np.argmax(np.ptp(template, axis=0)), 0] + column_mask = np.abs(channel_locations[:, column_dim] - max_channel_x) <= column_range + template_column_range = template[:, column_mask] + channel_locations_column_range = channel_locations[column_mask] + return template_column_range, channel_locations_column_range def sort_template_and_locations(template, channel_locations, depth_direction="y"): - direction_index = ["x", "y", "z"].index(depth_direction) - sort_indices = np.argsort(channel_locations[:, direction_index]) + depth_dim = 1 if depth_direction == "y" else 0 + sort_indices = np.argsort(channel_locations[:, depth_dim]) return template[:, sort_indices], channel_locations[sort_indices, :] @@ -621,29 +636,28 @@ def get_velocity_above(template, channel_locations, sampling_frequency, **kwargs - depth_direction: the direction to compute velocity above and below ("x", "y", or "z") - min_channels_for_velocity: the minimum number of channels above or below to compute velocity - min_r2_velocity: the minimum r2 to accept the velocity fit + - column_range: the range in um in the x-direction to consider channels for velocity """ assert "depth_direction" in kwargs, "depth_direction must be given as kwarg" assert "min_channels_for_velocity" in kwargs, "min_channels_for_velocity must be given as kwarg" assert "min_r2_velocity" in kwargs, "min_r2_velocity must be given as kwarg" - assert "same_x" in kwargs, "same_x must be given as kwarg" + assert "column_range" in kwargs, "column_range must be given as kwarg" depth_direction = kwargs["depth_direction"] min_channels_for_velocity = kwargs["min_channels_for_velocity"] min_r2_velocity = kwargs["min_r2_velocity"] - same_x = kwargs["same_x"] + column_range = kwargs["column_range"] - direction_index = ["x", "y", "z"].index(depth_direction) + depth_dim = 1 if depth_direction == "y" else 0 + template, channel_locations = transform_column_range(template, channel_locations, column_range, depth_direction) template, channel_locations = sort_template_and_locations(template, channel_locations, depth_direction) - if same_x: - template, channel_locations = transform_same_x(template, channel_locations) - # find location of max channel max_sample_idx, max_channel_idx = np.unravel_index(np.argmin(template), template.shape) max_peak_time = max_sample_idx / sampling_frequency * 1000 max_channel_location = channel_locations[max_channel_idx] - channels_above = channel_locations[:, direction_index] >= max_channel_location[direction_index] + channels_above = channel_locations[:, depth_dim] >= max_channel_location[depth_dim] # we only consider samples forward in time with respect to the max channel # TODO: not sure @@ -697,30 +711,28 @@ def get_velocity_below(template, channel_locations, sampling_frequency, **kwargs - depth_direction: the direction to compute velocity above and below ("x", "y", or "z") - min_channels_for_velocity: the minimum number of channels above or below to compute velocity - min_r2_velocity: the minimum r2 to accept the velocity fit - - same_x: whether to transform the template and channel locations to have the same x coordinate + - column_range: the range in um in the x-direction to consider channels for velocity """ assert "depth_direction" in kwargs, "depth_direction must be given as kwarg" assert "min_channels_for_velocity" in kwargs, "min_channels_for_velocity must be given as kwarg" assert "min_r2_velocity" in kwargs, "min_r2_velocity must be given as kwarg" - assert "same_x" in kwargs, "same_x must be given as kwarg" + assert "column_range" in kwargs, "column_range must be given as kwarg" depth_direction = kwargs["depth_direction"] min_channels_for_velocity = kwargs["min_channels_for_velocity"] min_r2_velocity = kwargs["min_r2_velocity"] - same_x = kwargs["same_x"] + column_range = kwargs["column_range"] - direction_index = ["x", "y", "z"].index(depth_direction) + depth_dim = 1 if depth_direction == "y" else 0 + template, channel_locations = transform_column_range(template, channel_locations, column_range) template, channel_locations = sort_template_and_locations(template, channel_locations, depth_direction) - if same_x: - template, channel_locations = transform_same_x(template, channel_locations) - # find location of max channel max_sample_idx, max_channel_idx = np.unravel_index(np.argmin(template), template.shape) max_peak_time = max_sample_idx / sampling_frequency * 1000 max_channel_location = channel_locations[max_channel_idx] - channels_below = channel_locations[:, direction_index] <= max_channel_location[direction_index] + channels_below = channel_locations[:, depth_dim] <= max_channel_location[depth_dim] # we only consider samples forward in time with respect to the max channel # template_below = template[max_sample_idx:, channels_below] @@ -847,6 +859,7 @@ def get_spread(template, channel_locations, sampling_frequency, **kwargs): **kwargs: Required kwargs: - depth_direction: the direction to compute velocity above and below ("x", "y", or "z") - spread_threshold: the threshold to compute the spread + - column_range: the range in um in the x-direction to consider channels for velocity """ assert "depth_direction" in kwargs, "depth_direction must be given as kwarg" depth_direction = kwargs["depth_direction"] @@ -854,17 +867,16 @@ def get_spread(template, channel_locations, sampling_frequency, **kwargs): spread_threshold = kwargs["spread_threshold"] assert "spread_smooth_um" in kwargs, "spread_smooth_um must be given as kwarg" spread_smooth_um = kwargs["spread_smooth_um"] - assert "same_x" in kwargs, "same_x must be given as kwarg" - same_x = kwargs["same_x"] + assert "column_range" in kwargs, "column_range must be given as kwarg" + column_range = kwargs["column_range"] - direction_index = ["x", "y", "z"].index(depth_direction) + depth_dim = 1 if depth_direction == "y" else 0 + template, channel_locations = transform_column_range(template, channel_locations, column_range) template, channel_locations = sort_template_and_locations(template, channel_locations, depth_direction) - if same_x: - template, channel_locations = transform_same_x(template, channel_locations) MM = np.ptp(template, 0) MM = MM / np.max(MM) - channel_depths = channel_locations[:, direction_index] + channel_depths = channel_locations[:, depth_dim] if spread_smooth_um is not None and spread_smooth_um > 0: from scipy.ndimage import gaussian_filter1d @@ -873,7 +885,7 @@ def get_spread(template, channel_locations, sampling_frequency, **kwargs): MM = gaussian_filter1d(MM, spread_sigma) channel_locations_above_theshold = channel_locations[MM > spread_threshold] - channel_depth_above_theshold = channel_locations_above_theshold[:, direction_index] + channel_depth_above_theshold = channel_locations_above_theshold[:, depth_dim] spread = np.ptp(channel_depth_above_theshold) global DEBUG @@ -885,7 +897,7 @@ def get_spread(template, channel_locations, sampling_frequency, **kwargs): template.T, aspect="auto", origin="lower", - extent=[0, template.shape[0] / sampling_frequency, channel_depths[0], channel_depths[1]], + extent=[0, template.shape[0] / sampling_frequency, channel_depths[0], channel_depths[-1]], ) axs[1].plot(channel_depths, MM, "o-") axs[1].axhline(spread_threshold, ls="--", color="r") From ac84b25530b04e30c80eba7c474be61279a7dd1f Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Sun, 1 Oct 2023 15:11:30 +0200 Subject: [PATCH 06/35] Fix docstrings --- .../postprocessing/template_metrics.py | 67 ++++++++++++++----- 1 file changed, 50 insertions(+), 17 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 774ebab4a9..82f55483b4 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -4,12 +4,13 @@ 22/04/2020 """ import numpy as np +import warnings +from typing import Optional from copy import deepcopy -from ..core import WaveformExtractor +from ..core import WaveformExtractor, ChannelSparsity from ..core.template_tools import get_template_extremum_channel from ..core.waveform_extractor import BaseWaveformExtractorExtension -import warnings global DEBUG @@ -211,16 +212,17 @@ def get_extension_function(): ) +# TODO: add typing def compute_template_metrics( waveform_extractor, - load_if_exists=False, - metric_names=None, - peak_sign="neg", - upsampling_factor=10, - sparsity=None, - include_multi_channel_metrics=False, - metrics_kwargs=None, - debug_plots=False, + load_if_exists: bool = False, + metric_names: Optional[list[str]] = None, + peak_sign: Optional[str] = "neg", + upsampling_factor: int = 10, + sparsity: Optional[ChannelSparsity] = None, + include_multi_channel_metrics: bool = False, + metrics_kwargs: dict = None, + debug_plots: bool = False, ): """ Compute template metrics including: @@ -247,13 +249,13 @@ def compute_template_metrics( metric_names : list, optional List of metrics to compute (see si.postprocessing.get_template_metric_names()), by default None peak_sign : {"neg", "pos"}, default: "neg" - The peak sign + Whether to use the positive ("pos") or negative ("neg") peaks to estimate extremum channels. upsampling_factor : int, default: 10 The upsampling factor to upsample the templates - sparsity: dict or None, default: None - Default is sparsity=None and template metric is computed on extremum channel only. - If given, the dictionary should contain a unit ids as keys and a channel id or a list of channel ids as values. - For more generating a sparsity dict, see the postprocessing.compute_sparsity() function. + sparsity: ChannelSparsity or None, default: None + If None, template metrics are computed on the extremum channel only. + If sparsity is given, template metrics are computed on all sparse channels of each unit. + For more on generating a ChannelSparsity, see the `~spikeinterface.compute_sparsity()` function. include_multi_channel_metrics: bool, default: False Whether to compute multi-channel metrics metrics_kwargs: dict @@ -261,7 +263,7 @@ def compute_template_metrics( * recovery_window_ms: the window in ms after the peak to compute the recovery_slope, default: 0.7 * peak_relative_threshold: the relative threshold to detect positive and negative peaks, default: 0.2 * peak_width_ms: the width in samples to detect peaks, default: 0.2 - * depth_direction: the direction to compute velocity above and below, default: "y" + * depth_direction: the direction to compute velocity above and below, default: "y" (see notes) * min_channels_for_velocity: the minimum number of channels above or below to compute velocity, default: 5 * min_r2_velocity: the minimum r2 to accept the velocity fit, default: 0.7 * exp_peak_function: the function to use to compute the peak amplitude for the exp decay, default: "ptp" @@ -284,7 +286,7 @@ def compute_template_metrics( ----- If any multi-channel metric is in the metric_names or include_multi_channel_metrics is True, sparsity must be None, so that one metric value will be computed per unit. - For multi-channel metrocs, 3D channel locations are not supported. By default, the depth direction is "y". + For multi-channel metrics, 3D channel locations are not supported. By default, the depth direction is "y". """ if debug_plots: global DEBUG @@ -359,6 +361,8 @@ def get_peak_to_valley(template_single, sampling_frequency, trough_idx=None, pea ---------- template_single: numpy.ndarray The 1D template waveform + sampling_frequency : float + The sampling frequency of the template trough_idx: int, default: None The index of the trough peak_idx: int, default: None @@ -383,6 +387,8 @@ def get_peak_trough_ratio(template_single, sampling_frequency=None, trough_idx=N ---------- template_single: numpy.ndarray The 1D template waveform + sampling_frequency : float + The sampling frequency of the template trough_idx: int, default: None The index of the trough peak_idx: int, default: None @@ -407,6 +413,8 @@ def get_half_width(template_single, sampling_frequency, trough_idx=None, peak_id ---------- template_single: numpy.ndarray The 1D template waveform + sampling_frequency : float + The sampling frequency of the template trough_idx: int, default: None The index of the trough peak_idx: int, default: None @@ -458,6 +466,8 @@ def get_repolarization_slope(template_single, sampling_frequency, trough_idx=Non ---------- template_single: numpy.ndarray The 1D template waveform + sampling_frequency : float + The sampling frequency of the template trough_idx: int, default: None The index of the trough """ @@ -499,6 +509,8 @@ def get_recovery_slope(template_single, sampling_frequency, peak_idx=None, **kwa ---------- template_single: numpy.ndarray The 1D template waveform + sampling_frequency : float + The sampling frequency of the template peak_idx: int, default: None The index of the peak **kwargs: Required kwargs: @@ -530,6 +542,8 @@ def get_num_positive_peaks(template_single, sampling_frequency, **kwargs): ---------- template_single: numpy.ndarray The 1D template waveform + sampling_frequency : float + The sampling frequency of the template **kwargs: Required kwargs: - peak_relative_threshold: the relative threshold to detect positive and negative peaks - peak_width_ms: the width in samples to detect peaks @@ -556,6 +570,8 @@ def get_num_negative_peaks(template_single, sampling_frequency, **kwargs): ---------- template_single: numpy.ndarray The 1D template waveform + sampling_frequency : float + The sampling frequency of the template **kwargs: Required kwargs: - peak_relative_threshold: the relative threshold to detect positive and negative peaks - peak_width_ms: the width in samples to detect peaks @@ -590,6 +606,9 @@ def get_num_negative_peaks(template_single, sampling_frequency, **kwargs): def transform_column_range(template, channel_locations, column_range, depth_direction="y"): + """ + Transform template anch channel locations based on column range. + """ column_dim = 0 if depth_direction == "y" else 1 if column_range is None: template_column_range = template @@ -603,12 +622,18 @@ def transform_column_range(template, channel_locations, column_range, depth_dire def sort_template_and_locations(template, channel_locations, depth_direction="y"): + """ + Sort template and locations. + """ depth_dim = 1 if depth_direction == "y" else 0 sort_indices = np.argsort(channel_locations[:, depth_dim]) return template[:, sort_indices], channel_locations[sort_indices, :] def fit_velocity(peak_times, channel_dist): + """ + Fit velocity from peak times and channel distances using ribust Theilsen estimator. + """ # from scipy.stats import linregress # slope, intercept, _, _, _ = linregress(peak_times, channel_dist) @@ -632,6 +657,8 @@ def get_velocity_above(template, channel_locations, sampling_frequency, **kwargs The template waveform (num_samples, num_channels) channel_locations: numpy.ndarray The channel locations (num_channels, 2) + sampling_frequency : float + The sampling frequency of the template **kwargs: Required kwargs: - depth_direction: the direction to compute velocity above and below ("x", "y", or "z") - min_channels_for_velocity: the minimum number of channels above or below to compute velocity @@ -707,6 +734,8 @@ def get_velocity_below(template, channel_locations, sampling_frequency, **kwargs The template waveform (num_samples, num_channels) channel_locations: numpy.ndarray The channel locations (num_channels, 2) + sampling_frequency : float + The sampling frequency of the template **kwargs: Required kwargs: - depth_direction: the direction to compute velocity above and below ("x", "y", or "z") - min_channels_for_velocity: the minimum number of channels above or below to compute velocity @@ -781,6 +810,8 @@ def get_exp_decay(template, channel_locations, sampling_frequency=None, **kwargs The template waveform (num_samples, num_channels) channel_locations: numpy.ndarray The channel locations (num_channels, 2) + sampling_frequency : float + The sampling frequency of the template **kwargs: Required kwargs: - exp_peak_function: the function to use to compute the peak amplitude for the exp decay ("ptp" or "min") - min_r2_exp_decay: the minimum r2 to accept the exp decay fit @@ -856,6 +887,8 @@ def get_spread(template, channel_locations, sampling_frequency, **kwargs): The template waveform (num_samples, num_channels) channel_locations: numpy.ndarray The channel locations (num_channels, 2) + sampling_frequency : float + The sampling frequency of the template **kwargs: Required kwargs: - depth_direction: the direction to compute velocity above and below ("x", "y", or "z") - spread_threshold: the threshold to compute the spread From 4e3140f58cec52b42563b02a5bfb2d0fdda498c3 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 2 Oct 2023 10:19:09 +0200 Subject: [PATCH 07/35] Remove comment --- src/spikeinterface/postprocessing/template_metrics.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 82f55483b4..3f47c505ad 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -41,7 +41,7 @@ class TemplateMetricsCalculator(BaseWaveformExtractorExtension): extension_name = "template_metrics" min_channels_for_multi_channel_warning = 10 - def __init__(self, waveform_extractor): + def __init__(self, waveform_extractor: WaveformExtractor): BaseWaveformExtractorExtension.__init__(self, waveform_extractor) def _set_params( @@ -212,7 +212,6 @@ def get_extension_function(): ) -# TODO: add typing def compute_template_metrics( waveform_extractor, load_if_exists: bool = False, From d373c05673b04354749e9d4ed9fc207f00824de3 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 6 Oct 2023 12:49:21 +0200 Subject: [PATCH 08/35] wip --- .../sorters/internal/tridesclous2.py | 16 +++++++++++----- .../sortingcomponents/clustering/split.py | 2 ++ 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index e2f4812222..5a2664a45e 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -20,9 +20,11 @@ class Tridesclous2Sorter(ComponentsBasedSorter): _default_params = { "apply_preprocessing": True, "waveforms" : {"ms_before": 0.5, "ms_after": 1.5, }, - "filtering": {"freq_min": 300., "freq_max": 8000.0}, - "detection": {"peak_sign": "neg", "detect_threshold": 5, "exclude_sweep_ms": 1.5, "radius_um": 150.}, + "filtering": {"freq_min": 300., "freq_max": 12000.0}, + "detection": {"peak_sign": "neg", "detect_threshold": 5, + "exclude_sweep_ms": 1.5, "radius_um": 150.}, "selection": {"n_peaks_per_channel": 5000, "min_n_peaks": 20000}, + "features": {"radius_um": 120}, "svd": {"n_components": 6}, "clustering": { "split_radius_um": 40., @@ -35,7 +37,10 @@ class Tridesclous2Sorter(ComponentsBasedSorter): }, "matching": { "peak_shift_ms": 0.2, - "radius_um": 100. + # "radius_um": 100. + "num_peeler_loop": 3, + "num_template_try": 3, + }, "job_kwargs": {"n_jobs":-1}, "save_array": True, @@ -143,9 +148,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # upsampling_um=5.0, # ) + radius = params["features"]["radius_um"] node3 = ExtractSparseWaveforms(recording, parents=[node0], return_output=True, - ms_before=0.5, - ms_after=1.5, + ms_before=ms_before, + ms_after=ms_after, radius_um=100.0, ) diff --git a/src/spikeinterface/sortingcomponents/clustering/split.py b/src/spikeinterface/sortingcomponents/clustering/split.py index 9836e9110f..b433a2d16d 100644 --- a/src/spikeinterface/sortingcomponents/clustering/split.py +++ b/src/spikeinterface/sortingcomponents/clustering/split.py @@ -192,6 +192,7 @@ def split( # target channel subset is done intersect local channels + neighbours local_chans = np.unique(peaks["channel_index"][peak_indices]) + target_channels = np.flatnonzero(np.all(neighbours_mask[local_chans, :], axis=0)) # TODO fix this a better way, this when cluster have too few overlapping channels @@ -204,6 +205,7 @@ def split( local_labels[dont_have_channels] = -2 kept = np.flatnonzero(~dont_have_channels) + if kept.size < min_size_split: return False, None From f5a42e7c51d5983738191d10896cf9fb500847c7 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 6 Oct 2023 13:41:13 +0200 Subject: [PATCH 09/35] wip --- src/spikeinterface/sorters/internal/tridesclous2.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index bfc01b897f..909a2d1cb3 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -29,11 +29,12 @@ class Tridesclous2Sorter(ComponentsBasedSorter): "waveforms": { "ms_before": 0.5, "ms_after": 1.5, + "radius_um": 120.0, }, "filtering": {"freq_min": 300.0, "freq_max": 12000.0}, "detection": {"peak_sign": "neg", "detect_threshold": 5, "exclude_sweep_ms": 1.5, "radius_um": 150.0}, "selection": {"n_peaks_per_channel": 5000, "min_n_peaks": 20000}, - "features": {"radius_um": 120}, + "features": {}, "svd": {"n_components": 6}, "clustering": { "split_radius_um": 40.0, @@ -155,13 +156,14 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # upsampling_um=5.0, # ) + radius_um = params["waveforms"]["radius_um"] node3 = ExtractSparseWaveforms( recording, parents=[node0], return_output=True, - ms_beforems_before, + ms_before=ms_before, ms_after=ms_after, - radius_um=radius, + radius_um=radius_um, ) model_folder_path = sorter_output_folder / "tsvd_model" From 446dcc7114f8d275cc55ce557ed11ebe0fb1160e Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 6 Oct 2023 13:42:07 +0200 Subject: [PATCH 10/35] Patch --- src/spikeinterface/sorters/internal/spyking_circus2.py | 2 +- .../sortingcomponents/benchmark/benchmark_clustering.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 0c3b9f95d1..3681a1fbc5 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -114,7 +114,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## We get the labels for our peaks mask = peak_labels > -1 - sorting = NumpySorting.from_times_labels(selected_peaks["sample_index"][mask], peak_labels[mask], sampling_rate) + sorting = NumpySorting.from_times_labels(selected_peaks["sample_index"][mask], peak_labels[mask].astype(int), sampling_rate) clustering_folder = sorter_output_folder / "clustering" if clustering_folder.exists(): shutil.rmtree(clustering_folder) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py index d68b8e5449..bd413417bf 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py @@ -524,7 +524,7 @@ def plot_statistics(self, metric="cosine", annotations=True, detect_threshold=5) template_real = template_real.reshape(template_real.size, 1).T if metric == "cosine": - dist = sklearn.metrics.pairwise.cosine_similarity(template, template_real, metric).flatten().tolist() + dist = sklearn.metrics.pairwise.cosine_similarity(template, template_real).flatten().tolist() else: dist = sklearn.metrics.pairwise_distances(template, template_real, metric).flatten().tolist() res += dist From f56780db0ed2240a94e697dcdc4040e65e706918 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Oct 2023 11:42:58 +0000 Subject: [PATCH 11/35] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sorters/internal/spyking_circus2.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 3681a1fbc5..2c297662f4 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -114,7 +114,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## We get the labels for our peaks mask = peak_labels > -1 - sorting = NumpySorting.from_times_labels(selected_peaks["sample_index"][mask], peak_labels[mask].astype(int), sampling_rate) + sorting = NumpySorting.from_times_labels( + selected_peaks["sample_index"][mask], peak_labels[mask].astype(int), sampling_rate + ) clustering_folder = sorter_output_folder / "clustering" if clustering_folder.exists(): shutil.rmtree(clustering_folder) From 4a778069329529afa7eccc01d7805642d9ff93e5 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 6 Oct 2023 13:46:35 +0200 Subject: [PATCH 12/35] small fix in gtstudy --- src/spikeinterface/comparison/groundtruthstudy.py | 3 +++ src/spikeinterface/widgets/widget_list.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index df0b5296c0..a1814d3527 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -184,6 +184,9 @@ def run_sorters(self, case_keys=None, engine="loop", engine_kwargs={}, keep=True log_file = self.folder / "sortings" / "run_logs" / f"{self.key_to_str(key)}.json" if log_file.exists(): log_file.unlink() + + if sorter_folder_exists: + shutil.rmtree(sorter_folder) params = self.cases[key]["run_sorter_params"].copy() # this ensure that sorter_name is given diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index ed77de6128..51e7208080 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -114,7 +114,7 @@ plot_study_run_times = StudyRunTimesWidget plot_study_unit_counts = StudyUnitCountsWidget plot_study_performances = StudyPerformances -plot_stufy_performances_vs_metrics = StudyPerformancesVsMetrics +plot_study_performances_vs_metrics = StudyPerformancesVsMetrics def plot_timeseries(*args, **kwargs): From 59e617fdfca4898b131c4cbf84b1a2e4eccd1eb0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Oct 2023 11:47:21 +0000 Subject: [PATCH 13/35] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/comparison/groundtruthstudy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index a1814d3527..8d8b255336 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -184,9 +184,9 @@ def run_sorters(self, case_keys=None, engine="loop", engine_kwargs={}, keep=True log_file = self.folder / "sortings" / "run_logs" / f"{self.key_to_str(key)}.json" if log_file.exists(): log_file.unlink() - + if sorter_folder_exists: - shutil.rmtree(sorter_folder) + shutil.rmtree(sorter_folder) params = self.cases[key]["run_sorter_params"].copy() # this ensure that sorter_name is given From 8c64a9f74b035211b4f4623becbdfd010683d402 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 6 Oct 2023 14:19:38 +0200 Subject: [PATCH 14/35] oups --- src/spikeinterface/comparison/groundtruthstudy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index a1814d3527..0133f57e4d 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -165,7 +165,7 @@ def run_sorters(self, case_keys=None, engine="loop", engine_kwargs={}, keep=True sorting_exists = sorting_folder.exists() sorter_folder = self.folder / "sorters" / self.key_to_str(key) - sorter_folder_exists = sorting_folder.exists() + sorter_folder_exists = sorter_folder.exists() if keep: if sorting_exists: From f465e815c7de66958c659880b21085edd38f4216 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 6 Oct 2023 21:22:31 +0200 Subject: [PATCH 15/35] wip --- src/spikeinterface/sorters/internal/tridesclous2.py | 11 +++++++---- .../sortingcomponents/clustering/merge.py | 7 +++++-- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index 909a2d1cb3..ca1dfa1854 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -196,6 +196,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): original_labels = peaks["channel_index"] + min_cluster_size = 50 + post_split_label, split_count = split_clusters( original_labels, recording, @@ -208,8 +210,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # feature_name="sparse_wfs", neighbours_mask=neighbours_mask, waveforms_sparse_mask=sparse_mask, - min_size_split=50, - min_cluster_size=50, + min_size_split=min_cluster_size, + min_cluster_size=min_cluster_size, min_samples=50, n_pca_features=3, ), @@ -240,9 +242,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # criteria="percentile", # threshold_percentile=80., criteria="distrib_overlap", - threshold_overlap=0.4, + threshold_overlap=0.3, + min_cluster_size=min_cluster_size+1, # num_shift=0 - num_shift=2, + num_shift=5, ), **job_kwargs, ) diff --git a/src/spikeinterface/sortingcomponents/clustering/merge.py b/src/spikeinterface/sortingcomponents/clustering/merge.py index d892d0723a..45090452dc 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merge.py +++ b/src/spikeinterface/sortingcomponents/clustering/merge.py @@ -398,6 +398,7 @@ def merge( threshold_diptest=0.5, threshold_percentile=80.0, threshold_overlap=0.4, + min_cluster_size=50, num_shift=2, ): if num_shift > 0: @@ -414,7 +415,7 @@ def merge( chans1 = np.unique(peaks["channel_index"][inds1]) target_chans1 = np.flatnonzero(np.all(waveforms_sparse_mask[chans1, :], axis=0)) - if inds0.size < 40 or inds1.size < 40: + if inds0.size < min_cluster_size or inds1.size < min_cluster_size: is_merge = False merge_value = 0 final_shift = 0 @@ -525,7 +526,9 @@ def merge( # DEBUG = True DEBUG = False - if DEBUG and is_merge: + # if DEBUG and is_merge: + # if DEBUG and (overlap > 0.1 and overlap <0.3): + if DEBUG: # if DEBUG and not is_merge: # if DEBUG and (overlap > 0.05 and overlap <0.25): # if label0 == 49 and label1== 65: From 022c55d3c960dfa570a8737c54937b694fde5f2e Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 6 Oct 2023 21:26:26 +0200 Subject: [PATCH 16/35] Add noise_level in kwargs NoiseGeneratorRecording and change some parameters in generate templates --- src/spikeinterface/core/generate.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 0c67404069..dc84d31987 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -654,6 +654,7 @@ def __init__( "num_channels": num_channels, "durations": durations, "sampling_frequency": sampling_frequency, + "noise_level": noise_level, "dtype": dtype, "seed": seed, "strategy": strategy, @@ -876,13 +877,13 @@ def generate_single_fake_waveform( default_unit_params_range = dict( - alpha=(5_000.0, 15_000.0), + alpha=(6_000.0, 9_000.0), depolarization_ms=(0.09, 0.14), repolarization_ms=(0.5, 0.8), recovery_ms=(1.0, 1.5), positive_amplitude=(0.05, 0.15), smooth_ms=(0.03, 0.07), - decay_power=(1.2, 1.8), + decay_power=(1.4, 1.8), ) From 1bbfe1622baf3250b362dae28c11b03c5dc712cf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Oct 2023 20:24:58 +0000 Subject: [PATCH 17/35] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sorters/internal/tridesclous2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index ca1dfa1854..054596e9b3 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -243,7 +243,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # threshold_percentile=80., criteria="distrib_overlap", threshold_overlap=0.3, - min_cluster_size=min_cluster_size+1, + min_cluster_size=min_cluster_size + 1, # num_shift=0 num_shift=5, ), From f68da6a82b3f8efec5aa1d5f80ba170b5ad6d4e0 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Mon, 9 Oct 2023 07:55:42 +0200 Subject: [PATCH 18/35] Updates --- .../sorters/internal/spyking_circus2.py | 4 +++- .../clustering/random_projections.py | 21 ++++++++----------- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 2c297662f4..780e6a14aa 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -67,6 +67,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # recording_f = whiten(recording_f, dtype="float32") recording_f = zscore(recording_f, dtype="float32") + noise_levels = np.ones(num_channels, dtype=np.float32) ## Then, we are detecting peaks with a locally_exclusive method detection_params = params["detection"].copy() @@ -87,7 +88,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): selection_params["n_peaks"] = params["selection"]["n_peaks_per_channel"] * num_channels selection_params["n_peaks"] = max(selection_params["min_n_peaks"], selection_params["n_peaks"]) - noise_levels = np.ones(num_channels, dtype=np.float32) + selection_params.update({"noise_levels": noise_levels}) selected_peaks = select_peaks( peaks, method="smart_sampling_amplitudes", select_per_channel=False, **selection_params @@ -107,6 +108,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): clustering_params.update(dict(shared_memory=params["shared_memory"])) clustering_params["job_kwargs"] = job_kwargs clustering_params["tmp_folder"] = sorter_output_folder / "clustering" + clustering_params.update({"noise_levels": noise_levels}) labels, peak_labels = find_cluster_from_peaks( recording_f, selected_peaks, method="random_projections", method_kwargs=clustering_params diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index a81458d7a8..a6d69f74aa 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -43,7 +43,8 @@ class RandomProjectionClustering: "ms_before": 1, "ms_after": 1, "random_seed": 42, - "smoothing_kwargs": {"window_length_ms": 1}, + "noise_levels" : None, + "smoothing_kwargs": {"window_length_ms": 0.25}, "shared_memory": True, "tmp_folder": None, "job_kwargs": {"n_jobs": os.cpu_count(), "chunk_memory": "100M", "verbose": True, "progress_bar": True}, @@ -72,7 +73,10 @@ def main_function(cls, recording, peaks, params): num_samples = nbefore + nafter num_chans = recording.get_num_channels() - noise_levels = get_noise_levels(recording, return_scaled=False) + if d["noise_levels"] is None: + noise_levels = get_noise_levels(recording, return_scaled=False) + else: + noise_levels = d["noise_levels"] np.random.seed(d["random_seed"]) @@ -82,7 +86,9 @@ def main_function(cls, recording, peaks, params): else: tmp_folder = Path(params["tmp_folder"]).absolute() - ### Then we extract the SVD features + + tmp_folder.mkdir(parents=True, exist_ok=True) + node0 = PeakRetriever(recording, peaks) node1 = ExtractDenseWaveforms( recording, parents=[node0], return_output=False, ms_before=params["ms_before"], ms_after=params["ms_after"] @@ -174,15 +180,6 @@ def sigmoid(x, L, x0, k, b): if verbose: print("We found %d raw clusters, starting to clean with matching..." % (len(labels))) - # create a tmp folder - if params["tmp_folder"] is None: - name = "".join(random.choices(string.ascii_uppercase + string.digits, k=8)) - tmp_folder = get_global_tmp_folder() / name - else: - tmp_folder = Path(params["tmp_folder"]) - - tmp_folder.mkdir(parents=True, exist_ok=True) - sorting_folder = tmp_folder / "sorting" unit_ids = np.arange(len(np.unique(spikes["unit_index"]))) sorting = NumpySorting(spikes, fs, unit_ids=unit_ids) From ed44aaf68fc6c614373a130041b93d1ce0d9ffc8 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Mon, 9 Oct 2023 10:04:37 +0200 Subject: [PATCH 19/35] Extracting sparse waveforms --- .../clustering/random_projections.py | 21 ++++++++++++++++--- .../sortingcomponents/features_from_peaks.py | 9 ++++++-- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index a6d69f74aa..b1dab9b27c 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -20,7 +20,7 @@ from spikeinterface.core import extract_waveforms from spikeinterface.sortingcomponents.waveforms.savgol_denoiser import SavGolDenoiser from spikeinterface.sortingcomponents.features_from_peaks import RandomProjectionsFeature -from spikeinterface.core.node_pipeline import run_node_pipeline, ExtractDenseWaveforms, PeakRetriever +from spikeinterface.core.node_pipeline import run_node_pipeline, ExtractDenseWaveforms, ExtractSparseWaveforms, PeakRetriever class RandomProjectionClustering: @@ -90,8 +90,9 @@ def main_function(cls, recording, peaks, params): tmp_folder.mkdir(parents=True, exist_ok=True) node0 = PeakRetriever(recording, peaks) - node1 = ExtractDenseWaveforms( - recording, parents=[node0], return_output=False, ms_before=params["ms_before"], ms_after=params["ms_after"] + node1 = ExtractSparseWaveforms( + recording, parents=[node0], return_output=False, ms_before=params["ms_before"], ms_after=params["ms_after"], + radius_um=params['radius_um'] ) node2 = SavGolDenoiser(recording, parents=[node0, node1], return_output=False, **params["smoothing_kwargs"]) @@ -129,6 +130,8 @@ def sigmoid(x, L, x0, k, b): return_output=True, projections=projections, radius_um=params["radius_um"], + sigmoid=None, + sparse=True ) pipeline_nodes = [node0, node1, node2, node3] @@ -142,6 +145,18 @@ def sigmoid(x, L, x0, k, b): clustering = hdbscan.hdbscan(hdbscan_data, **d["hdbscan_kwargs"]) peak_labels = clustering[0] + # peak_labels = -1 * np.ones(len(peaks), dtype=int) + # nb_clusters = 0 + # for c in np.unique(peaks['channel_index']): + # mask = peaks['channel_index'] == c + # clustering = hdbscan.hdbscan(hdbscan_data[mask], **d['hdbscan_kwargs']) + # local_labels = clustering[0] + # valid_clusters = local_labels > -1 + # if np.sum(valid_clusters) > 0: + # local_labels[valid_clusters] += nb_clusters + # peak_labels[mask] = local_labels + # nb_clusters += len(np.unique(local_labels[valid_clusters])) + labels = np.unique(peak_labels) labels = labels[labels >= 0] diff --git a/src/spikeinterface/sortingcomponents/features_from_peaks.py b/src/spikeinterface/sortingcomponents/features_from_peaks.py index b534c2356d..3ca53b05fb 100644 --- a/src/spikeinterface/sortingcomponents/features_from_peaks.py +++ b/src/spikeinterface/sortingcomponents/features_from_peaks.py @@ -186,6 +186,7 @@ def __init__( projections=None, sigmoid=None, radius_um=None, + sparse=True ): PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) @@ -195,7 +196,8 @@ def __init__( self.channel_distance = get_channel_distances(recording) self.neighbours_mask = self.channel_distance < radius_um self.radius_um = radius_um - self._kwargs.update(dict(projections=projections, sigmoid=sigmoid, radius_um=radius_um)) + self.sparse = sparse + self._kwargs.update(dict(projections=projections, sigmoid=sigmoid, radius_um=radius_um, sparse=sparse)) self._dtype = recording.get_dtype() def get_dtype(self): @@ -213,7 +215,10 @@ def compute(self, traces, peaks, waveforms): (idx,) = np.nonzero(peaks["channel_index"] == main_chan) (chan_inds,) = np.nonzero(self.neighbours_mask[main_chan]) local_projections = self.projections[chan_inds, :] - wf_ptp = np.ptp(waveforms[idx][:, :, chan_inds], axis=1) + if self.sparse: + wf_ptp = np.ptp(waveforms[idx][:, :, :len(chan_inds)], axis=1) + else: + wf_ptp = np.ptp(waveforms[idx][:, :, chan_inds], axis=1) if self.sigmoid is not None: wf_ptp *= self._sigmoid(wf_ptp) From fa82f108a1a15e4eeb347a9c86294a65960bbd6d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 9 Oct 2023 08:20:50 +0000 Subject: [PATCH 20/35] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sorters/internal/spyking_circus2.py | 1 - .../clustering/random_projections.py | 20 +++++++++++++------ .../sortingcomponents/features_from_peaks.py | 4 ++-- 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 780e6a14aa..a16b642dd5 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -88,7 +88,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): selection_params["n_peaks"] = params["selection"]["n_peaks_per_channel"] * num_channels selection_params["n_peaks"] = max(selection_params["min_n_peaks"], selection_params["n_peaks"]) - selection_params.update({"noise_levels": noise_levels}) selected_peaks = select_peaks( peaks, method="smart_sampling_amplitudes", select_per_channel=False, **selection_params diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index b1dab9b27c..72acd49f4f 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -20,7 +20,12 @@ from spikeinterface.core import extract_waveforms from spikeinterface.sortingcomponents.waveforms.savgol_denoiser import SavGolDenoiser from spikeinterface.sortingcomponents.features_from_peaks import RandomProjectionsFeature -from spikeinterface.core.node_pipeline import run_node_pipeline, ExtractDenseWaveforms, ExtractSparseWaveforms, PeakRetriever +from spikeinterface.core.node_pipeline import ( + run_node_pipeline, + ExtractDenseWaveforms, + ExtractSparseWaveforms, + PeakRetriever, +) class RandomProjectionClustering: @@ -43,7 +48,7 @@ class RandomProjectionClustering: "ms_before": 1, "ms_after": 1, "random_seed": 42, - "noise_levels" : None, + "noise_levels": None, "smoothing_kwargs": {"window_length_ms": 0.25}, "shared_memory": True, "tmp_folder": None, @@ -86,13 +91,16 @@ def main_function(cls, recording, peaks, params): else: tmp_folder = Path(params["tmp_folder"]).absolute() - tmp_folder.mkdir(parents=True, exist_ok=True) node0 = PeakRetriever(recording, peaks) node1 = ExtractSparseWaveforms( - recording, parents=[node0], return_output=False, ms_before=params["ms_before"], ms_after=params["ms_after"], - radius_um=params['radius_um'] + recording, + parents=[node0], + return_output=False, + ms_before=params["ms_before"], + ms_after=params["ms_after"], + radius_um=params["radius_um"], ) node2 = SavGolDenoiser(recording, parents=[node0, node1], return_output=False, **params["smoothing_kwargs"]) @@ -131,7 +139,7 @@ def sigmoid(x, L, x0, k, b): projections=projections, radius_um=params["radius_um"], sigmoid=None, - sparse=True + sparse=True, ) pipeline_nodes = [node0, node1, node2, node3] diff --git a/src/spikeinterface/sortingcomponents/features_from_peaks.py b/src/spikeinterface/sortingcomponents/features_from_peaks.py index 3ca53b05fb..06d22181cb 100644 --- a/src/spikeinterface/sortingcomponents/features_from_peaks.py +++ b/src/spikeinterface/sortingcomponents/features_from_peaks.py @@ -186,7 +186,7 @@ def __init__( projections=None, sigmoid=None, radius_um=None, - sparse=True + sparse=True, ): PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) @@ -216,7 +216,7 @@ def compute(self, traces, peaks, waveforms): (chan_inds,) = np.nonzero(self.neighbours_mask[main_chan]) local_projections = self.projections[chan_inds, :] if self.sparse: - wf_ptp = np.ptp(waveforms[idx][:, :, :len(chan_inds)], axis=1) + wf_ptp = np.ptp(waveforms[idx][:, :, : len(chan_inds)], axis=1) else: wf_ptp = np.ptp(waveforms[idx][:, :, chan_inds], axis=1) From 1d21b68619605151d1571402fa89d5c71bcc1c05 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 9 Oct 2023 11:16:20 +0200 Subject: [PATCH 21/35] wip tdc2 merge with template. --- .../sorters/internal/tridesclous2.py | 39 +++--- .../sortingcomponents/clustering/merge.py | 115 +++++++++++++++++- 2 files changed, 135 insertions(+), 19 deletions(-) diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index 054596e9b3..11be2c3580 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -229,24 +229,23 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): recording, features_folder, radius_um=merge_radius_um, - method="project_distribution", + # method="project_distribution", + # method_kwargs=dict( + # waveforms_sparse_mask=sparse_mask, + # feature_name="sparse_wfs", + # projection="centroid", + # criteria="distrib_overlap", + # threshold_overlap=0.3, + # min_cluster_size=min_cluster_size + 1, + # num_shift=5, + # ), + method="normalized_template_diff", method_kwargs=dict( - # neighbours_mask=neighbours_mask, waveforms_sparse_mask=sparse_mask, - # feature_name="sparse_tsvd", - feature_name="sparse_wfs", - # projection='lda', - projection="centroid", - # criteria='diptest', - # threshold_diptest=0.5, - # criteria="percentile", - # threshold_percentile=80., - criteria="distrib_overlap", - threshold_overlap=0.3, + threshold_diff=0.2, min_cluster_size=min_cluster_size + 1, - # num_shift=0 num_shift=5, - ), + ), **job_kwargs, ) @@ -255,10 +254,20 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): new_peaks = peaks.copy() new_peaks["sample_index"] -= peak_shifts + # clean very small cluster before peeler + minimum_cluster_size = 25 + labels_set, count = np.unique(post_merge_label, return_counts=True) + to_remove = labels_set[count < minimum_cluster_size] + print(to_remove) + mask = np.isin(post_merge_label, to_remove) + post_merge_label[mask] = -1 + + # final label sets labels_set = np.unique(post_merge_label) labels_set = labels_set[labels_set >= 0] - mask = post_merge_label >= 0 + + mask = post_merge_label >= 0 sorting_temp = NumpySorting.from_times_labels( new_peaks["sample_index"][mask], post_merge_label[mask], diff --git a/src/spikeinterface/sortingcomponents/clustering/merge.py b/src/spikeinterface/sortingcomponents/clustering/merge.py index 45090452dc..24cbedfb8c 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merge.py +++ b/src/spikeinterface/sortingcomponents/clustering/merge.py @@ -256,7 +256,7 @@ def find_merge_pairs( sparse_wfs, sparse_mask, radius_um=70, - method="waveforms_lda", + method="project_distribution", method_kwargs={}, **job_kwargs # n_jobs=1, @@ -308,7 +308,8 @@ def find_merge_pairs( max_workers=n_jobs, initializer=find_pair_worker_init, mp_context=get_context(mp_context), - initargs=(recording, features_dict_or_folder, peak_labels, method, method_kwargs, max_threads_per_process), + initargs=(recording, features_dict_or_folder, peak_labels, labels_set, templates, + method, method_kwargs, max_threads_per_process), ) as pool: jobs = [] for ind0, ind1 in zip(indices0, indices1): @@ -338,13 +339,16 @@ def find_merge_pairs( def find_pair_worker_init( - recording, features_dict_or_folder, original_labels, method, method_kwargs, max_threads_per_process + recording, features_dict_or_folder, original_labels, + labels_set, templates, method, method_kwargs, max_threads_per_process ): global _ctx _ctx = {} _ctx["recording"] = recording _ctx["original_labels"] = original_labels + _ctx["labels_set"] = labels_set + _ctx["templates"] = templates _ctx["method"] = method _ctx["method_kwargs"] = method_kwargs _ctx["method_class"] = find_pair_method_dict[method] @@ -364,8 +368,10 @@ def find_pair_function_wrapper(label0, label1): global _ctx with threadpool_limits(limits=_ctx["max_threads_per_process"]): is_merge, label0, label1, shift, merge_value = _ctx["method_class"].merge( - label0, label1, _ctx["original_labels"], _ctx["peaks"], _ctx["features"], **_ctx["method_kwargs"] + label0, label1, _ctx["labels_set"], _ctx["templates"], + _ctx["original_labels"], _ctx["peaks"], _ctx["features"], **_ctx["method_kwargs"] ) + return is_merge, label0, label1, shift, merge_value @@ -388,6 +394,8 @@ class ProjectDistribution: def merge( label0, label1, + labels_set, + templates, original_labels, peaks, features, @@ -578,7 +586,106 @@ def merge( return is_merge, label0, label1, final_shift, merge_value +class NormalizedTemplateDiff: + """ + Compute the normalized (some kind of) template differences. + And merge if below a threhold. + Do this at several shift. + + """ + + name = "normalized_template_diff" + + @staticmethod + def merge( + label0, + label1, + labels_set, + templates, + original_labels, + peaks, + features, + waveforms_sparse_mask=None, + threshold_diff=0.05, + min_cluster_size=50, + num_shift=5, + ): + + assert waveforms_sparse_mask is not None + + (inds0,) = np.nonzero(original_labels == label0) + chans0 = np.unique(peaks["channel_index"][inds0]) + target_chans0 = np.flatnonzero(np.all(waveforms_sparse_mask[chans0, :], axis=0)) + + (inds1,) = np.nonzero(original_labels == label1) + chans1 = np.unique(peaks["channel_index"][inds1]) + target_chans1 = np.flatnonzero(np.all(waveforms_sparse_mask[chans1, :], axis=0)) + + # if inds0.size < min_cluster_size or inds1.size < min_cluster_size: + # is_merge = False + # merge_value = 0 + # final_shift = 0 + # return is_merge, label0, label1, final_shift, merge_value + + target_chans = np.intersect1d(target_chans0, target_chans1) + union_chans = np.union1d(target_chans0, target_chans1) + + ind0 = list(labels_set).index(label0) + template0 = templates[ind0, :, target_chans] + + ind1 = list(labels_set).index(label1) + template1 = templates[ind1, :, target_chans] + + + num_samples = template0.shape[0] + # norm = np.mean(np.abs(template0)) + np.mean(np.abs(template1)) + norm = np.mean(np.abs(template0) + np.abs(template1)) + all_shift_diff = [] + for shift in range(-num_shift, num_shift + 1): + temp0 = template0[num_shift : num_samples - num_shift, :] + temp1 = template1[num_shift + shift : num_samples - num_shift + shift, :] + d = np.mean(np.abs(temp0 - temp1)) / (norm) + all_shift_diff.append(d) + normed_diff = np.min(all_shift_diff) + + is_merge = normed_diff < threshold_diff + if is_merge: + merge_value = normed_diff + final_shift = np.argmin(all_shift_diff) - num_shift + else: + final_shift = 0 + merge_value = np.nan + + + # DEBUG = False + DEBUG = True + if DEBUG and normed_diff < 0.2: + # if DEBUG: + + import matplotlib.pyplot as plt + + fig, ax = plt.subplots() + + m0 = template0.flatten() + m1 = template1.flatten() + + ax.plot(m0, color="C0", label=f"{label0} {inds0.size}") + ax.plot(m1, color="C1", label=f"{label1} {inds1.size}") + + ax.set_title(f"union{union_chans.size} intersect{target_chans.size} \n {normed_diff:.3f} {final_shift} {is_merge}") + ax.legend() + plt.show() + + + + + + return is_merge, label0, label1, final_shift, merge_value + + + find_pair_method_list = [ ProjectDistribution, + NormalizedTemplateDiff, ] find_pair_method_dict = {e.name: e for e in find_pair_method_list} From 7803413c2f7c64fac4619837fb1ab6cd5cf0d68e Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 9 Oct 2023 16:28:49 +0200 Subject: [PATCH 22/35] Add SPIKEINTERFACE_DEV_PATH to aws gu tests --- .github/workflows/test_containers_singularity_gpu.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/test_containers_singularity_gpu.yml b/.github/workflows/test_containers_singularity_gpu.yml index e74fbeb4a5..d075f5a6ef 100644 --- a/.github/workflows/test_containers_singularity_gpu.yml +++ b/.github/workflows/test_containers_singularity_gpu.yml @@ -46,5 +46,6 @@ jobs: - name: Run test singularity containers with GPU env: REPO_TOKEN: ${{ secrets.PERSONAL_ACCESS_TOKEN }} + SPIKEINTERFACE_DEV_PATH: ${{ github.workspace }} run: | pytest -vv --capture=tee-sys -rA src/spikeinterface/sorters/external/tests/test_singularity_containers_gpu.py From 2753b49c4e4bfd76a5ac6971e52b3604e5ea4617 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 9 Oct 2023 17:40:54 +0000 Subject: [PATCH 23/35] [pre-commit.ci] pre-commit autoupdate MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/pre-commit/pre-commit-hooks: v4.4.0 → v4.5.0](https://github.com/pre-commit/pre-commit-hooks/compare/v4.4.0...v4.5.0) --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 07601cd208..7153a7dfc0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.5.0 hooks: - id: check-yaml - id: end-of-file-fixer From 64d507c7374a609955c69ef61df4e9cde5a7a04d Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 9 Oct 2023 22:19:45 +0200 Subject: [PATCH 24/35] remove print --- src/spikeinterface/sorters/internal/tridesclous2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index 11be2c3580..ddabd46657 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -258,7 +258,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): minimum_cluster_size = 25 labels_set, count = np.unique(post_merge_label, return_counts=True) to_remove = labels_set[count < minimum_cluster_size] - print(to_remove) + mask = np.isin(post_merge_label, to_remove) post_merge_label[mask] = -1 From 8e0575838b6177a134d7f89c95f499465975978d Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 10 Oct 2023 08:11:15 +0200 Subject: [PATCH 25/35] fix plot_spike_on_trace --- src/spikeinterface/widgets/spikes_on_traces.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/spikeinterface/widgets/spikes_on_traces.py b/src/spikeinterface/widgets/spikes_on_traces.py index c2bed8fe41..b68efc3f8a 100644 --- a/src/spikeinterface/widgets/spikes_on_traces.py +++ b/src/spikeinterface/widgets/spikes_on_traces.py @@ -162,10 +162,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): max_y = np.max(traces_widget.data_plot["channel_locations"][:, 1]) n = len(traces_widget.data_plot["channel_ids"]) - order = traces_widget.data_plot["order"] - - if order is None: - order = np.arange(n) if ax.get_legend() is not None: ax.get_legend().remove() @@ -221,7 +217,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): # discontinuity times[:, -1] = np.nan times_r = times.reshape(times.shape[0] * times.shape[1]) - waveforms = traces[waveform_idxs] # [:, :, order] + waveforms = traces[waveform_idxs] waveforms_r = waveforms.reshape((waveforms.shape[0] * waveforms.shape[1], waveforms.shape[2])) for i, chan_id in enumerate(traces_widget.data_plot["channel_ids"]): From 0fd84922dd9d4ae54bcc0183a98d7a50a1e9f50c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 Oct 2023 07:21:34 +0000 Subject: [PATCH 26/35] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sorters/internal/tridesclous2.py | 3 +- .../sortingcomponents/clustering/merge.py | 46 ++++++++++++------- 2 files changed, 31 insertions(+), 18 deletions(-) diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index ddabd46657..e256915fa6 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -245,7 +245,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): threshold_diff=0.2, min_cluster_size=min_cluster_size + 1, num_shift=5, - ), + ), **job_kwargs, ) @@ -266,7 +266,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): labels_set = np.unique(post_merge_label) labels_set = labels_set[labels_set >= 0] - mask = post_merge_label >= 0 sorting_temp = NumpySorting.from_times_labels( new_peaks["sample_index"][mask], diff --git a/src/spikeinterface/sortingcomponents/clustering/merge.py b/src/spikeinterface/sortingcomponents/clustering/merge.py index 24cbedfb8c..c46f214192 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merge.py +++ b/src/spikeinterface/sortingcomponents/clustering/merge.py @@ -308,8 +308,16 @@ def find_merge_pairs( max_workers=n_jobs, initializer=find_pair_worker_init, mp_context=get_context(mp_context), - initargs=(recording, features_dict_or_folder, peak_labels, labels_set, templates, - method, method_kwargs, max_threads_per_process), + initargs=( + recording, + features_dict_or_folder, + peak_labels, + labels_set, + templates, + method, + method_kwargs, + max_threads_per_process, + ), ) as pool: jobs = [] for ind0, ind1 in zip(indices0, indices1): @@ -339,8 +347,14 @@ def find_merge_pairs( def find_pair_worker_init( - recording, features_dict_or_folder, original_labels, - labels_set, templates, method, method_kwargs, max_threads_per_process + recording, + features_dict_or_folder, + original_labels, + labels_set, + templates, + method, + method_kwargs, + max_threads_per_process, ): global _ctx _ctx = {} @@ -368,8 +382,14 @@ def find_pair_function_wrapper(label0, label1): global _ctx with threadpool_limits(limits=_ctx["max_threads_per_process"]): is_merge, label0, label1, shift, merge_value = _ctx["method_class"].merge( - label0, label1, _ctx["labels_set"], _ctx["templates"], - _ctx["original_labels"], _ctx["peaks"], _ctx["features"], **_ctx["method_kwargs"] + label0, + label1, + _ctx["labels_set"], + _ctx["templates"], + _ctx["original_labels"], + _ctx["peaks"], + _ctx["features"], + **_ctx["method_kwargs"], ) return is_merge, label0, label1, shift, merge_value @@ -610,7 +630,6 @@ def merge( min_cluster_size=50, num_shift=5, ): - assert waveforms_sparse_mask is not None (inds0,) = np.nonzero(original_labels == label0) @@ -636,7 +655,6 @@ def merge( ind1 = list(labels_set).index(label1) template1 = templates[ind1, :, target_chans] - num_samples = template0.shape[0] # norm = np.mean(np.abs(template0)) + np.mean(np.abs(template1)) norm = np.mean(np.abs(template0) + np.abs(template1)) @@ -656,11 +674,10 @@ def merge( final_shift = 0 merge_value = np.nan - # DEBUG = False DEBUG = True if DEBUG and normed_diff < 0.2: - # if DEBUG: + # if DEBUG: import matplotlib.pyplot as plt @@ -672,18 +689,15 @@ def merge( ax.plot(m0, color="C0", label=f"{label0} {inds0.size}") ax.plot(m1, color="C1", label=f"{label1} {inds1.size}") - ax.set_title(f"union{union_chans.size} intersect{target_chans.size} \n {normed_diff:.3f} {final_shift} {is_merge}") + ax.set_title( + f"union{union_chans.size} intersect{target_chans.size} \n {normed_diff:.3f} {final_shift} {is_merge}" + ) ax.legend() plt.show() - - - - return is_merge, label0, label1, final_shift, merge_value - find_pair_method_list = [ ProjectDistribution, NormalizedTemplateDiff, From a92b83732ac52d11cf5bc193da24e8b8e5be01a8 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 11 Oct 2023 06:20:00 +0200 Subject: [PATCH 27/35] Forgot extra params --- src/spikeinterface/comparison/groundtruthstudy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index 4f9a0b2a14..e5f4ce8b31 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -286,7 +286,7 @@ def extract_waveforms_gt(self, case_keys=None, **extract_kwargs): # the waveforms depend on the dataset key wf_folder = base_folder / self.key_to_str(dataset_key) recording, gt_sorting = self.datasets[dataset_key] - we = extract_waveforms(recording, gt_sorting, folder=wf_folder) + we = extract_waveforms(recording, gt_sorting, folder=wf_folder, **extract_kwargs) def get_waveform_extractor(self, key): # some recording are not dumpable to json and the waveforms extactor need it! From d2ba3fa5200dbb042ddc8471256a5979642eb7b3 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 12 Oct 2023 09:32:41 -0400 Subject: [PATCH 28/35] Attempt to fix failing CI --- .github/actions/build-test-environment/action.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/actions/build-test-environment/action.yml b/.github/actions/build-test-environment/action.yml index 004fe31203..7241f60a8b 100644 --- a/.github/actions/build-test-environment/action.yml +++ b/.github/actions/build-test-environment/action.yml @@ -37,6 +37,11 @@ runs: - name: git-annex install run: | wget https://downloads.kitenet.net/git-annex/linux/current/git-annex-standalone-amd64.tar.gz + mkdir /home/runner/work/installation + mv git-annex-standalone-amd64.tar.gz /home/runner/work/installation/ + workdir=$(pwd) + cd /home/runner/work/installation tar xvzf git-annex-standalone-amd64.tar.gz echo "$(pwd)/git-annex.linux" >> $GITHUB_PATH + cd $workdir shell: bash From 00b208c04d39980c7d67e45d0afb445456b2824b Mon Sep 17 00:00:00 2001 From: Zach McKenzie <92116279+zm711@users.noreply.github.com> Date: Fri, 13 Oct 2023 15:34:49 -0400 Subject: [PATCH 29/35] fix for waveform parameter change --- .../modules_gallery/core/plot_4_waveform_extractor.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/modules_gallery/core/plot_4_waveform_extractor.py b/examples/modules_gallery/core/plot_4_waveform_extractor.py index 6c886c1eb0..bee8f4061b 100644 --- a/examples/modules_gallery/core/plot_4_waveform_extractor.py +++ b/examples/modules_gallery/core/plot_4_waveform_extractor.py @@ -49,7 +49,8 @@ ############################################################################### # A :py:class:`~spikeinterface.core.WaveformExtractor` object can be created with the -# :py:func:`~spikeinterface.core.extract_waveforms` function: +# :py:func:`~spikeinterface.core.extract_waveforms` function (this defaults to a sparse +# representation of the waveforms): folder = 'waveform_folder' we = extract_waveforms( @@ -87,6 +88,7 @@ recording, sorting, folder, + sparse=False, ms_before=3., ms_after=4., max_spikes_per_unit=500, @@ -149,7 +151,7 @@ # # Option 1) Save a dense waveform extractor to sparse: # -# In this case, from an existing waveform extractor, we can first estimate a +# In this case, from an existing (dense) waveform extractor, we can first estimate a # sparsity (which channels each unit is defined on) and then save to a new # folder in sparse mode: @@ -173,7 +175,7 @@ ############################################################################### -# Option 2) Directly extract sparse waveforms: +# Option 2) Directly extract sparse waveforms (current spikeinterface default): # # We can also directly extract sparse waveforms. To do so, dense waveforms are # extracted first using a small number of spikes (:code:`'num_spikes_for_sparsity'`) From 21e2e974a8c56c091300b77e76c3ac5f9d98b103 Mon Sep 17 00:00:00 2001 From: Zach McKenzie <92116279+zm711@users.noreply.github.com> Date: Fri, 13 Oct 2023 15:41:18 -0400 Subject: [PATCH 30/35] fix sparsity here as well. --- .../modules_gallery/qualitymetrics/plot_3_quality_mertics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/modules_gallery/qualitymetrics/plot_3_quality_mertics.py b/examples/modules_gallery/qualitymetrics/plot_3_quality_mertics.py index 209f357457..986680e798 100644 --- a/examples/modules_gallery/qualitymetrics/plot_3_quality_mertics.py +++ b/examples/modules_gallery/qualitymetrics/plot_3_quality_mertics.py @@ -30,7 +30,7 @@ # because it contains a reference to the "Recording" and the "Sorting" objects: folder = 'waveforms_mearec' -we = si.extract_waveforms(recording, sorting, folder, +we = si.extract_waveforms(recording, sorting, folder, sparsity=False, ms_before=1, ms_after=2., max_spikes_per_unit=500, n_jobs=1, chunk_durations='1s') print(we) From 9addc5769dc7beb019a98715a87fcce681f6e3cb Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 13 Oct 2023 16:04:00 -0400 Subject: [PATCH 31/35] Fix grouping of OpenEphys NPIX --- examples/modules_gallery/qualitymetrics/plot_4_curation.py | 2 +- src/spikeinterface/extractors/neoextractors/openephys.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/modules_gallery/qualitymetrics/plot_4_curation.py b/examples/modules_gallery/qualitymetrics/plot_4_curation.py index c66f55f221..8953a5a835 100644 --- a/examples/modules_gallery/qualitymetrics/plot_4_curation.py +++ b/examples/modules_gallery/qualitymetrics/plot_4_curation.py @@ -61,4 +61,4 @@ curated_sorting = sorting.select_units(keep_unit_ids) print(curated_sorting) -se.NpzSortingExtractor.write_sorting(curated_sorting, 'curated_sorting.pnz') +se.NpzSortingExtractor.write_sorting(curated_sorting, 'curated_sorting.npz') diff --git a/src/spikeinterface/extractors/neoextractors/openephys.py b/src/spikeinterface/extractors/neoextractors/openephys.py index cd2b6fb941..bb3ae3435a 100644 --- a/src/spikeinterface/extractors/neoextractors/openephys.py +++ b/src/spikeinterface/extractors/neoextractors/openephys.py @@ -183,7 +183,10 @@ def __init__( probe = None if probe is not None: - self = self.set_probe(probe, in_place=True) + if probe.shank_ids is not None: + self.set_probe(probe, in_place=True, group_mode="by_shank") + else: + self.set_probe(probe, in_place=True) probe_name = probe.annotations["probe_name"] # load num_channels_per_adc depending on probe type if "2.0" in probe_name: From bf8c5d1ccbe85b22e0397af562df1a21f256331f Mon Sep 17 00:00:00 2001 From: Zach McKenzie <92116279+zm711@users.noreply.github.com> Date: Fri, 13 Oct 2023 18:24:48 -0400 Subject: [PATCH 32/35] sparse=False for dense --- .../modules_gallery/qualitymetrics/plot_3_quality_mertics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/modules_gallery/qualitymetrics/plot_3_quality_mertics.py b/examples/modules_gallery/qualitymetrics/plot_3_quality_mertics.py index 986680e798..7b6aae3e30 100644 --- a/examples/modules_gallery/qualitymetrics/plot_3_quality_mertics.py +++ b/examples/modules_gallery/qualitymetrics/plot_3_quality_mertics.py @@ -30,7 +30,7 @@ # because it contains a reference to the "Recording" and the "Sorting" objects: folder = 'waveforms_mearec' -we = si.extract_waveforms(recording, sorting, folder, sparsity=False, +we = si.extract_waveforms(recording, sorting, folder, sparse=False, ms_before=1, ms_after=2., max_spikes_per_unit=500, n_jobs=1, chunk_durations='1s') print(we) From 0d29a422d3adc32d0cc113479ec530c05e6a20a1 Mon Sep 17 00:00:00 2001 From: Zach McKenzie <92116279+zm711@users.noreply.github.com> Date: Sat, 14 Oct 2023 09:00:34 -0400 Subject: [PATCH 33/35] add kwargs and make keep_unit_ids list of strings --- .../qualitymetrics/plot_4_curation.py | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/examples/modules_gallery/qualitymetrics/plot_4_curation.py b/examples/modules_gallery/qualitymetrics/plot_4_curation.py index c66f55f221..7f33e0bd8f 100644 --- a/examples/modules_gallery/qualitymetrics/plot_4_curation.py +++ b/examples/modules_gallery/qualitymetrics/plot_4_curation.py @@ -6,6 +6,8 @@ quality metrics. """ +############################################################################# +# Import the modules and/or functions necessary from spikeinterface import spikeinterface as si import spikeinterface.extractors as se @@ -15,22 +17,21 @@ ############################################################################## -# First, let's download a simulated dataset -# from the repo 'https://gin.g-node.org/NeuralEnsemble/ephy_testing_data' +# Let's download a simulated dataset +# from the repo 'https://gin.g-node.org/NeuralEnsemble/ephy_testing_data' # # Let's imagine that the ground-truth sorting is in fact the output of a sorter. -# local_path = si.download_dataset(remote_path='mearec/mearec_test_10s.h5') -recording, sorting = se.read_mearec(local_path) +recording, sorting = se.read_mearec(file_path=local_path) print(recording) print(sorting) ############################################################################## -# First, we extract waveforms and compute their PC scores: +# First, we extract waveforms (to be saved in the folder 'wfs_mearec') and +# compute their PC scores: -folder = 'wfs_mearec' -we = si.extract_waveforms(recording, sorting, folder, +we = si.extract_waveforms(recording, sorting, folder='wfs_mearec', ms_before=1, ms_after=2., max_spikes_per_unit=500, n_jobs=1, chunk_size=30000) print(we) @@ -47,12 +48,15 @@ ############################################################################## # We can now threshold each quality metric and select units based on some rules. # -# The easiest and most intuitive way is to use boolean masking with dataframe: +# The easiest and most intuitive way is to use boolean masking with a dataframe. +# +# Then create a list of unit ids that we want to keep keep_mask = (metrics['snr'] > 7.5) & (metrics['isi_violations_ratio'] < 0.2) & (metrics['nn_hit_rate'] > 0.90) print(keep_mask) keep_unit_ids = keep_mask[keep_mask].index.values +keep_unit_ids = [unit_id for unit_id in keep_unit_ids] print(keep_unit_ids) ############################################################################## @@ -61,4 +65,4 @@ curated_sorting = sorting.select_units(keep_unit_ids) print(curated_sorting) -se.NpzSortingExtractor.write_sorting(curated_sorting, 'curated_sorting.pnz') +se.NpzSortingExtractor.write_sorting(sorting=curated_sorting, save_path='curated_sorting.npz') From 4180b22eedef178ca95057ee0140d279c292e9bb Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 16 Oct 2023 13:06:10 +0200 Subject: [PATCH 34/35] Fix slicing in merge.py and so tridesclous2 and so test_launcher.py --- src/spikeinterface/sorters/tests/test_launcher.py | 6 +++--- src/spikeinterface/sortingcomponents/clustering/merge.py | 6 ++++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/sorters/tests/test_launcher.py b/src/spikeinterface/sorters/tests/test_launcher.py index a5e29c8fd9..fdadf533f5 100644 --- a/src/spikeinterface/sorters/tests/test_launcher.py +++ b/src/spikeinterface/sorters/tests/test_launcher.py @@ -233,15 +233,15 @@ def test_run_sorters_with_dict(): if __name__ == "__main__": - # setup_module() + setup_module() job_list = get_job_list() - # test_run_sorter_jobs_loop(job_list) + test_run_sorter_jobs_loop(job_list) # test_run_sorter_jobs_joblib(job_list) # test_run_sorter_jobs_processpoolexecutor(job_list) # test_run_sorter_jobs_multiprocessing(job_list) # test_run_sorter_jobs_dask(job_list) - test_run_sorter_jobs_slurm(job_list) + # test_run_sorter_jobs_slurm(job_list) # test_run_sorter_by_property() diff --git a/src/spikeinterface/sortingcomponents/clustering/merge.py b/src/spikeinterface/sortingcomponents/clustering/merge.py index c46f214192..a1da1ad6e9 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merge.py +++ b/src/spikeinterface/sortingcomponents/clustering/merge.py @@ -649,11 +649,13 @@ def merge( target_chans = np.intersect1d(target_chans0, target_chans1) union_chans = np.union1d(target_chans0, target_chans1) + + ind0 = list(labels_set).index(label0) - template0 = templates[ind0, :, target_chans] + template0 = templates[ind0][:, target_chans] ind1 = list(labels_set).index(label1) - template1 = templates[ind1, :, target_chans] + template1 = templates[ind1][:, target_chans] num_samples = template0.shape[0] # norm = np.mean(np.abs(template0)) + np.mean(np.abs(template1)) From a00ce05a124962d3fa410947c378082d6c1caa6c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 16 Oct 2023 11:08:30 +0000 Subject: [PATCH 35/35] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sortingcomponents/clustering/merge.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/merge.py b/src/spikeinterface/sortingcomponents/clustering/merge.py index a1da1ad6e9..d35b562298 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merge.py +++ b/src/spikeinterface/sortingcomponents/clustering/merge.py @@ -649,8 +649,6 @@ def merge( target_chans = np.intersect1d(target_chans0, target_chans1) union_chans = np.union1d(target_chans0, target_chans1) - - ind0 = list(labels_set).index(label0) template0 = templates[ind0][:, target_chans]