diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 8c4f5482..5b0e8983 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -85,7 +85,7 @@ jobs: with: directory: "./doc/_build/html" # The directory to scan - arguments: --checks Links,Scripts --ignore-urls "https://fonts.gstatic.com,https://mkdocs-gallery.github.io,./doc/_build/html/_static/" --assume-extension --check-external-hash --ignore-status-codes 403 --ignore-files "/.+\/html\/_static\/.+/" + arguments: --checks Links,Scripts --ignore-urls "https://fonts.gstatic.com,https://mkdocs-gallery.github.io,./doc/_build/html/_static/,https://www.nature.com/articles/s41593-022-01020-w" --assume-extension --check-external-hash --ignore-status-codes 403 --ignore-files "/.+\/html\/_static\/.+/" # The arguments to pass to HTMLProofer diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index 72d33d5a..fbb3d9ee 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -1169,23 +1169,24 @@ def __getitem__(self, key, *args, **kwargs): index = np.array([index]) if all(is_array_like(a) for a in [index, output]): - if ( - (len(index) == 1) - and (output.ndim == 1) - and ((len(output) > 1) or isinstance(key[1], (list, np.ndarray))) - ): - # reshape output of single index to preserve column axis if there are more than one columns being indexed - # or if column key is a list or array + if isinstance(key, tuple): + if ( + len(index) == 1 + and output.ndim == 1 + and not isinstance(key[1], int) + ): + output = output[None, :] + elif ( + (output.ndim == 1) + and isinstance(key[1], (list, np.ndarray)) + and (len(columns) == 1) + ): + # reshape output of single column if column key is a list or array + output = output[:, None] + # if getting a row (1 dim implied) + elif isinstance(key, Number): output = output[None, :] - elif ( - (output.ndim == 1) - and isinstance(key[1], (list, np.ndarray)) - and (len(columns) == 1) - ): - # reshape output of single column if column key is a list or array - output = output[:, None] - kwargs["columns"] = columns kwargs["metadata"] = self._metadata.loc[columns] return _initialize_tsd_output( diff --git a/tests/test_time_series.py b/tests/test_time_series.py index 0c5e9289..354104ed 100755 --- a/tests/test_time_series.py +++ b/tests/test_time_series.py @@ -954,6 +954,7 @@ def test_interpolate_with_ep(self, tsd): @pytest.mark.parametrize( "tsdframe", [ + nap.TsdFrame(t=np.arange(100), d=np.random.rand(100, 1), time_units="s"), nap.TsdFrame(t=np.arange(100), d=np.random.rand(100, 3), time_units="s"), nap.TsdFrame( t=np.arange(100), @@ -997,6 +998,7 @@ def test_copy(self, tsdframe): ], ) def test_horizontal_slicing(self, tsdframe, index, nap_type): + index = index if isinstance(index, int) else index[: tsdframe.shape[1]] assert isinstance(tsdframe[:, index], nap_type) np.testing.assert_array_almost_equal( tsdframe[:, index].values, tsdframe.values[:, index] @@ -1067,6 +1069,12 @@ def test_vertical_slicing(self, tsdframe, index): ], ) def test_vert_and_horz_slicing(self, tsdframe, row, col, expected): + if tsdframe.shape[1] == 1: + if isinstance(col, list) and isinstance(col[0], int): + col = [0] + elif isinstance(col, list) and isinstance(col[0], bool): + col = [col[0]] + # get details about row index row_array = isinstance(row, (list, np.ndarray)) if row_array and isinstance(row[0], (bool, np.bool_)):