From 4ea5d9fa658c1cf52a482460681cd12275da3a14 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Wed, 4 Dec 2024 10:34:03 -0500 Subject: [PATCH] Adding few changes for metadata --- pynapple/core/time_series.py | 16 +++++++++------- pynapple/core/ts_group.py | 16 +++++++++++----- pynapple/io/interface_nwb.py | 2 +- pynapple/process/decoding.py | 1 + 4 files changed, 22 insertions(+), 13 deletions(-) diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index 8f3faecc..3b9dbde0 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -1218,11 +1218,15 @@ def __getitem__(self, key, *args, **kwargs): "When indexing with a Tsd, it must contain boolean values" ) key = key.d - elif ( - isinstance(key, str) - or hasattr(key, "__iter__") - and all([isinstance(k, str) for k in key]) - ): + elif isinstance(key, str): + if key in self.columns: + with warnings.catch_warnings(): + # ignore deprecated warning for loc + warnings.simplefilter("ignore") + return self.loc[key] + else: + return _MetadataMixin.__getitem__(self, key) + elif hasattr(key, "__iter__") and all([isinstance(k, str) for k in key]): if all(k in self.columns for k in key): with warnings.catch_warnings(): # ignore deprecated warning for loc @@ -1272,8 +1276,6 @@ def __getitem__(self, key, *args, **kwargs): return _get_class(output)( t=index, d=output, time_support=self.time_support, **kwargs ) - # else: - # return output else: return output diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index 47e4307b..608b9a97 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -345,7 +345,7 @@ def __getitem__(self, key): return _MetadataMixin.__getitem__(self, key) else: raise KeyError(r"Key {} not in group index.".format(key)) - elif isinstance(key, list) and all(isinstance(k, str) for k in key): + elif isinstance(key, list) and len(key) and all(isinstance(k, str) for k in key): # index multiple metadata columns return _MetadataMixin.__getitem__(self, key) @@ -389,11 +389,17 @@ def __repr__(self): max_rows = np.maximum(rows - 10, 2) # By default, the first three columns should always show. - col_names = self._metadata.columns.drop("rate") + col_names = self._metadata.columns + if "rate" in col_names: + col_names = col_names.drop("rate") + headers = ["Index", "rate"] + [c for c in col_names][0:max_cols] end = ["..."] if len(headers) > max_cols else [] headers += end + if len(self) == 0: + return tabulate(tabular_data=[], headers=headers) + if len(self) > max_rows: n_rows = max_rows // 2 ends = np.array([end] * n_rows) @@ -530,7 +536,7 @@ def restrict(self, ep): cols = self._metadata.columns.drop("rate") return TsGroup( - newgr, time_support=ep, bypass_check=True, **self._metadata[cols] + newgr, time_support=ep, bypass_check=True, metadata=self._metadata[cols] ) def value_from(self, tsd, ep=None): @@ -576,7 +582,7 @@ def value_from(self, tsd, ep=None): newgr[k] = self.data[k].value_from(tsd, ep) cols = self._metadata.columns.drop("rate") - return TsGroup(newgr, time_support=ep, **self._metadata[cols]) + return TsGroup(newgr, time_support=ep, metadata=self._metadata[cols]) def count(self, *args, dtype=None, **kwargs): """ @@ -867,7 +873,7 @@ def get(self, start, end=None, time_units="s"): newgr, time_support=self.time_support, bypass_check=True, - **self._metadata[cols], + metadata=self._metadata[cols], ) ################################# diff --git a/pynapple/io/interface_nwb.py b/pynapple/io/interface_nwb.py index 7fc30f33..e2388920 100644 --- a/pynapple/io/interface_nwb.py +++ b/pynapple/io/interface_nwb.py @@ -279,7 +279,7 @@ def _make_tsgroup(obj, **kwargs): else: pass - tsgroup = nap.TsGroup(tsgroup, **metainfo) + tsgroup = nap.TsGroup(tsgroup, metadata=metainfo) return tsgroup diff --git a/pynapple/process/decoding.py b/pynapple/process/decoding.py index 13900d3e..2dbdb1d6 100644 --- a/pynapple/process/decoding.py +++ b/pynapple/process/decoding.py @@ -63,6 +63,7 @@ def decode_1d(tuning_curves, group, ep, bin_size, time_units="s", feature=None): # Bin spikes count = newgroup.count(bin_size, ep, time_units) + count = count.smooth(bin_size*3) # Occupancy if feature is None: