Skip to content

Commit

Permalink
Merge pull request #3180 from h-mayorquin/sorting_aggregation_should_…
Browse files Browse the repository at this point in the history
…preserve_ids

Units aggregation preserve unit ids of aggregated sorters
  • Loading branch information
alejoe91 authored Jul 12, 2024
2 parents e2b0a34 + f7018c6 commit 53933e6
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 1 deletion.
38 changes: 38 additions & 0 deletions src/spikeinterface/core/tests/test_unitsaggregationsorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from spikeinterface.core import NpzSortingExtractor
from spikeinterface.core import create_sorting_npz
from spikeinterface.core import generate_sorting


def test_unitsaggregationsorting(create_cache_folder):
Expand Down Expand Up @@ -92,5 +93,42 @@ def test_unitsaggregationsorting(create_cache_folder):
print(sorting_agg_prop.get_property("brain_area"))


def test_unit_aggregation_preserve_ids():

sorting1 = generate_sorting(num_units=3)
sorting1 = sorting1.rename_units(new_unit_ids=["unit1", "unit2", "unit3"])

sorting2 = generate_sorting(num_units=3)
sorting2 = sorting2.rename_units(new_unit_ids=["unit4", "unit5", "unit6"])

aggregated_sorting = aggregate_units([sorting1, sorting2])
assert aggregated_sorting.get_num_units() == 6
assert list(aggregated_sorting.get_unit_ids()) == ["unit1", "unit2", "unit3", "unit4", "unit5", "unit6"]


def test_unit_aggregation_does_not_preserve_ids_if_not_unique():
sorting1 = generate_sorting(num_units=3)
sorting1 = sorting1.rename_units(new_unit_ids=["unit1", "unit2", "unit3"])

sorting2 = generate_sorting(num_units=3)
sorting2 = sorting2.rename_units(new_unit_ids=["unit1", "unit2", "unit3"])

aggregated_sorting = aggregate_units([sorting1, sorting2])
assert aggregated_sorting.get_num_units() == 6
assert list(aggregated_sorting.get_unit_ids()) == ["0", "1", "2", "3", "4", "5"]


def test_unit_aggregation_does_not_preserve_ids_not_the_same_type():
sorting1 = generate_sorting(num_units=3)
sorting1 = sorting1.rename_units(new_unit_ids=["unit1", "unit2", "unit3"])

sorting2 = generate_sorting(num_units=2)
sorting2 = sorting2.rename_units(new_unit_ids=[1, 2])

aggregated_sorting = aggregate_units([sorting1, sorting2])
assert aggregated_sorting.get_num_units() == 5
assert list(aggregated_sorting.get_unit_ids()) == ["0", "1", "2", "3", "4"]


if __name__ == "__main__":
test_unitsaggregationsorting()
16 changes: 15 additions & 1 deletion src/spikeinterface/core/unitsaggregationsorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,21 @@ def __init__(self, sorting_list, renamed_unit_ids=None):
)
unit_ids = list(renamed_unit_ids)
else:
unit_ids = list(np.arange(num_all_units))
unit_ids_dtypes = [sort.get_unit_ids().dtype for sort in sorting_list]
all_ids_are_same_type = np.unique(unit_ids_dtypes).size == 1
all_units_ids_are_unique = False
if all_ids_are_same_type:
combined_ids = np.concatenate([sort.get_unit_ids() for sort in sorting_list])
all_units_ids_are_unique = np.unique(combined_ids).size == num_all_units

if all_ids_are_same_type and all_units_ids_are_unique:
unit_ids = combined_ids
else:
default_unit_ids = [str(i) for i in range(num_all_units)]
if all_ids_are_same_type and np.issubdtype(unit_ids_dtypes[0], np.integer):
unit_ids = np.arange(num_all_units, dtype=np.uint64)
else:
unit_ids = default_unit_ids

# unit map maps unit ids that are used to get spike trains
u_id = 0
Expand Down

0 comments on commit 53933e6

Please sign in to comment.