Skip to content

Commit

Permalink
Fix fusion calling things multiple times
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed Nov 11, 2024
1 parent 2205ad8 commit 2c9e45b
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
16 changes: 13 additions & 3 deletions dask_expr/_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import numpy as np
import pandas as pd
from dask._task_spec import Alias, DataNode, Task, TaskRef
from dask._task_spec import Alias, DataNode, Task, TaskRef, execute_graph
from dask.array import Array
from dask.core import flatten
from dask.dataframe import methods
Expand Down Expand Up @@ -3777,8 +3777,18 @@ def _task(self, name: Key, index: int) -> Task:
for i, dep in enumerate(self.dependencies()):
subgraphs[self._blockwise_arg(dep, index)] = "_" + str(i)

result = subgraphs.pop((self.exprs[0]._name, index))
return result.inline(subgraphs)
return Task(
name,
Fused._execute_subgraph,
DataNode(None, subgraphs),
(self.exprs[0]._name, index),
)

@staticmethod
def _execute_subgraph(dsk, outkey):
dsk = dict(dsk)
res = execute_graph(dsk, keys=[outkey])
return res[outkey]


# Used for sorting with None
Expand Down
15 changes: 15 additions & 0 deletions dask_expr/tests/test_fusion.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dask.dataframe as dd
import pytest

from dask_expr import from_pandas, optimize
Expand Down Expand Up @@ -128,3 +129,17 @@ def test_name(df):
assert "getitem" in str(fused.expr)
assert "sub" in str(fused.expr)
assert str(fused.expr) == str(fused.expr).lower()


def test_fusion_executes_only_once():
times_called = []
import pandas as pd

def test(i):
times_called.append(i)
return pd.DataFrame({"a": [1, 2, 3], "b": 1})

df = dd.from_map(test, [1], meta=[("a", "i8"), ("b", "i8")])
df = df[df.a > 1]
df.sum().compute()
assert len(times_called) == 1

0 comments on commit 2c9e45b

Please sign in to comment.