diff --git a/dask_expr/_collection.py b/dask_expr/_collection.py index 99f637e2e..93b2bbc1a 100644 --- a/dask_expr/_collection.py +++ b/dask_expr/_collection.py @@ -110,9 +110,10 @@ def _wrap_unary_expr_op(self, op=None): # # Collection classes # - - -class FrameBase(DaskMethodsMixin): +from dask.typing import NewDaskCollection +# Note: subclassing isn't required. This is just for the prototype to have a +# check for abstractmethods but the runtime checks for duck-typing/protocol only +class FrameBase(DaskMethodsMixin, NewDaskCollection): """Base class for Expr-backed Collections""" __dask_scheduler__ = staticmethod( @@ -120,6 +121,9 @@ class FrameBase(DaskMethodsMixin): ) __dask_optimize__ = staticmethod(lambda dsk, keys, **kwargs: dsk) + def __dask_tokenize__(self): + return self.expr._name + def __init__(self, expr): self._expr = expr @@ -177,14 +181,7 @@ def compute(self, fuse=True, combine_similar=True, **kwargs): return DaskMethodsMixin.compute(out, **kwargs) def __dask_graph__(self): - out = self.expr - out = out.lower_completely() - return out.__dask_graph__() - - def __dask_keys__(self): - out = self.expr - out = out.lower_completely() - return out.__dask_keys__() + return self.expr def simplify(self): return new_collection(self.expr.simplify()) @@ -201,6 +198,15 @@ def optimize(self, combine_similar: bool = True, fuse: bool = True): def dask(self): return self.__dask_graph__() + def finalize_compute(self) -> FrameBase: + from ._repartition import RepartitionToFewer + if self.npartitions > 1: + return new_collection(RepartitionToFewer(self.expr, 1)) + return self + + def postpersist(self, futures: dict) -> NewDaskCollection: + return from_graph(futures, self._meta, self.divisions, self._name) + def __dask_postcompute__(self): state = new_collection(self.expr.lower_completely()) if type(self) != type(state): diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index a737fa677..ac70517f3 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -36,7 +36,10 @@ replacement_rules = [] -class Expr: +from dask.typing import DaskGraph +# Note: subclassing isn't required. This is just for the prototype to have a +# check for abstractmethods but the runtime checks for duck-typing/protocol only +class Expr(DaskGraph): """Primary class for all Expressions This mostly includes Dask protocols and various Pandas-like method @@ -812,8 +815,10 @@ def dtypes(self): def _meta(self): raise NotImplementedError() - def __dask_graph__(self): + def _materialize(self): """Traverse expression tree, collect layers""" + from distributed.scheduler import ensure_materialization_allowed + ensure_materialization_allowed() stack = [self] seen = set() layers = []