From 6022722c6fb9c0c8a35e9c6f5b1cbac8bd106e24 Mon Sep 17 00:00:00 2001 From: BptGrm Date: Thu, 7 Nov 2024 14:04:18 +0100 Subject: [PATCH 1/3] Implement get_unit_location --- .../extractors/herdingspikesextractors.py | 37 +++++++++++++------ 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/src/spikeinterface/extractors/herdingspikesextractors.py b/src/spikeinterface/extractors/herdingspikesextractors.py index de4929218b..1c69827fa8 100644 --- a/src/spikeinterface/extractors/herdingspikesextractors.py +++ b/src/spikeinterface/extractors/herdingspikesextractors.py @@ -43,28 +43,39 @@ def __init__(self, file_path, load_unit_info=True): spike_ids = self._rf["cluster_id"][()] unit_ids = np.unique(spike_ids) spike_times = self._rf["times"][()] + unit_locs = self._rf["centres"][()] - if load_unit_info: - self.load_unit_info() + # if load_unit_info: + # self.load_unit_info() BaseSorting.__init__(self, sampling_frequency, unit_ids) - self.add_sorting_segment(HerdingspikesSortingSegment(unit_ids, spike_times, spike_ids)) + self.add_sorting_segment(HerdingspikesSortingSegment(unit_ids, spike_times, spike_ids, unit_locs)) self._kwargs = {"file_path": str(Path(file_path).absolute()), "load_unit_info": load_unit_info} self.extra_requirements.append("h5py") - def load_unit_info(self): - # TODO - """ + def get_unit_location( + self, + unit_id, + segment_index=None, + ): + + segment_index = self._check_segment_index(segment_index) + segment = self._sorting_segments[segment_index] + loc = segment.get_unit_location(unit_id=unit_id) + return loc + + """ + def load_unit_info(self): + if 'centres' in self._rf.keys() and len(self._spike_times) > 0: self._unit_locs = self._rf['centres'][()] # cache for faster access - for u_i, unit_id in enumerate(self._unit_ids): - self.set_unit_property(unit_id, property_name='unit_location', value=self._unit_locs[u_i]) inds = [] # get these only once for unit_id in self._unit_ids: inds.append(np.where(self._cluster_id == unit_id)[0]) - if 'data' in self._rf.keys() and len(self._spike_times) > 0: - d = self._rf['data'][()] + if 'x' in self._rf.keys() and 'y' in self._rf.keys() and len(self._spike_times) > 0: + x = self._rf['x'][()] + y = self._rf['y'][()] for i, unit_id in enumerate(self._unit_ids): self.set_unit_spike_features(unit_id, 'spike_location', d[:, inds[i]].T) if 'ch' in self._rf.keys() and len(self._spike_times) > 0: @@ -79,12 +90,13 @@ def load_unit_info(self): class HerdingspikesSortingSegment(BaseSortingSegment): - def __init__(self, unit_ids, spike_times, spike_ids): + def __init__(self, unit_ids, spike_times, spike_ids, unit_locs): BaseSortingSegment.__init__(self) # spike_times is a dict self._unit_ids = list(unit_ids) self._spike_times = spike_times self._spike_ids = spike_ids + self._unit_locs = unit_locs def get_unit_spike_train(self, unit_id, start_frame, end_frame): mask = self._spike_ids == unit_id @@ -95,6 +107,9 @@ def get_unit_spike_train(self, unit_id, start_frame, end_frame): times = times[times < end_frame] return times + def get_unit_location(self, unit_id): + return self._unit_locs[unit_id] + """ @staticmethod def write_sorting(sorting, save_path): From eaf7c5b6c69ce732e49522eff3debb4ac6e7bb6b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 7 Nov 2024 13:30:12 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../extractors/herdingspikesextractors.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/extractors/herdingspikesextractors.py b/src/spikeinterface/extractors/herdingspikesextractors.py index 1c69827fa8..2deb9fc2a8 100644 --- a/src/spikeinterface/extractors/herdingspikesextractors.py +++ b/src/spikeinterface/extractors/herdingspikesextractors.py @@ -55,19 +55,19 @@ def __init__(self, file_path, load_unit_info=True): self.extra_requirements.append("h5py") def get_unit_location( - self, - unit_id, - segment_index=None, - ): - + self, + unit_id, + segment_index=None, + ): + segment_index = self._check_segment_index(segment_index) segment = self._sorting_segments[segment_index] loc = segment.get_unit_location(unit_id=unit_id) return loc - """ + """ def load_unit_info(self): - + if 'centres' in self._rf.keys() and len(self._spike_times) > 0: self._unit_locs = self._rf['centres'][()] # cache for faster access inds = [] # get these only once @@ -109,7 +109,7 @@ def get_unit_spike_train(self, unit_id, start_frame, end_frame): def get_unit_location(self, unit_id): return self._unit_locs[unit_id] - + """ @staticmethod def write_sorting(sorting, save_path): From ffcdf5a8f1eadc7cd7d8018b553163f375de7632 Mon Sep 17 00:00:00 2001 From: Baptiste Grimaud <83828302+b-grimaud@users.noreply.github.com> Date: Mon, 18 Nov 2024 18:08:23 +0100 Subject: [PATCH 3/3] Expose unit_locations as property --- src/spikeinterface/extractors/herdingspikesextractors.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/extractors/herdingspikesextractors.py b/src/spikeinterface/extractors/herdingspikesextractors.py index 2deb9fc2a8..a83fbbb838 100644 --- a/src/spikeinterface/extractors/herdingspikesextractors.py +++ b/src/spikeinterface/extractors/herdingspikesextractors.py @@ -45,11 +45,13 @@ def __init__(self, file_path, load_unit_info=True): spike_times = self._rf["times"][()] unit_locs = self._rf["centres"][()] + self.unit_locations = unit_locs + # if load_unit_info: # self.load_unit_info() BaseSorting.__init__(self, sampling_frequency, unit_ids) - self.add_sorting_segment(HerdingspikesSortingSegment(unit_ids, spike_times, spike_ids, unit_locs)) + self.add_sorting_segment(HerdingspikesSortingSegment(unit_ids, spike_times, spike_ids, self.unit_locations)) self._kwargs = {"file_path": str(Path(file_path).absolute()), "load_unit_info": load_unit_info} self.extra_requirements.append("h5py")