Skip to content

Commit

Permalink
Fix multilevel groupby with task shuffle (#599)
Browse files Browse the repository at this point in the history
  • Loading branch information
phofl authored Dec 19, 2023
1 parent 82f58e6 commit 6d33d2d
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 1 deletion.
11 changes: 11 additions & 0 deletions dask_expr/_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,23 @@ def by(self):
def levels(self):
return _determine_levels(self.by)

@property
def shuffle_by_index(self):
return True


class GroupByChunk(Chunk, GroupByBase):
@functools.cached_property
def _args(self) -> list:
return [self.frame] + self.by

@functools.cached_property
def _meta(self):
args = [
meta_nonempty(op._meta) if isinstance(op, Expr) else op for op in self._args
]
return make_meta(self.operation(*args, **self._kwargs))


class GroupByApplyConcatApply(ApplyConcatApply, GroupByBase):
_chunk_cls = GroupByChunk
Expand Down
6 changes: 6 additions & 0 deletions dask_expr/_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,13 @@ class ShuffleReduce(Expr):
"split_every",
"split_out",
"sort",
"shuffle_by_index",
]
_defaults = {
"split_every": 8,
"split_out": True,
"sort": None,
"shuffle_by_index": None,
}

@property
Expand Down Expand Up @@ -182,6 +184,8 @@ def _lower(self):
# Sort or shuffle
split_every = getattr(self, "split_every", 0) or chunked.npartitions
ignore_index = getattr(self, "ignore_index", True)
if self.shuffle_by_index is not None:
ignore_index = not self.shuffle_by_index
shuffle_npartitions = max(
chunked.npartitions // split_every,
self.split_out,
Expand All @@ -199,6 +203,7 @@ def _lower(self):
split_by,
shuffle_npartitions,
ignore_index=ignore_index,
index_shuffle=not split_by_index and self.shuffle_by_index,
)

# Unmap column names if necessary
Expand Down Expand Up @@ -464,6 +469,7 @@ def _lower(self):
split_out=self.split_out,
split_every=split_every,
sort=sort,
shuffle_by_index=getattr(self, "shuffle_by_index", None),
)


Expand Down
3 changes: 2 additions & 1 deletion dask_expr/_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,8 +637,9 @@ def operation(df, index, name: str, npartitions: int, assign_index):
if assign_index:
# columns take precedence over index in _select_columns_or_index, so
# circumvent that, to_frame doesn't work because it loses the index
names = index
index = df[[]]
index["_index"] = df.index
index[names] = df.index.to_frame()
else:
index = _select_columns_or_index(df, index)

Expand Down
10 changes: 10 additions & 0 deletions dask_expr/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,16 @@ def test_groupby_group_keys(group_keys, pdf):
)


def test_dataframe_aggregations_multilevel(df, pdf):
grouper = lambda df: [df["x"] > 2, df["y"] > 1]

with dask.config.set({"dataframe.shuffle.method": "tasks"}):
assert_eq(
pdf.groupby(grouper(pdf)).sum(),
df.groupby(grouper(df)).sum(split_out=2),
)


@pytest.mark.parametrize(
"spec",
[
Expand Down

0 comments on commit 6d33d2d

Please sign in to comment.