Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Matched filtering with both peak signs simultaneously #2914

Merged
merged 27 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
67c3e00
Handling of both peak types
yger May 28, 2024
fadd015
WIP
yger May 28, 2024
6bee02b
WIP
yger May 28, 2024
e1d87d0
WIP
yger May 29, 2024
50a562e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 29, 2024
49ef51e
adding tests
yger May 29, 2024
e9e2419
Merge branch 'mf_with_both_peak_signs' of github.com:yger/spikeinterf…
yger May 29, 2024
32dccf2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 29, 2024
563dcb8
Merge branch 'SpikeInterface:main' into mf_with_both_peak_signs
yger May 31, 2024
6d20129
Merge branch 'SpikeInterface:main' into mf_with_both_peak_signs
yger Jun 2, 2024
c6c5d3e
Debug
yger Jul 10, 2024
f702cf5
Merge branch 'SpikeInterface:main' into mf_with_both_peak_signs
yger Jul 10, 2024
43e8828
Trying to find the problem
yger Jul 11, 2024
eaee6b4
Fixing matchef filtering both peaks
yger Jul 19, 2024
7d4c675
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 19, 2024
a2d4573
Merge branch 'main' into mf_with_both_peak_signs
yger Jul 19, 2024
462d485
Merge branch 'main' into mf_with_both_peak_signs
yger Jul 30, 2024
ef21b3c
Merge branch 'SpikeInterface:main' into mf_with_both_peak_signs
yger Aug 26, 2024
085763d
Merge branch 'SpikeInterface:main' into mf_with_both_peak_signs
yger Aug 27, 2024
2ded61a
Merge branch 'main' of https://github.com/SpikeInterface/spikeinterfa…
yger Sep 10, 2024
5db5053
Merge branch 'main' of github.com:spikeinterface/spikeinterface into …
yger Sep 11, 2024
62eb941
Merge branch 'SpikeInterface:main' into mf_with_both_peak_signs
yger Sep 24, 2024
2daf3df
Speeding up mf
yger Sep 27, 2024
d8e5aa7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 27, 2024
9eb22a2
Merge branch 'main' into mf_with_both_peak_signs
yger Sep 27, 2024
cb77c0c
Checks for prototypes
yger Oct 2, 2024
0595412
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 36 additions & 67 deletions src/spikeinterface/sortingcomponents/peak_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,47 +631,31 @@ def __init__(
self.conv_margin = prototype.shape[0]

assert peak_sign in ("both", "neg", "pos")
idx = np.argmax(np.abs(prototype))
self.nbefore = int(ms_before * recording.sampling_frequency / 1000)
if peak_sign == "neg":
assert prototype[idx] < 0, "Prototype should have a negative peak"
assert prototype[self.nbefore] < 0, "Prototype should have a negative peak"
peak_sign = "pos"
elif peak_sign == "pos":
assert prototype[idx] > 0, "Prototype should have a positive peak"
elif peak_sign == "both":
raise NotImplementedError("Matched filtering not working with peak_sign=both yet!")
assert prototype[self.nbefore] > 0, "Prototype should have a positive peak"

self.peak_sign = peak_sign
self.nbefore = int(ms_before * recording.sampling_frequency / 1000)
self.prototype = np.flip(prototype) / np.linalg.norm(prototype)

contact_locations = recording.get_channel_locations()
dist = np.linalg.norm(contact_locations[:, np.newaxis] - contact_locations[np.newaxis, :], axis=2)
weights, self.z_factors = get_convolution_weights(dist, **weight_method)
self.weights, self.z_factors = get_convolution_weights(dist, **weight_method)
self.num_z_factors = len(self.z_factors)
self.num_channels = recording.get_num_channels()
self.num_templates = self.num_channels
if peak_sign == "both":
self.weights = np.hstack((self.weights, self.weights))
self.weights[:, self.num_templates :, :] *= -1
self.num_templates *= 2

num_channels = recording.get_num_channels()
num_templates = num_channels * len(self.z_factors)
weights = weights.reshape(num_templates, -1)

templates = weights[:, None, :] * prototype[None, :, None]
templates -= templates.mean(axis=(1, 2))[:, None, None]
temporal, singular, spatial = np.linalg.svd(templates, full_matrices=False)
temporal = temporal[:, :, :rank]
singular = singular[:, :rank]
spatial = spatial[:, :rank, :]
templates = np.matmul(temporal * singular[:, np.newaxis, :], spatial)
norms = np.linalg.norm(templates, axis=(1, 2))
del templates

temporal /= norms[:, np.newaxis, np.newaxis]
temporal = np.flip(temporal, axis=1)
spatial = np.moveaxis(spatial, [0, 1, 2], [1, 0, 2])
temporal = np.moveaxis(temporal, [0, 1, 2], [1, 2, 0])
singular = singular.T[:, :, np.newaxis]

self.temporal = temporal
self.spatial = spatial
self.singular = singular
self.weights = self.weights.reshape(self.num_templates * self.num_z_factors, -1)

random_data = get_random_data_chunks(recording, return_scaled=False, **random_chunk_kwargs)
conv_random_data = self.get_convolved_traces(random_data, temporal, spatial, singular)
conv_random_data = self.get_convolved_traces(random_data)
medians = np.median(conv_random_data, axis=1)
medians = medians[:, None]
noise_levels = np.median(np.abs(conv_random_data - medians), axis=1) / 0.6744897501960817
Expand All @@ -688,16 +672,13 @@ def get_trace_margin(self):
def compute(self, traces, start_frame, end_frame, segment_index, max_margin):

assert HAVE_NUMBA, "You need to install numba"
conv_traces = self.get_convolved_traces(traces, self.temporal, self.spatial, self.singular)
conv_traces = self.get_convolved_traces(traces)
conv_traces /= self.abs_thresholds[:, None]
conv_traces = conv_traces[:, self.conv_margin : -self.conv_margin]
traces_center = conv_traces[:, self.exclude_sweep_size : -self.exclude_sweep_size]

num_z_factors = len(self.z_factors)
num_templates = traces.shape[1]

traces_center = traces_center.reshape(num_z_factors, num_templates, traces_center.shape[1])
conv_traces = conv_traces.reshape(num_z_factors, num_templates, conv_traces.shape[1])
traces_center = traces_center.reshape(self.num_z_factors, self.num_templates, traces_center.shape[1])
conv_traces = conv_traces.reshape(self.num_z_factors, self.num_templates, conv_traces.shape[1])
peak_mask = traces_center > 1

peak_mask = _numba_detect_peak_matched_filtering(
Expand All @@ -708,11 +689,13 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin):
self.abs_thresholds,
self.peak_sign,
self.neighbours_mask,
num_templates,
self.num_channels,
)

