Skip to content

Commit

Permalink
Fix SourceCombiner.get_iter() not interleaving correctly (#45)
Browse files Browse the repository at this point in the history
* revise test to demonstrate bug

* fix

* patch bump
  • Loading branch information
AlpAribal authored May 18, 2022
1 parent b3b7f3f commit a93f62f
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 9 deletions.
2 changes: 1 addition & 1 deletion squirrel/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.13.1"
__version__ = "0.13.2"
2 changes: 1 addition & 1 deletion squirrel/driver/source_combiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def get_iter(self, subset: Optional[str] = None, **kwargs) -> Composable:
"""
if subset is None:
return IterableSource(
more_itertools.interleave_longest(self.get_iter(subset=k, **kwargs) for k in self.subsets)
more_itertools.interleave_longest(*[self.get_iter(subset=k, **kwargs) for k in self.subsets])
)
return self.get_source(subset).get_driver().get_iter(**kwargs)

Expand Down
23 changes: 16 additions & 7 deletions test/test_driver/test_sourcecombiner.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
from tempfile import TemporaryDirectory

import pandas as pd

from squirrel.catalog import Catalog, CatalogKey
from squirrel.catalog.source import Source
from squirrel.constants import URL


def test_combiner_in_catalog(test_path: URL) -> None:
def test_combiner_in_catalog() -> None:
"""Test if sources can be combined into a source combiner"""
c = Catalog()
tmpdir = TemporaryDirectory()

c["train"] = Source("file", driver_kwargs={"path": test_path + "train.dummy"})
c["val"] = Source("file", driver_kwargs={"path": test_path + "val.dummy"})
c["test"] = Source("file", driver_kwargs={"path": test_path + "test.dummy"})
for split in ("train", "val", "test"):
fname = f"{tmpdir.name}/{split}.csv"
data = pd.DataFrame(dict(split=[split])) # one row, one col df, only value of "split" changes
data.to_csv(fname)
c[split] = Source("csv", driver_kwargs={"path": fname})

c["combined"] = Source(
"source_combiner",
Expand All @@ -24,9 +31,11 @@ def test_combiner_in_catalog(test_path: URL) -> None:

d = c["combined"].get_driver()
assert len(d.subsets) == 3
assert d.get_source("subset1").get_driver().path == test_path + "train.dummy"
assert d.get_source("subset2").get_driver().path == test_path + "val.dummy"
assert d.get_source("subset3").get_driver().path == test_path + "test.dummy"
assert d.get_source("subset1").get_driver().path == f"{tmpdir.name}/train.csv"
assert d.get_source("subset2").get_driver().path == f"{tmpdir.name}/val.csv"
assert d.get_source("subset3").get_driver().path == f"{tmpdir.name}/test.csv"
all_rows = d.get_iter().collect()
assert [r.split for r in all_rows] == ["train", "val", "test"]


def test_copy_combiner(test_path: URL) -> None:
Expand Down

0 comments on commit a93f62f

Please sign in to comment.