Skip to content

Commit

Permalink
Merge pull request #3319 from rkim48/fix-template-similarity-deprecat…
Browse files Browse the repository at this point in the history
…ion-warning

Fix DeprecationWarnings from distance calculations
  • Loading branch information
alejoe91 authored Sep 5, 2024
2 parents 0e36a3a + 2f87919 commit b3f9793
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions src/spikeinterface/postprocessing/template_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ def _get_data(self):
def compute_similarity_with_templates_array(
templates_array, other_templates_array, method, support="union", num_shifts=0, sparsity=None, other_sparsity=None
):

import sklearn.metrics.pairwise

if method == "cosine_similarity":
Expand Down Expand Up @@ -223,15 +222,17 @@ def compute_similarity_with_templates_array(
if method == "l1":
norm_i = np.sum(np.abs(src))
norm_j = np.sum(np.abs(tgt))
distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances(src, tgt, metric="l1")
distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances(src, tgt, metric="l1").item()
distances[count, i, j] /= norm_i + norm_j
elif method == "l2":
norm_i = np.linalg.norm(src, ord=2)
norm_j = np.linalg.norm(tgt, ord=2)
distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances(src, tgt, metric="l2")
distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances(src, tgt, metric="l2").item()
distances[count, i, j] /= norm_i + norm_j
else:
distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances(src, tgt, metric="cosine")
distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances(
src, tgt, metric="cosine"
).item()

if same_array:
distances[count, j, i] = distances[count, i, j]
Expand Down

0 comments on commit b3f9793

Please sign in to comment.