# Find peaks and correct for time shift
z_ind, peak_chan_ind, peak_sample_ind = np.nonzero(peak_mask)
if self.peak_sign == "both":
peak_chan_ind = peak_chan_ind % self.num_channels

# If we want to estimate z
# peak_chan_ind = peak_chan_ind % num_channels
Expand All @@ -739,16 +722,11 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin):
# return is always a tuple
return (local_peaks,)

def get_convolved_traces(self, traces, temporal, spatial, singular):
def get_convolved_traces(self, traces):
import scipy.signal

num_timesteps, num_templates = len(traces), temporal.shape[1]
num_peaks = num_timesteps - self.conv_margin + 1
scalar_products = np.zeros((num_templates, num_peaks), dtype=np.float32)
spatially_filtered_data = np.matmul(spatial, traces.T[np.newaxis, :, :])
scaled_filtered_data = spatially_filtered_data * singular
objective_by_rank = scipy.signal.oaconvolve(scaled_filtered_data, temporal, axes=2, mode="valid")
scalar_products += np.sum(objective_by_rank, axis=0)
tmp = scipy.signal.oaconvolve(self.prototype[None, :], traces.T, axes=1, mode="valid")
scalar_products = np.dot(self.weights, tmp)
return scalar_products


Expand Down Expand Up @@ -873,37 +851,28 @@ def _numba_detect_peak_neg(

@numba.jit(nopython=True, parallel=False)
def _numba_detect_peak_matched_filtering(
traces, traces_center, peak_mask, exclude_sweep_size, abs_thresholds, peak_sign, neighbours_mask, num_templates
traces, traces_center, peak_mask, exclude_sweep_size, abs_thresholds, peak_sign, neighbours_mask, num_channels
):
num_z = traces_center.shape[0]
num_templates = traces_center.shape[1]
for template_ind in range(num_templates):
for z in range(num_z):
for s in range(peak_mask.shape[2]):
if not peak_mask[z, template_ind, s]:
continue
for neighbour in range(num_templates):
if not neighbours_mask[template_ind, neighbour]:
continue
for j in range(num_z):
if not neighbours_mask[template_ind % num_channels, neighbour % num_channels]:
continue
for i in range(exclude_sweep_size):
if template_ind >= neighbour:
if z >= j:
peak_mask[z, template_ind, s] &= (
traces_center[z, template_ind, s] >= traces_center[j, neighbour, s]
)
else:
peak_mask[z, template_ind, s] &= (
traces_center[z, template_ind, s] > traces_center[j, neighbour, s]
)
elif template_ind < neighbour:
if z > j:
peak_mask[z, template_ind, s] &= (
traces_center[z, template_ind, s] > traces_center[j, neighbour, s]
)
else:
peak_mask[z, template_ind, s] &= (
traces_center[z, template_ind, s] > traces_center[j, neighbour, s]
)
if template_ind >= neighbour and z >= j:
peak_mask[z, template_ind, s] &= (
traces_center[z, template_ind, s] >= traces_center[j, neighbour, s]
)
else:
peak_mask[z, template_ind, s] &= (
traces_center[z, template_ind, s] > traces_center[j, neighbour, s]
)
peak_mask[z, template_ind, s] &= (
traces_center[z, template_ind, s] > traces[j, neighbour, s + i]
)
Expand Down
27 changes: 23 additions & 4 deletions src/spikeinterface/sortingcomponents/tests/test_peak_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,19 +328,38 @@ def test_detect_peaks_locally_exclusive_matched_filtering(recording, job_kwargs)
)
assert len(peaks_local_mf_filtering) > len(peaks_by_channel_np)

