Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix on TsdFrame getitem #382

Open
wants to merge 5 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ jobs:
with:
directory: "./doc/_build/html"
# The directory to scan
arguments: --checks Links,Scripts --ignore-urls "https://fonts.gstatic.com,https://mkdocs-gallery.github.io,./doc/_build/html/_static/" --assume-extension --check-external-hash --ignore-status-codes 403 --ignore-files "/.+\/html\/_static\/.+/"
arguments: --checks Links,Scripts --ignore-urls "https://fonts.gstatic.com,https://mkdocs-gallery.github.io,./doc/_build/html/_static/,https://www.nature.com/articles/s41593-022-01020-w" --assume-extension --check-external-hash --ignore-status-codes 403 --ignore-files "/.+\/html\/_static\/.+/"
# The arguments to pass to HTMLProofer


Expand Down
31 changes: 16 additions & 15 deletions pynapple/core/time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1169,23 +1169,24 @@ def __getitem__(self, key, *args, **kwargs):
index = np.array([index])

if all(is_array_like(a) for a in [index, output]):
if (
(len(index) == 1)
and (output.ndim == 1)
and ((len(output) > 1) or isinstance(key[1], (list, np.ndarray)))
):
# reshape output of single index to preserve column axis if there are more than one columns being indexed
# or if column key is a list or array
if isinstance(key, tuple):
if (
len(index) == 1
and output.ndim == 1
and not isinstance(key[1], int)
):
output = output[None, :]
elif (
(output.ndim == 1)
and isinstance(key[1], (list, np.ndarray))
and (len(columns) == 1)
):
# reshape output of single column if column key is a list or array
output = output[:, None]
# if getting a row (1 dim implied)
elif isinstance(key, Number):
output = output[None, :]

elif (
(output.ndim == 1)
and isinstance(key[1], (list, np.ndarray))
and (len(columns) == 1)
):
# reshape output of single column if column key is a list or array
output = output[:, None]

kwargs["columns"] = columns
kwargs["metadata"] = self._metadata.loc[columns]
return _initialize_tsd_output(
Expand Down
8 changes: 8 additions & 0 deletions tests/test_time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,6 +954,7 @@ def test_interpolate_with_ep(self, tsd):
@pytest.mark.parametrize(
"tsdframe",
[
nap.TsdFrame(t=np.arange(100), d=np.random.rand(100, 1), time_units="s"),
nap.TsdFrame(t=np.arange(100), d=np.random.rand(100, 3), time_units="s"),
nap.TsdFrame(
t=np.arange(100),
Expand Down Expand Up @@ -997,6 +998,7 @@ def test_copy(self, tsdframe):
],
)
def test_horizontal_slicing(self, tsdframe, index, nap_type):
index = index if isinstance(index, int) else index[: tsdframe.shape[1]]
assert isinstance(tsdframe[:, index], nap_type)
np.testing.assert_array_almost_equal(
tsdframe[:, index].values, tsdframe.values[:, index]
Expand Down Expand Up @@ -1067,6 +1069,12 @@ def test_vertical_slicing(self, tsdframe, index):
],
)
def test_vert_and_horz_slicing(self, tsdframe, row, col, expected):
if tsdframe.shape[1] == 1:
if isinstance(col, list) and isinstance(col[0], int):
col = [0]
elif isinstance(col, list) and isinstance(col[0], bool):
col = [col[0]]

# get details about row index
row_array = isinstance(row, (list, np.ndarray))
if row_array and isinstance(row[0], (bool, np.bool_)):
Expand Down
Loading