Skip to content

Commit

Permalink
Allow expressions to be shipped to the scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed Dec 6, 2023
1 parent ded3cb8 commit 75200fb
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 13 deletions.
28 changes: 17 additions & 11 deletions dask_expr/_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,16 +110,20 @@ def _wrap_unary_expr_op(self, op=None):
#
# Collection classes
#


class FrameBase(DaskMethodsMixin):
from dask.typing import NewDaskCollection
# Note: subclassing isn't required. This is just for the prototype to have a
# check for abstractmethods but the runtime checks for duck-typing/protocol only
class FrameBase(DaskMethodsMixin, NewDaskCollection):
"""Base class for Expr-backed Collections"""

__dask_scheduler__ = staticmethod(
named_schedulers.get("threads", named_schedulers["sync"])
)
__dask_optimize__ = staticmethod(lambda dsk, keys, **kwargs: dsk)

def __dask_tokenize__(self):
return self.expr._name

def __init__(self, expr):
self._expr = expr

Expand Down Expand Up @@ -177,14 +181,7 @@ def compute(self, fuse=True, combine_similar=True, **kwargs):
return DaskMethodsMixin.compute(out, **kwargs)

def __dask_graph__(self):
out = self.expr
out = out.lower_completely()
return out.__dask_graph__()

def __dask_keys__(self):
out = self.expr
out = out.lower_completely()
return out.__dask_keys__()
return self.expr

def simplify(self):
return new_collection(self.expr.simplify())
Expand All @@ -201,6 +198,15 @@ def optimize(self, combine_similar: bool = True, fuse: bool = True):
def dask(self):
return self.__dask_graph__()

def finalize_compute(self) -> FrameBase:
from ._repartition import RepartitionToFewer
if self.npartitions > 1:
return new_collection(RepartitionToFewer(self.expr, 1))
return self

def postpersist(self, futures: dict) -> NewDaskCollection:
return from_graph(futures, self._meta, self.divisions, self._name)

def __dask_postcompute__(self):
state = new_collection(self.expr.lower_completely())
if type(self) != type(state):
Expand Down
9 changes: 7 additions & 2 deletions dask_expr/_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@
replacement_rules = []


class Expr:
from dask.typing import DaskGraph
# Note: subclassing isn't required. This is just for the prototype to have a
# check for abstractmethods but the runtime checks for duck-typing/protocol only
class Expr(DaskGraph):
"""Primary class for all Expressions
This mostly includes Dask protocols and various Pandas-like method
Expand Down Expand Up @@ -812,8 +815,10 @@ def dtypes(self):
def _meta(self):
raise NotImplementedError()

def __dask_graph__(self):
def _materialize(self):
"""Traverse expression tree, collect layers"""
from distributed.scheduler import ensure_materialization_allowed
ensure_materialization_allowed()
stack = [self]
seen = set()
layers = []
Expand Down

0 comments on commit 75200fb

Please sign in to comment.