peaks_local_mf_filtering_both = detect_peaks(
recording,
method="matched_filtering",
peak_sign="both",
detect_threshold=5,
exclude_sweep_ms=0.1,
prototype=prototype,
ms_before=1.0,
**job_kwargs,
)
assert len(peaks_local_mf_filtering_both) > len(peaks_local_mf_filtering)

DEBUG = False
if DEBUG:
import matplotlib.pyplot as plt

peaks = peaks_local_mf_filtering
peaks_local = peaks_by_channel_np
peaks_mf_neg = peaks_local_mf_filtering
peaks_mf_both = peaks_local_mf_filtering_both
labels = ["locally_exclusive", "mf_neg", "mf_both"]

sample_inds, chan_inds, amplitudes = peaks["sample_index"], peaks["channel_index"], peaks["amplitude"]
fig, ax = plt.subplots()
chan_offset = 500
traces = recording.get_traces().copy()
traces += np.arange(traces.shape[1])[None, :] * chan_offset
fig, ax = plt.subplots()
ax.plot(traces, color="k")
ax.scatter(sample_inds, chan_inds * chan_offset + amplitudes, color="r")

for count, peaks in enumerate([peaks_local, peaks_mf_neg, peaks_mf_both]):
sample_inds, chan_inds, amplitudes = peaks["sample_index"], peaks["channel_index"], peaks["amplitude"]
ax.scatter(sample_inds, chan_inds * chan_offset + amplitudes, label=labels[count])

ax.legend()
plt.show()


Expand Down