Skip to content

Commit

Permalink
Merge branch 'chdir' into dependents_wout_weakrefs
Browse files Browse the repository at this point in the history
  • Loading branch information
phofl committed Feb 14, 2024
2 parents 1faa475 + bda2577 commit babb1ee
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 3 deletions.
19 changes: 19 additions & 0 deletions dask_expr/_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import datetime
import functools
import inspect
import os
import warnings
from collections.abc import Callable, Hashable, Mapping
from numbers import Integral, Number
Expand Down Expand Up @@ -4526,6 +4527,7 @@ def read_csv(
storage_options=storage_options,
kwargs=kwargs,
header=header,
_cwd=_get_cwd(path, kwargs),
)
)

Expand All @@ -4551,6 +4553,7 @@ def read_table(
storage_options=storage_options,
kwargs=kwargs,
header=header,
_cwd=_get_cwd(path, kwargs),
)
)

Expand All @@ -4576,10 +4579,25 @@ def read_fwf(
storage_options=storage_options,
kwargs=kwargs,
header=header,
_cwd=_get_cwd(path, kwargs),
)
)


def _get_protocol(urlpath):
if "://" in urlpath:
protocol, _ = urlpath.split("://", 1)
if len(protocol) > 1:
# excludes Windows paths
return protocol
return None


def _get_cwd(path, kwargs):
protocol = kwargs.pop("protocol", None) or _get_protocol(path) or "file"
return os.getcwd() if protocol == "file" else None


def read_parquet(
path=None,
columns=None,
Expand Down Expand Up @@ -4630,6 +4648,7 @@ def read_parquet(
filesystem=filesystem,
engine=_set_parquet_engine(engine),
kwargs=kwargs,
_cwd=_get_cwd(path, kwargs),
_series=isinstance(columns, str),
)
)
Expand Down
2 changes: 2 additions & 0 deletions dask_expr/io/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class ReadCSV(PartitionsFiltered, BlockwiseIO):
"_partitions",
"storage_options",
"kwargs",
"_cwd", # needed for tokenization
"_series",
]
_defaults = {
Expand All @@ -24,6 +25,7 @@ class ReadCSV(PartitionsFiltered, BlockwiseIO):
"_partitions": None,
"storage_options": None,
"_series": False,
"_cwd": None,
}
_absorb_projections = True

Expand Down
2 changes: 2 additions & 0 deletions dask_expr/io/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,7 @@ class ReadParquet(PartitionsFiltered, BlockwiseIO):
"filesystem",
"engine",
"kwargs",
"_cwd", # needed for tokenization
"_partitions",
"_series",
"_dataset_info_cache",
Expand All @@ -449,6 +450,7 @@ class ReadParquet(PartitionsFiltered, BlockwiseIO):
"_partitions": None,
"_series": False,
"_dataset_info_cache": None,
"_cwd": None,
}
_pq_length_stats = None
_absorb_projections = True
Expand Down
34 changes: 31 additions & 3 deletions dask_expr/io/tests/test_io.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import glob
import os
from pathlib import Path

import dask.array as da
import dask.dataframe as dd
Expand Down Expand Up @@ -30,14 +31,14 @@
pd = _backend_library()


def _make_file(dir, format="parquet", df=None):
def _make_file(dir, format="parquet", df=None, **kwargs):
fn = os.path.join(str(dir), f"myfile.{format}")
if df is None:
df = pd.DataFrame({c: range(10) for c in "abcde"})
if format == "csv":
df.to_csv(fn)
df.to_csv(fn, **kwargs)
elif format == "parquet":
df.to_parquet(fn)
df.to_parquet(fn, **kwargs)
else:
ValueError(f"{format} not a supported format")
return fn
Expand Down Expand Up @@ -413,6 +414,33 @@ def test_combine_similar_no_projection_on_one_branch(tmpdir):
assert_eq(df, pdf)


@pytest.mark.parametrize(
"fmt, func, kwargs",
[
("parquet", read_parquet, {}),
("csv", read_csv, {"index": False}),
],
)
def test_chdir_different_files(tmpdir, fmt, func, kwargs):
cwd = os.getcwd()

try:
pdf = pd.DataFrame({"x": [0, 1, 2, 3] * 4, "y": range(16)})
os.chdir(tmpdir)
_make_file(tmpdir, format=fmt, df=pdf, **kwargs)
df = func(f"myfile.{fmt}")

new_dir = Path(tmpdir).joinpath("new_dir")
new_dir.mkdir()
os.chdir(new_dir)
pdf2 = pd.DataFrame({"x": [0, 100, 200, 300] * 4, "y": range(16)})
_make_file(new_dir, format=fmt, df=pdf2, **kwargs)
df2 = func(f"myfile.{fmt}")
assert_eq(df.sum() + df2.sum(), pd.Series([2424, 240], index=["x", "y"]))
finally:
os.chdir(cwd)


@pytest.mark.parametrize("meta", [True, False])
@pytest.mark.parametrize("label", [None, "foo"])
@pytest.mark.parametrize("allow_projection", [True, False])
Expand Down

0 comments on commit babb1ee

Please sign in to comment.