Skip to content

Commit

Permalink
Merge pull request #2753 from yger/curation_improvements
Browse files Browse the repository at this point in the history
Enhancing curation : get_potential_auto_merge()
  • Loading branch information
samuelgarcia authored Apr 30, 2024
2 parents 0498c51 + 138496b commit c0bfb49
Showing 1 changed file with 67 additions and 19 deletions.
86 changes: 67 additions & 19 deletions src/spikeinterface/curation/auto_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np

from ..core import create_sorting_analyzer
from ..core.template import Templates
from ..core.template_tools import get_template_extremum_channel
from ..postprocessing import compute_correlograms
from ..qualitymetrics import compute_refrac_period_violations, compute_firing_rates
Expand Down Expand Up @@ -30,6 +31,7 @@ def get_potential_auto_merge(
firing_contamination_balance=1.5,
extra_outputs=False,
steps=None,
template_metric="l1",
):
"""
Algorithm to find and check potential merges between units.
Expand Down Expand Up @@ -63,7 +65,7 @@ def get_potential_auto_merge(
Minimum number of spikes for each unit to consider a potential merge.
Enough spikes are needed to estimate the correlogram
maximum_distance_um: float, default: 150
Minimum distance between units for considering a merge
Maximum distance between units for considering a merge
peak_sign: "neg" | "pos" | "both", default: "neg"
Peak sign used to estimate the maximum channel of a template
bin_ms: float, default: 0.25
Expand Down Expand Up @@ -101,6 +103,8 @@ def get_potential_auto_merge(
If None all steps are done.
Pontential steps: "min_spikes", "remove_contaminated", "unit_positions", "correlogram", "template_similarity",
"check_increase_score". Please check steps explanations above!
template_metric: 'l1', 'l2' or 'cosine'
The metric to consider when measuring the distances between templates. Default is l1
Returns
-------
Expand All @@ -114,6 +118,7 @@ def get_potential_auto_merge(
import scipy

sorting = sorting_analyzer.sorting
recording = sorting_analyzer.recording
unit_ids = sorting.unit_ids

# to get fast computation we will not analyse pairs when:
Expand Down Expand Up @@ -154,12 +159,17 @@ def get_potential_auto_merge(

# STEP 3 : unit positions are estimated roughly with channel
if "unit_positions" in steps:
chan_loc = sorting_analyzer.get_channel_locations()
unit_max_chan = get_template_extremum_channel(
sorting_analyzer, peak_sign=peak_sign, mode="extremum", outputs="index"
)
unit_max_chan = list(unit_max_chan.values())
unit_locations = chan_loc[unit_max_chan, :]
positions_ext = sorting_analyzer.get_extension("unit_locations")
if positions_ext is not None:
unit_locations = positions_ext.get_data()[:, :2]
else:
chan_loc = sorting_analyzer.get_channel_locations()
unit_max_chan = get_template_extremum_channel(
sorting_analyzer, peak_sign=peak_sign, mode="extremum", outputs="index"
)
unit_max_chan = list(unit_max_chan.values())
unit_locations = chan_loc[unit_max_chan, :]

unit_distances = scipy.spatial.distance.cdist(unit_locations, unit_locations, metric="euclidean")
pair_mask = pair_mask & (unit_distances <= maximum_distance_um)

Expand Down Expand Up @@ -194,10 +204,18 @@ def get_potential_auto_merge(
templates_ext is not None
), "auto_merge with template_similarity requires a SortingAnalyzer with extension templates"

templates = templates_ext.get_templates(operator="average")
templates_array = templates_ext.get_data(outputs="numpy")

templates_diff = compute_templates_diff(
sorting, templates, num_channels=num_channels, num_shift=num_shift, pair_mask=pair_mask
sorting,
templates_array,
num_channels=num_channels,
num_shift=num_shift,
pair_mask=pair_mask,
template_metric=template_metric,
sparsity=sorting_analyzer.sparsity,
)

pair_mask = pair_mask & (templates_diff < template_diff_thresh)

# STEP 6 : validate the potential merges with CC increase the contamination quality metrics
Expand Down Expand Up @@ -378,23 +396,29 @@ def get_unit_adaptive_window(auto_corr: np.ndarray, threshold: float):
return win_size


def compute_templates_diff(sorting, templates, num_channels=5, num_shift=5, pair_mask=None):
def compute_templates_diff(
sorting, templates_array, num_channels=5, num_shift=5, pair_mask=None, template_metric="l1", sparsity=None
):
"""
Computes normalilzed template differences.
Computes normalized template differences.
Parameters
----------
sorting : BaseSorting
The sorting object
templates : np.array
The templates array (num_units, num_samples, num_channels)
templates_array : np.array
The templates array (num_units, num_samples, num_channels).
num_channels: int, default: 5
Number of channel to use for template similarity computation
num_shift: int, default: 5
Number of shifts in samles to be explored for template similarity computation
pair_mask: None or boolean array
A bool matrix of size (num_units, num_units) to select
which pair to compute.
template_metric: 'l1', 'l2' or 'cosine'
The metric to consider when measuring the distances between templates. Default is l1
sparsity: None or ChannelSparsity
Optionaly a ChannelSparsity object.
Returns
-------
Expand All @@ -403,30 +427,54 @@ def compute_templates_diff(sorting, templates, num_channels=5, num_shift=5, pair
"""
unit_ids = sorting.unit_ids
n = len(unit_ids)
assert template_metric in ["l1", "l2", "cosine"], "Not a valid metric!"

if pair_mask is None:
pair_mask = np.ones((n, n), dtype="bool")

if sparsity is None:
adaptative_masks = False
sparsity_mask = None
else:
adaptative_masks = num_channels == None
sparsity_mask = sparsity.mask

templates_diff = np.full((n, n), np.nan, dtype="float64")
for unit_ind1 in range(n):
for unit_ind2 in range(unit_ind1 + 1, n):
if not pair_mask[unit_ind1, unit_ind2]:
continue

template1 = templates[unit_ind1]
template2 = templates[unit_ind2]
template1 = templates_array[unit_ind1]
template2 = templates_array[unit_ind2]
# take best channels
chan_inds = np.argsort(np.max(np.abs(template1 + template2), axis=0))[::-1][:num_channels]
if not adaptative_masks:
chan_inds = np.argsort(np.max(np.abs(template1 + template2), axis=0))[::-1][:num_channels]
else:
chan_inds = np.intersect1d(
np.flatnonzero(sparsity_mask[unit_ind1]), np.flatnonzero(sparsity_mask[unit_ind2])
)

template1 = template1[:, chan_inds]
template2 = template2[:, chan_inds]

num_samples = template1.shape[0]
norm = np.sum(np.abs(template1)) + np.sum(np.abs(template2))
if template_metric == "l1":
norm = np.sum(np.abs(template1)) + np.sum(np.abs(template2))
elif template_metric == "l2":
norm = np.sum(template1**2) + np.sum(template2**2)
elif template_metric == "cosine":
norm = np.linalg.norm(template1) * np.linalg.norm(template2)
all_shift_diff = []
for shift in range(-num_shift, num_shift + 1):
temp1 = template1[num_shift : num_samples - num_shift, :]
temp2 = template2[num_shift + shift : num_samples - num_shift + shift, :]
d = np.sum(np.abs(temp1 - temp2)) / (norm)
if template_metric == "l1":
d = np.sum(np.abs(temp1 - temp2)) / norm
elif template_metric == "l2":
d = np.linalg.norm(temp1 - temp2) / norm
elif template_metric == "cosine":
d = 1 - np.sum(temp1 * temp2) / norm
all_shift_diff.append(d)
templates_diff[unit_ind1, unit_ind2] = np.min(all_shift_diff)

Expand All @@ -437,7 +485,7 @@ def check_improve_contaminations_score(
sorting_analyzer, pair_mask, contaminations, firing_contamination_balance, refractory_period_ms, censored_period_ms
):
"""
Check that the score is improve afeter a potential merge
Check that the score is improve after a potential merge
The score is a balance between:
* contamination decrease
Expand Down

0 comments on commit c0bfb49

Please sign in to comment.