diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index f24d6a75..47bba418 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -2197,6 +2197,7 @@ def optimize(expr: Expr, combine_similar: bool = True, fuse: bool = True) -> Exp result = result.combine_similar() if fuse: + result = optimize_io_fusion(result) result = optimize_blockwise_fusion(result) return result @@ -2239,6 +2240,40 @@ def are_co_aligned(*exprs): ## Utilites for Expr fusion +def optimize_io_fusion(expr): + """Traverse the expression graph and apply fusion to the I/O layer that squashes + partitions together if possible.""" + + def _fusion_pass(expr): + new_operands = [] + changed = False + for operand in expr.operands: + if isinstance(operand, Expr): + if ( + isinstance(operand, BlockwiseIO) + and operand._fusion_compression_factor < 1 + ): + new = FusedIO(operand) + elif isinstance(operand, BlockwiseIO): + new = operand + else: + new = _fusion_pass(operand) + + if new._name != operand._name: + changed = True + else: + new = operand + new_operands.append(new) + + if changed: + expr = type(expr)(*new_operands) + + return expr + + expr = _fusion_pass(expr) + return expr + + def optimize_blockwise_fusion(expr): """Traverse the expression graph and apply fusion""" @@ -2472,4 +2507,4 @@ def _execute_task(graph, name, *deps): Sum, Var, ) -from dask_expr.io import IO, BlockwiseIO +from dask_expr.io import IO, BlockwiseIO, FusedIO diff --git a/dask_expr/io/io.py b/dask_expr/io/io.py index 954da39f..876a2278 100644 --- a/dask_expr/io/io.py +++ b/dask_expr/io/io.py @@ -3,6 +3,7 @@ import functools import math +from dask.dataframe import methods from dask.dataframe.core import is_dataframe_like from dask.dataframe.io.io import sorted_division_locations @@ -50,6 +51,10 @@ def _layer(self): class BlockwiseIO(Blockwise, IO): _absorb_projections = False + @functools.cached_property + def _fusion_compression_factor(self): + return 1 + def _simplify_up(self, parent): if ( self._absorb_projections @@ -121,6 +126,37 @@ def _combine_similar(self, root: Expr): return +class FusedIO(BlockwiseIO): + _parameters = ["expr"] + + @functools.cached_property + def _meta(self): + return self.operand("expr")._meta + + @functools.cached_property + def npartitions(self): + return len(self._fusion_buckets) + + def _divisions(self): + divisions = self.operand("expr")._divisions() + new_divisions = [divisions[b[0]] for b in self._fusion_buckets] + new_divisions.append(self._fusion_buckets[-1][-1]) + return tuple(new_divisions) + + def _task(self, index: int): + expr = self.operand("expr") + bucket = self._fusion_buckets[index] + return (methods.concat, [expr._filtered_task(i) for i in bucket]) + + @functools.cached_property + def _fusion_buckets(self): + step = math.ceil(1 / self.operand("expr")._fusion_compression_factor) + partitions = self.operand("expr")._partitions + npartitions = len(partitions) + buckets = [partitions[i : i + step] for i in range(0, npartitions, step)] + return buckets + + class FromPandas(PartitionsFiltered, BlockwiseIO): """The only way today to get a real dataframe""" diff --git a/dask_expr/io/parquet.py b/dask_expr/io/parquet.py index 4bdbd2c6..0ffb0ecf 100644 --- a/dask_expr/io/parquet.py +++ b/dask_expr/io/parquet.py @@ -1,6 +1,7 @@ from __future__ import annotations import contextlib +import functools import itertools import operator import warnings @@ -642,6 +643,13 @@ def _update_length_statistics(self): stat["num-rows"] for stat in _collect_pq_statistics(self) ) + @functools.cached_property + def _fusion_compression_factor(self): + if self.operand("columns") is None: + return 1 + nr_original_columns = len(self._dataset_info["schema"].names) - 1 + return len(_convert_to_list(self.operand("columns"))) / nr_original_columns + # # Helper functions diff --git a/dask_expr/io/tests/test_io.py b/dask_expr/io/tests/test_io.py index ca00d906..539c6ec3 100644 --- a/dask_expr/io/tests/test_io.py +++ b/dask_expr/io/tests/test_io.py @@ -164,6 +164,13 @@ def test_predicate_pushdown_compound(tmpdir): assert_eq(y, z) +def test_io_fusion_blockwise(tmpdir): + pdf = lib.DataFrame({c: range(10) for c in "abcdefghijklmn"}) + dd.from_pandas(pdf, 2).to_parquet(tmpdir) + df = read_parquet(tmpdir)["a"].fillna(10).optimize() + assert df.npartitions == 1 + + @pytest.mark.parametrize("fmt", ["parquet", "csv", "pandas"]) def test_io_culling(tmpdir, fmt): pdf = lib.DataFrame({c: range(10) for c in "abcde"})