Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed Nov 11, 2024
1 parent 2c9e45b commit 6b6a576
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 14 deletions.
39 changes: 27 additions & 12 deletions dask_expr/_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3766,28 +3766,43 @@ def _broadcast_dep(self, dep: Expr):
return dep.npartitions == 1

def _task(self, name: Key, index: int) -> Task:
subgraphs = {}
internal_tasks = []
seen_keys = set()
external_deps = set()
for _expr in self.exprs:
if self._broadcast_dep(_expr):
subname = (_expr._name, 0)
else:
subname = (_expr._name, index)
subgraphs[subname] = _expr._task(subname, subname[1])

for i, dep in enumerate(self.dependencies()):
subgraphs[self._blockwise_arg(dep, index)] = "_" + str(i)

return Task(
t = _expr._task(subname, subname[1])
assert t.key == subname
internal_tasks.append(t)
seen_keys.add(subname)
external_deps.update(t.dependencies)
external_deps -= seen_keys
dependencies = {
dep: TaskRef(dep)
# Note: the method dependencies isn't strictly needed here. We could
# also get the list of external dependencies when iterating over the
# subgraphs above.
for dep in external_deps
}
t = Task(
name,
Fused._execute_subgraph,
DataNode(None, subgraphs),
Fused._execute_internal_graph,
# Wrap the actual subgraph as a data node such that the tasks are
# not erroneously parsed. The external task would otherwise carry
# the internal keys as dependencies which is satisfiable
DataNode(None, internal_tasks),
dependencies,
(self.exprs[0]._name, index),
)
return t

@staticmethod
def _execute_subgraph(dsk, outkey):
dsk = dict(dsk)
res = execute_graph(dsk, keys=[outkey])
def _execute_internal_graph(internal_tasks, dependencies, outkey):
cache = dict(dependencies)
res = execute_graph(internal_tasks, cache=cache, keys=[outkey])
return res[outkey]


Expand Down
7 changes: 6 additions & 1 deletion dask_expr/_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,12 @@ def _layer_cache(self):
return convert_legacy_graph(self._layer())

def _task(self, name: Key, index: int) -> Task:
return self._layer_cache[(self._name, index)]
t = self._layer_cache[(self._name, index)]
if isinstance(t, Alias):
return Alias(name, t.target)
elif t.key != name:
return Task(name, lambda x: x, t)
return t


class LocUnknown(Blockwise):
Expand Down
5 changes: 4 additions & 1 deletion dask_expr/io/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,10 @@ def _tasks(self):
def _filtered_task(self, name: Key, index: int) -> Task:
if self._series:
return Task(name, operator.getitem, self._tasks[index], self.columns[0])
return self._tasks[index]
t = self._tasks[index]
if t.key != name:
return Task(name, lambda x: x, t)
return t


class ReadTable(ReadCSV):
Expand Down

0 comments on commit 6b6a576

Please sign in to comment.