Skip to content

Commit

Permalink
Update to keep track of refs
Browse files Browse the repository at this point in the history
  • Loading branch information
phofl committed Feb 15, 2024
1 parent babb1ee commit 97029db
Showing 1 changed file with 33 additions and 20 deletions.
53 changes: 33 additions & 20 deletions dask_expr/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class Expr:
_parameters = []
_defaults = {}
_instances = weakref.WeakValueDictionary()
_dependents = defaultdict(list)
_seen = set()

def __new__(cls, *args, **kwargs):
operands = list(args)
Expand Down Expand Up @@ -68,16 +70,41 @@ def __new__(cls, *args, **kwargs):
inst._graph[_name].update(children)
# Probably a bad idea to have a self ref
inst._graph_instances[_name] = inst

else:
Expr._instances[_name] = inst
inst._graph_instances = merge(_graph_instances, *_subgraph_instances)
inst._graph = merge(*_subgraphs)
inst._graph[_name] = children
# Probably a bad idea to have a self ref
inst._graph_instances[_name] = inst

if inst._name in Expr._seen:
# We already registered inst as a dependent of all it's
# dependencies, so we don't need to do it again
return inst

Expr._instances[_name] = inst
inst._graph_instances = merge(_graph_instances, *_subgraph_instances)
inst._graph = merge(*_subgraphs)
inst._graph[_name] = children
# Probably a bad idea to have a self ref
inst._graph_instances[_name] = inst
Expr._seen.add(inst._name)
for dep in inst.dependencies():
Expr._dependents[dep._name].append(inst)

return inst

@functools.cached_property
def _dependent_graph(self):
# Reset to clear tracking
Expr._dependents = defaultdict(list)
Expr._seen = set()
rv = Expr._dependents
# This should be O(E)
tmp = defaultdict(set)
for expr, dependencies in self._graph.items():
for dep in dependencies:
tmp[dep].add(expr)
for name, exprs in tmp.items():
rv[name] = [self._graph_instances[e] for e in exprs]
return rv

def __hash__(self):
raise TypeError(
"Expr objects can't be used in sets or dicts or similar, use the _name instead"
Expand Down Expand Up @@ -186,18 +213,6 @@ def dependencies(self):
# Dependencies are `Expr` operands only
return [operand for operand in self.operands if isinstance(operand, Expr)]

@functools.cached_property
def _dependent_graph(self):
rv = defaultdict(set)
# This should be O(E)
for expr, dependencies in self._graph.items():
rv[expr]
for dep in dependencies:
rv[dep].add(expr)
for name, exprs in rv.items():
rv[name] = {self._graph_instances[e] for e in exprs}
return rv

def dependents(self):
return self._dependent_graph

Expand Down Expand Up @@ -369,8 +384,6 @@ def simplify_once(self, dependents: defaultdict, simplified: dict):
changed = False
for operand in expr.operands:
if isinstance(operand, Expr):
# # Bandaid for now, waiting for Singleton
dependents[operand._name].add(expr)
new = operand.simplify_once(
dependents=dependents, simplified=simplified
)
Expand Down

0 comments on commit 97029db

Please sign in to comment.