Skip to content

Commit

Permalink
Implement .where (#453)
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikmakait authored Dec 5, 2023
1 parent 37f24cb commit 9779dcc
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 0 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ API Coverage
- `map`
- `map_overlap`
- `map_partitions`
- `mask`
- `max`
- `mean`
- `memory_usage`
Expand Down Expand Up @@ -130,6 +131,7 @@ API Coverage
- `to_timestamp`
- `var`
- `visualize`
- `where`


**`dask_expr.Series`**
Expand Down Expand Up @@ -160,6 +162,7 @@ API Coverage
- `isna`
- `map`
- `map_partitions`
- `mask`
- `max`
- `mean`
- `memory_usage`
Expand Down Expand Up @@ -188,6 +191,7 @@ API Coverage
- `value_counts`
- `var`
- `visualize`
- `where`


**`dask_expr.Index`**
Expand Down
5 changes: 5 additions & 0 deletions dask_expr/_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,11 @@ def mask(self, cond, other=np.nan):
def round(self, decimals=0):
return new_collection(self.expr.round(decimals))

def where(self, cond, other=np.nan):
cond = cond.expr if isinstance(cond, FrameBase) else cond
other = other.expr if isinstance(other, FrameBase) else other
return new_collection(self.expr.where(cond, other))

def apply(self, function, *args, **kwargs):
return new_collection(self.expr.apply(function, *args, **kwargs))

Expand Down
10 changes: 10 additions & 0 deletions dask_expr/_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,9 @@ def mask(self, cond, other=np.nan):
def round(self, decimals=0):
return Round(self, decimals=decimals)

def where(self, cond, other=np.nan):
return Where(self, cond=cond, other=other)

def apply(self, function, *args, **kwargs):
return Apply(self, function, args, kwargs)

Expand Down Expand Up @@ -1766,6 +1769,13 @@ class Round(Elemwise):
operation = M.round


class Where(Elemwise):
_projection_passthrough = True
_parameters = ["frame", "cond", "other"]
_defaults = {"other": np.nan}
operation = M.where


class Abs(Elemwise):
_projection_passthrough = True
_parameters = ["frame"]
Expand Down
6 changes: 6 additions & 0 deletions dask_expr/tests/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,12 @@ def test_to_timestamp(pdf, how):
lambda df: df.x.mask(df.x == 10, 42),
lambda df: df.abs(),
lambda df: df.x.abs(),
lambda df: df.where(df.x == 10, 42),
lambda df: df.where(df.x == 10),
lambda df: df.where(lambda df: df.x % 2 == 0, 42),
lambda df: df.where(df.x == 10, df + 2),
lambda df: df.where(df.x == 10, lambda df: df + 2),
lambda df: df.x.where(df.x == 10, 42),
lambda df: df.rename(columns={"x": "xx"}),
lambda df: df.rename(columns={"x": "xx"}).xx,
lambda df: df.rename(columns={"x": "xx"})[["xx"]],
Expand Down

0 comments on commit 9779dcc

Please sign in to comment.