Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed Dec 6, 2023
1 parent 75200fb commit 2aff647
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 19 deletions.
17 changes: 5 additions & 12 deletions dask_expr/_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,9 @@ def _wrap_unary_expr_op(self, op=None):
#
# Collection classes
#
from dask.typing import NewDaskCollection
# Note: subclassing isn't required. This is just for the prototype to have a
# check for abstractmethods but the runtime checks for duck-typing/protocol only
class FrameBase(DaskMethodsMixin, NewDaskCollection):


class FrameBase(DaskMethodsMixin):
"""Base class for Expr-backed Collections"""

__dask_scheduler__ = staticmethod(
Expand Down Expand Up @@ -198,14 +197,8 @@ def optimize(self, combine_similar: bool = True, fuse: bool = True):
def dask(self):
return self.__dask_graph__()

def finalize_compute(self) -> FrameBase:
from ._repartition import RepartitionToFewer
if self.npartitions > 1:
return new_collection(RepartitionToFewer(self.expr, 1))
return self

def postpersist(self, futures: dict) -> NewDaskCollection:
return from_graph(futures, self._meta, self.divisions, self._name)
def finalize_compute(self):
return new_collection(Repartition(self.expr, 1))

def __dask_postcompute__(self):
state = new_collection(self.expr.lower_completely())
Expand Down
9 changes: 2 additions & 7 deletions dask_expr/_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,7 @@
replacement_rules = []


from dask.typing import DaskGraph
# Note: subclassing isn't required. This is just for the prototype to have a
# check for abstractmethods but the runtime checks for duck-typing/protocol only
class Expr(DaskGraph):
class Expr:
"""Primary class for all Expressions
This mostly includes Dask protocols and various Pandas-like method
Expand Down Expand Up @@ -815,10 +812,8 @@ def dtypes(self):
def _meta(self):
raise NotImplementedError()

def _materialize(self):
def materialize(self):
"""Traverse expression tree, collect layers"""
from distributed.scheduler import ensure_materialization_allowed
ensure_materialization_allowed()
stack = [self]
seen = set()
layers = []
Expand Down

0 comments on commit 2aff647

Please sign in to comment.