diff --git a/CHANGELOG.md b/CHANGELOG.md index 7488d5834..e40382bf3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,7 @@ * Use pytest format for dandi tests to avoid window permission error on teardown [PR #1151](https://github.com/catalystneuro/neuroconv/pull/1151) * Added many docstrings for public functions [PR #1063](https://github.com/catalystneuro/neuroconv/pull/1063) * Clean up with warnings and deprecations in the testing framework [PR #1158](https://github.com/catalystneuro/neuroconv/pull/1158) +* Detect mismatch errors between group and group names when writing ElectrodeGroups [PR #1165](https://github.com/catalystneuro/neuroconv/pull/1165) # v0.6.5 (November 1, 2024) diff --git a/src/neuroconv/tools/spikeinterface/spikeinterface.py b/src/neuroconv/tools/spikeinterface/spikeinterface.py index d31e6032c..7df25c703 100644 --- a/src/neuroconv/tools/spikeinterface/spikeinterface.py +++ b/src/neuroconv/tools/spikeinterface/spikeinterface.py @@ -245,11 +245,19 @@ def _get_group_name(recording: BaseRecording) -> np.ndarray: An array containing the group names. If the `group_name` property is not available, the channel groups will be returned. If the group names are empty, a default value 'ElectrodeGroup' will be used. + + Raises + ------ + ValueError + If the number of unique group names doesn't match the number of unique groups, + or if the mapping between group names and group numbers is inconsistent. """ default_value = "ElectrodeGroup" group_names = recording.get_property("group_name") + groups = recording.get_channel_groups() + if group_names is None: - group_names = recording.get_channel_groups() + group_names = groups if group_names is None: group_names = np.full(recording.get_num_channels(), fill_value=default_value) @@ -259,6 +267,23 @@ def _get_group_name(recording: BaseRecording) -> np.ndarray: # If for any reason the group names are empty, fill them with the default group_names[group_names == ""] = default_value + # Validate group names against groups + if groups is not None: + unique_groups = set(groups) + unique_names = set(group_names) + + if len(unique_names) != len(unique_groups): + raise ValueError("The number of group names must match the number of groups") + + # Check consistency of group name to group number mapping + group_to_name_map = {} + for group, name in zip(groups, group_names): + if group in group_to_name_map: + if group_to_name_map[group] != name: + raise ValueError("Inconsistent mapping between group numbers and group names") + else: + group_to_name_map[group] = name + return group_names diff --git a/src/neuroconv/tools/testing/mock_interfaces.py b/src/neuroconv/tools/testing/mock_interfaces.py index 38cc750ab..42bf699df 100644 --- a/src/neuroconv/tools/testing/mock_interfaces.py +++ b/src/neuroconv/tools/testing/mock_interfaces.py @@ -220,6 +220,12 @@ def __init__( es_key=es_key, ) + # Adding this as a safeguard before the spikeinterface changes are merged: + # https://github.com/SpikeInterface/spikeinterface/pull/3588 + channel_ids = self.recording_extractor.get_channel_ids() + channel_ids_as_strings = [str(id) for id in channel_ids] + self.recording_extractor = self.recording_extractor.rename_channels(new_channel_ids=channel_ids_as_strings) + def get_metadata(self) -> dict: """ Returns the metadata dictionary for the current object. @@ -272,6 +278,7 @@ def __init__( ) # Sorting extractor to have string unit ids until is changed in SpikeInterface + # https://github.com/SpikeInterface/spikeinterface/pull/3588 string_unit_ids = [str(id) for id in self.sorting_extractor.unit_ids] self.sorting_extractor = self.sorting_extractor.rename_units(new_unit_ids=string_unit_ids) diff --git a/tests/test_ecephys/test_tools_spikeinterface.py b/tests/test_ecephys/test_tools_spikeinterface.py index 3436a2e70..68c681f46 100644 --- a/tests/test_ecephys/test_tools_spikeinterface.py +++ b/tests/test_ecephys/test_tools_spikeinterface.py @@ -24,6 +24,7 @@ from neuroconv.tools.nwb_helpers import get_module from neuroconv.tools.spikeinterface import ( add_electrical_series_to_nwbfile, + add_electrode_groups_to_nwbfile, add_electrodes_to_nwbfile, add_recording_to_nwbfile, add_sorting_to_nwbfile, @@ -1071,6 +1072,29 @@ def test_missing_bool_values(self): assert np.array_equal(extracted_incomplete_property, expected_incomplete_property) +class TestAddElectrodeGroups: + def test_group_naming_not_matching_group_number(self): + recording = generate_recording(num_channels=4) + recording.set_channel_groups(groups=[0, 1, 2, 3]) + recording.set_property(key="group_name", values=["A", "A", "A", "A"]) + + nwbfile = mock_NWBFile() + with pytest.raises(ValueError, match="The number of group names must match the number of groups"): + add_electrode_groups_to_nwbfile(nwbfile=nwbfile, recording=recording) + + def test_inconsistent_group_name_mapping(self): + recording = generate_recording(num_channels=3) + # Set up groups where the same group name is used for different group numbers + recording.set_channel_groups(groups=[0, 1, 0]) + recording.set_property( + key="group_name", values=["A", "B", "B"] # Inconsistent: group 0 maps to names "A" and "B" + ) + + nwbfile = mock_NWBFile() + with pytest.raises(ValueError, match="Inconsistent mapping between group numbers and group names"): + add_electrode_groups_to_nwbfile(nwbfile=nwbfile, recording=recording) + + class TestAddUnitsTable(TestCase): @classmethod def setUpClass(cls):