Skip to content

Commit

Permalink
Adding few changes for metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
gviejo committed Dec 4, 2024
1 parent 2eea6bc commit 4ea5d9f
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 13 deletions.
16 changes: 9 additions & 7 deletions pynapple/core/time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1218,11 +1218,15 @@ def __getitem__(self, key, *args, **kwargs):
"When indexing with a Tsd, it must contain boolean values"
)
key = key.d
elif (
isinstance(key, str)
or hasattr(key, "__iter__")
and all([isinstance(k, str) for k in key])
):
elif isinstance(key, str):
if key in self.columns:
with warnings.catch_warnings():
# ignore deprecated warning for loc
warnings.simplefilter("ignore")
return self.loc[key]
else:
return _MetadataMixin.__getitem__(self, key)
elif hasattr(key, "__iter__") and all([isinstance(k, str) for k in key]):
if all(k in self.columns for k in key):
with warnings.catch_warnings():
# ignore deprecated warning for loc
Expand Down Expand Up @@ -1272,8 +1276,6 @@ def __getitem__(self, key, *args, **kwargs):
return _get_class(output)(
t=index, d=output, time_support=self.time_support, **kwargs
)
# else:
# return output
else:
return output

Expand Down
16 changes: 11 additions & 5 deletions pynapple/core/ts_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def __getitem__(self, key):
return _MetadataMixin.__getitem__(self, key)
else:
raise KeyError(r"Key {} not in group index.".format(key))
elif isinstance(key, list) and all(isinstance(k, str) for k in key):
elif isinstance(key, list) and len(key) and all(isinstance(k, str) for k in key):
# index multiple metadata columns
return _MetadataMixin.__getitem__(self, key)

Expand Down Expand Up @@ -389,11 +389,17 @@ def __repr__(self):
max_rows = np.maximum(rows - 10, 2)

# By default, the first three columns should always show.
col_names = self._metadata.columns.drop("rate")
col_names = self._metadata.columns
if "rate" in col_names:
col_names = col_names.drop("rate")

headers = ["Index", "rate"] + [c for c in col_names][0:max_cols]
end = ["..."] if len(headers) > max_cols else []
headers += end

if len(self) == 0:
return tabulate(tabular_data=[], headers=headers)

if len(self) > max_rows:
n_rows = max_rows // 2
ends = np.array([end] * n_rows)
Expand Down Expand Up @@ -530,7 +536,7 @@ def restrict(self, ep):
cols = self._metadata.columns.drop("rate")

return TsGroup(
newgr, time_support=ep, bypass_check=True, **self._metadata[cols]
newgr, time_support=ep, bypass_check=True, metadata=self._metadata[cols]
)

def value_from(self, tsd, ep=None):
Expand Down Expand Up @@ -576,7 +582,7 @@ def value_from(self, tsd, ep=None):
newgr[k] = self.data[k].value_from(tsd, ep)

cols = self._metadata.columns.drop("rate")
return TsGroup(newgr, time_support=ep, **self._metadata[cols])
return TsGroup(newgr, time_support=ep, metadata=self._metadata[cols])

def count(self, *args, dtype=None, **kwargs):
"""
Expand Down Expand Up @@ -867,7 +873,7 @@ def get(self, start, end=None, time_units="s"):
newgr,
time_support=self.time_support,
bypass_check=True,
**self._metadata[cols],
metadata=self._metadata[cols],
)

#################################
Expand Down
2 changes: 1 addition & 1 deletion pynapple/io/interface_nwb.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def _make_tsgroup(obj, **kwargs):
else:
pass

tsgroup = nap.TsGroup(tsgroup, **metainfo)
tsgroup = nap.TsGroup(tsgroup, metadata=metainfo)

return tsgroup

Expand Down
1 change: 1 addition & 0 deletions pynapple/process/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def decode_1d(tuning_curves, group, ep, bin_size, time_units="s", feature=None):

# Bin spikes
count = newgroup.count(bin_size, ep, time_units)
count = count.smooth(bin_size*3)

# Occupancy
if feature is None:
Expand Down

0 comments on commit 4ea5d9f

Please sign in to comment.