Skip to content

Commit

Permalink
Remove debug plots from template metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 committed Oct 19, 2023
1 parent a733fd0 commit 65b2496
Showing 1 changed file with 73 additions and 82 deletions.
155 changes: 73 additions & 82 deletions src/spikeinterface/postprocessing/template_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
from ..core.waveform_extractor import BaseWaveformExtractorExtension


global DEBUG
DEBUG = False
# DEBUG = False


def get_single_channel_template_metric_names():
Expand Down Expand Up @@ -223,7 +222,6 @@ def compute_template_metrics(
sparsity: Optional[ChannelSparsity] = None,
include_multi_channel_metrics: bool = False,
metrics_kwargs: dict = None,
debug_plots: bool = False,
):
"""
Compute template metrics including:
Expand Down Expand Up @@ -289,9 +287,6 @@ def compute_template_metrics(
so that one metric value will be computed per unit.
For multi-channel metrics, 3D channel locations are not supported. By default, the depth direction is "y".
"""
if debug_plots:
global DEBUG
DEBUG = True
if load_if_exists and waveform_extractor.is_extension(TemplateMetricsCalculator.extension_name):
tmc = waveform_extractor.load_extension(TemplateMetricsCalculator.extension_name)
else:
Expand Down Expand Up @@ -697,27 +692,26 @@ def get_velocity_above(template, channel_locations, sampling_frequency, **kwargs
distances_um_above = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations_above])
velocity_above, intercept, score = fit_velocity(peak_times_ms_above, distances_um_above)

global DEBUG
if DEBUG:
import matplotlib.pyplot as plt

fig, axs = plt.subplots(ncols=2, figsize=(10, 7))
offset = 1.2 * np.max(np.ptp(template, axis=0))
ts = np.arange(template.shape[0]) / sampling_frequency * 1000 - max_peak_time
(channel_indices_above,) = np.nonzero(channels_above)
for i, single_template in enumerate(template.T):
color = "r" if i in channel_indices_above else "k"
axs[0].plot(ts, single_template + i * offset, color=color)
axs[0].axvline(0, color="g", ls="--")
axs[1].plot(peak_times_ms_above, distances_um_above, "o")
x = np.linspace(peak_times_ms_above.min(), peak_times_ms_above.max(), 20)
axs[1].plot(x, intercept + x * velocity_above)
axs[1].set_xlabel("Peak time (ms)")
axs[1].set_ylabel("Distance from max channel (um)")
fig.suptitle(
f"Velocity above: {velocity_above:.2f} um/ms - score {score:.2f} - channels: {np.sum(channels_above)}"
)
plt.show()
# if DEBUG:
# import matplotlib.pyplot as plt

# fig, axs = plt.subplots(ncols=2, figsize=(10, 7))
# offset = 1.2 * np.max(np.ptp(template, axis=0))
# ts = np.arange(template.shape[0]) / sampling_frequency * 1000 - max_peak_time
# (channel_indices_above,) = np.nonzero(channels_above)
# for i, single_template in enumerate(template.T):
# color = "r" if i in channel_indices_above else "k"
# axs[0].plot(ts, single_template + i * offset, color=color)
# axs[0].axvline(0, color="g", ls="--")
# axs[1].plot(peak_times_ms_above, distances_um_above, "o")
# x = np.linspace(peak_times_ms_above.min(), peak_times_ms_above.max(), 20)
# axs[1].plot(x, intercept + x * velocity_above)
# axs[1].set_xlabel("Peak time (ms)")
# axs[1].set_ylabel("Distance from max channel (um)")
# fig.suptitle(
# f"Velocity above: {velocity_above:.2f} um/ms - score {score:.2f} - channels: {np.sum(channels_above)}"
# )
# plt.show()

if np.sum(channels_above) < min_channels_for_velocity or score < min_r2_velocity:
velocity_above = np.nan
Expand Down Expand Up @@ -773,27 +767,26 @@ def get_velocity_below(template, channel_locations, sampling_frequency, **kwargs
distances_um_below = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations_below])
velocity_below, intercept, score = fit_velocity(peak_times_ms_below, distances_um_below)

global DEBUG
if DEBUG:
import matplotlib.pyplot as plt

fig, axs = plt.subplots(ncols=2, figsize=(10, 7))
offset = 1.2 * np.max(np.ptp(template, axis=0))
ts = np.arange(template.shape[0]) / sampling_frequency * 1000 - max_peak_time
(channel_indices_below,) = np.nonzero(channels_below)
for i, single_template in enumerate(template.T):
color = "r" if i in channel_indices_below else "k"
axs[0].plot(ts, single_template + i * offset, color=color)
axs[0].axvline(0, color="g", ls="--")
axs[1].plot(peak_times_ms_below, distances_um_below, "o")
x = np.linspace(peak_times_ms_below.min(), peak_times_ms_below.max(), 20)
axs[1].plot(x, intercept + x * velocity_below)
axs[1].set_xlabel("Peak time (ms)")
axs[1].set_ylabel("Distance from max channel (um)")
fig.suptitle(
f"Velocity below: {np.round(velocity_below, 3)} um/ms - score {score:.2f} - channels: {np.sum(channels_below)}"
)
plt.show()
# if DEBUG:
# import matplotlib.pyplot as plt

