diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index bd6d1eff2a..a359e2a814 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -15,8 +15,7 @@ from ..core.waveform_extractor import BaseWaveformExtractorExtension -global DEBUG -DEBUG = False +# DEBUG = False def get_single_channel_template_metric_names(): @@ -223,7 +222,6 @@ def compute_template_metrics( sparsity: Optional[ChannelSparsity] = None, include_multi_channel_metrics: bool = False, metrics_kwargs: dict = None, - debug_plots: bool = False, ): """ Compute template metrics including: @@ -289,9 +287,6 @@ def compute_template_metrics( so that one metric value will be computed per unit. For multi-channel metrics, 3D channel locations are not supported. By default, the depth direction is "y". """ - if debug_plots: - global DEBUG - DEBUG = True if load_if_exists and waveform_extractor.is_extension(TemplateMetricsCalculator.extension_name): tmc = waveform_extractor.load_extension(TemplateMetricsCalculator.extension_name) else: @@ -697,27 +692,26 @@ def get_velocity_above(template, channel_locations, sampling_frequency, **kwargs distances_um_above = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations_above]) velocity_above, intercept, score = fit_velocity(peak_times_ms_above, distances_um_above) - global DEBUG - if DEBUG: - import matplotlib.pyplot as plt - - fig, axs = plt.subplots(ncols=2, figsize=(10, 7)) - offset = 1.2 * np.max(np.ptp(template, axis=0)) - ts = np.arange(template.shape[0]) / sampling_frequency * 1000 - max_peak_time - (channel_indices_above,) = np.nonzero(channels_above) - for i, single_template in enumerate(template.T): - color = "r" if i in channel_indices_above else "k" - axs[0].plot(ts, single_template + i * offset, color=color) - axs[0].axvline(0, color="g", ls="--") - axs[1].plot(peak_times_ms_above, distances_um_above, "o") - x = np.linspace(peak_times_ms_above.min(), peak_times_ms_above.max(), 20) - axs[1].plot(x, intercept + x * velocity_above) - axs[1].set_xlabel("Peak time (ms)") - axs[1].set_ylabel("Distance from max channel (um)") - fig.suptitle( - f"Velocity above: {velocity_above:.2f} um/ms - score {score:.2f} - channels: {np.sum(channels_above)}" - ) - plt.show() + # if DEBUG: + # import matplotlib.pyplot as plt + + # fig, axs = plt.subplots(ncols=2, figsize=(10, 7)) + # offset = 1.2 * np.max(np.ptp(template, axis=0)) + # ts = np.arange(template.shape[0]) / sampling_frequency * 1000 - max_peak_time + # (channel_indices_above,) = np.nonzero(channels_above) + # for i, single_template in enumerate(template.T): + # color = "r" if i in channel_indices_above else "k" + # axs[0].plot(ts, single_template + i * offset, color=color) + # axs[0].axvline(0, color="g", ls="--") + # axs[1].plot(peak_times_ms_above, distances_um_above, "o") + # x = np.linspace(peak_times_ms_above.min(), peak_times_ms_above.max(), 20) + # axs[1].plot(x, intercept + x * velocity_above) + # axs[1].set_xlabel("Peak time (ms)") + # axs[1].set_ylabel("Distance from max channel (um)") + # fig.suptitle( + # f"Velocity above: {velocity_above:.2f} um/ms - score {score:.2f} - channels: {np.sum(channels_above)}" + # ) + # plt.show() if np.sum(channels_above) < min_channels_for_velocity or score < min_r2_velocity: velocity_above = np.nan @@ -773,27 +767,26 @@ def get_velocity_below(template, channel_locations, sampling_frequency, **kwargs distances_um_below = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations_below]) velocity_below, intercept, score = fit_velocity(peak_times_ms_below, distances_um_below) - global DEBUG - if DEBUG: - import matplotlib.pyplot as plt - - fig, axs = plt.subplots(ncols=2, figsize=(10, 7)) - offset = 1.2 * np.max(np.ptp(template, axis=0)) - ts = np.arange(template.shape[0]) / sampling_frequency * 1000 - max_peak_time - (channel_indices_below,) = np.nonzero(channels_below) - for i, single_template in enumerate(template.T): - color = "r" if i in channel_indices_below else "k" - axs[0].plot(ts, single_template + i * offset, color=color) - axs[0].axvline(0, color="g", ls="--") - axs[1].plot(peak_times_ms_below, distances_um_below, "o") - x = np.linspace(peak_times_ms_below.min(), peak_times_ms_below.max(), 20) - axs[1].plot(x, intercept + x * velocity_below) - axs[1].set_xlabel("Peak time (ms)") - axs[1].set_ylabel("Distance from max channel (um)") - fig.suptitle( - f"Velocity below: {np.round(velocity_below, 3)} um/ms - score {score:.2f} - channels: {np.sum(channels_below)}" - ) - plt.show() + # if DEBUG: + # import matplotlib.pyplot as plt + + # fig, axs = plt.subplots(ncols=2, figsize=(10, 7)) + # offset = 1.2 * np.max(np.ptp(template, axis=0)) + # ts = np.arange(template.shape[0]) / sampling_frequency * 1000 - max_peak_time + # (channel_indices_below,) = np.nonzero(channels_below) + # for i, single_template in enumerate(template.T): + # color = "r" if i in channel_indices_below else "k" + # axs[0].plot(ts, single_template + i * offset, color=color) + # axs[0].axvline(0, color="g", ls="--") + # axs[1].plot(peak_times_ms_below, distances_um_below, "o") + # x = np.linspace(peak_times_ms_below.min(), peak_times_ms_below.max(), 20) + # axs[1].plot(x, intercept + x * velocity_below) + # axs[1].set_xlabel("Peak time (ms)") + # axs[1].set_ylabel("Distance from max channel (um)") + # fig.suptitle( + # f"Velocity below: {np.round(velocity_below, 3)} um/ms - score {score:.2f} - channels: {np.sum(channels_below)}" + # ) + # plt.show() if np.sum(channels_below) < min_channels_for_velocity or score < min_r2_velocity: velocity_below = np.nan @@ -856,22 +849,21 @@ def exp_decay(x, decay, amp0, offset): if r2 < min_r2_exp_decay: exp_decay_value = np.nan - global DEBUG - if DEBUG: - import matplotlib.pyplot as plt - - fig, ax = plt.subplots() - ax.plot(channel_distances_sorted, peak_amplitudes_sorted, "o") - x = np.arange(channel_distances_sorted.min(), channel_distances_sorted.max()) - ax.plot(x, exp_decay(x, *popt)) - ax.set_xlabel("Distance from max channel (um)") - ax.set_ylabel("Peak amplitude") - ax.set_title( - f"Exp decay: {np.round(exp_decay_value, 3)} - Amp: {np.round(popt[1], 3)} - Offset: {np.round(popt[2], 3)} - " - f"R2: {np.round(r2, 4)}" - ) - fig.suptitle("Exp decay") - plt.show() + # if DEBUG: + # import matplotlib.pyplot as plt + + # fig, ax = plt.subplots() + # ax.plot(channel_distances_sorted, peak_amplitudes_sorted, "o") + # x = np.arange(channel_distances_sorted.min(), channel_distances_sorted.max()) + # ax.plot(x, exp_decay(x, *popt)) + # ax.set_xlabel("Distance from max channel (um)") + # ax.set_ylabel("Peak amplitude") + # ax.set_title( + # f"Exp decay: {np.round(exp_decay_value, 3)} - Amp: {np.round(popt[1], 3)} - Offset: {np.round(popt[2], 3)} - " + # f"R2: {np.round(r2, 4)}" + # ) + # fig.suptitle("Exp decay") + # plt.show() except: exp_decay_value = np.nan @@ -922,24 +914,23 @@ def get_spread(template, channel_locations, sampling_frequency, **kwargs): channel_depth_above_theshold = channel_locations_above_theshold[:, depth_dim] spread = np.ptp(channel_depth_above_theshold) - global DEBUG - if DEBUG: - import matplotlib.pyplot as plt - - fig, axs = plt.subplots(ncols=2, figsize=(10, 7)) - axs[0].imshow( - template.T, - aspect="auto", - origin="lower", - extent=[0, template.shape[0] / sampling_frequency, channel_depths[0], channel_depths[-1]], - ) - axs[1].plot(channel_depths, MM, "o-") - axs[1].axhline(spread_threshold, ls="--", color="r") - axs[1].set_xlabel("Depth (um)") - axs[1].set_ylabel("Amplitude") - axs[1].set_title(f"Spread: {np.round(spread, 3)} um") - fig.suptitle("Spread") - plt.show() + # if DEBUG: + # import matplotlib.pyplot as plt + + # fig, axs = plt.subplots(ncols=2, figsize=(10, 7)) + # axs[0].imshow( + # template.T, + # aspect="auto", + # origin="lower", + # extent=[0, template.shape[0] / sampling_frequency, channel_depths[0], channel_depths[-1]], + # ) + # axs[1].plot(channel_depths, MM, "o-") + # axs[1].axhline(spread_threshold, ls="--", color="r") + # axs[1].set_xlabel("Depth (um)") + # axs[1].set_ylabel("Amplitude") + # axs[1].set_title(f"Spread: {np.round(spread, 3)} um") + # fig.suptitle("Spread") + # plt.show() return spread