# fig, axs = plt.subplots(ncols=2, figsize=(10, 7))
# offset = 1.2 * np.max(np.ptp(template, axis=0))
# ts = np.arange(template.shape[0]) / sampling_frequency * 1000 - max_peak_time
# (channel_indices_below,) = np.nonzero(channels_below)
# for i, single_template in enumerate(template.T):
# color = "r" if i in channel_indices_below else "k"
# axs[0].plot(ts, single_template + i * offset, color=color)
# axs[0].axvline(0, color="g", ls="--")
# axs[1].plot(peak_times_ms_below, distances_um_below, "o")
# x = np.linspace(peak_times_ms_below.min(), peak_times_ms_below.max(), 20)
# axs[1].plot(x, intercept + x * velocity_below)
# axs[1].set_xlabel("Peak time (ms)")
# axs[1].set_ylabel("Distance from max channel (um)")
# fig.suptitle(
# f"Velocity below: {np.round(velocity_below, 3)} um/ms - score {score:.2f} - channels: {np.sum(channels_below)}"
# )
# plt.show()

if np.sum(channels_below) < min_channels_for_velocity or score < min_r2_velocity:
velocity_below = np.nan
Expand Down Expand Up @@ -856,22 +849,21 @@ def exp_decay(x, decay, amp0, offset):
if r2 < min_r2_exp_decay:
exp_decay_value = np.nan

global DEBUG
if DEBUG:
import matplotlib.pyplot as plt

fig, ax = plt.subplots()
ax.plot(channel_distances_sorted, peak_amplitudes_sorted, "o")
x = np.arange(channel_distances_sorted.min(), channel_distances_sorted.max())
ax.plot(x, exp_decay(x, *popt))
ax.set_xlabel("Distance from max channel (um)")
ax.set_ylabel("Peak amplitude")
ax.set_title(
f"Exp decay: {np.round(exp_decay_value, 3)} - Amp: {np.round(popt[1], 3)} - Offset: {np.round(popt[2], 3)} - "
f"R2: {np.round(r2, 4)}"
)
fig.suptitle("Exp decay")
plt.show()
# if DEBUG:
# import matplotlib.pyplot as plt

# fig, ax = plt.subplots()
# ax.plot(channel_distances_sorted, peak_amplitudes_sorted, "o")
# x = np.arange(channel_distances_sorted.min(), channel_distances_sorted.max())
# ax.plot(x, exp_decay(x, *popt))
# ax.set_xlabel("Distance from max channel (um)")
# ax.set_ylabel("Peak amplitude")
# ax.set_title(
# f"Exp decay: {np.round(exp_decay_value, 3)} - Amp: {np.round(popt[1], 3)} - Offset: {np.round(popt[2], 3)} - "
# f"R2: {np.round(r2, 4)}"
# )
# fig.suptitle("Exp decay")
# plt.show()
except:
exp_decay_value = np.nan

Expand Down Expand Up @@ -922,24 +914,23 @@ def get_spread(template, channel_locations, sampling_frequency, **kwargs):
channel_depth_above_theshold = channel_locations_above_theshold[:, depth_dim]
spread = np.ptp(channel_depth_above_theshold)

global DEBUG
if DEBUG:
import matplotlib.pyplot as plt

fig, axs = plt.subplots(ncols=2, figsize=(10, 7))
axs[0].imshow(
template.T,
aspect="auto",
origin="lower",
extent=[0, template.shape[0] / sampling_frequency, channel_depths[0], channel_depths[-1]],
)
axs[1].plot(channel_depths, MM, "o-")
axs[1].axhline(spread_threshold, ls="--", color="r")
axs[1].set_xlabel("Depth (um)")
axs[1].set_ylabel("Amplitude")
axs[1].set_title(f"Spread: {np.round(spread, 3)} um")
fig.suptitle("Spread")
plt.show()
# if DEBUG:
# import matplotlib.pyplot as plt

# fig, axs = plt.subplots(ncols=2, figsize=(10, 7))
# axs[0].imshow(
# template.T,
# aspect="auto",
# origin="lower",
# extent=[0, template.shape[0] / sampling_frequency, channel_depths[0], channel_depths[-1]],
# )
# axs[1].plot(channel_depths, MM, "o-")
# axs[1].axhline(spread_threshold, ls="--", color="r")
# axs[1].set_xlabel("Depth (um)")
# axs[1].set_ylabel("Amplitude")
# axs[1].set_title(f"Spread: {np.round(spread, 3)} um")
# fig.suptitle("Spread")
# plt.show()

return spread

Expand Down

0 comments on commit 65b2496

Please sign in to comment.