Skip to content

Commit

Permalink
Add list of iterables to recursive env check of torch iterable (#171)
Browse files Browse the repository at this point in the history
* Add list of iterables to recursive env check of torch iterable

* Add IterableSampleSource to tests
  • Loading branch information
jotterbach authored Jan 4, 2024
1 parent 0ed5733 commit cfcefd6
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 8 deletions.
4 changes: 2 additions & 2 deletions squirrel/iterstream/multiplexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,12 @@ def __init__(
Note that the algorithm stops whenever max_reinits are hit or all composables
have been reinitialized at least once.
"""
super().__init__()
super().__init__(source=composables)
self.mux_strategy = mux_strategy
if mux_strategy == MultiplexingStrategy.SAMPLING:
assert sampling_probas is not None
assert len(sampling_probas) == len(
composables
self.source
), "Need sampling probas and composables to have same number of entries"

self.composables, self.sampling_probas = self._init_composables_and_probas(
Expand Down
12 changes: 9 additions & 3 deletions squirrel/iterstream/torch_composables.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
from torch.utils.data import IterableDataset

from squirrel.iterstream.base import Composable
from squirrel.iterstream import Composable, Multiplexer, IterableSamplerSource
from squirrel.framework.exceptions import PyTorchSplittingException

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -90,7 +90,10 @@ def _contains_rank_split(self, source: Composable) -> bool:
elif not isinstance(source, Composable):
return False
else:
return self._contains_rank_split(source.source)
if isinstance(source, (Multiplexer, IterableSamplerSource)):
return all(self._contains_rank_split(src) for src in source.source)
else:
return self._contains_rank_split(source.source)

def _contains_worker_split(self, source: Composable) -> bool:
"""Check if SplitByWorker was chained to this Composable"""
Expand All @@ -99,7 +102,10 @@ def _contains_worker_split(self, source: Composable) -> bool:
elif not isinstance(source, Composable):
return False
else:
return self._contains_worker_split(source.source)
if isinstance(source, (Multiplexer, IterableSamplerSource)):
return all(self._contains_worker_split(src) for src in source.source)
else:
return self._contains_worker_split(source.source)


def _skip_k(it: Iterable, start: int, step: int) -> Iterator:
Expand Down
125 changes: 122 additions & 3 deletions test/test_iterstream/test_torch_composables.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,17 @@
from functools import partial
from typing import List, Any
from unittest import mock
from collections import namedtuple
from collections import namedtuple, Counter

import pytest
import torch
import torch.utils.data as tud

from squirrel.driver import MessagepackDriver
from squirrel.catalog import Catalog
from squirrel.driver import MessagepackDriver, IterDriver
from squirrel.iterstream.iterators import map_
from squirrel.iterstream.source import IterableSource
from squirrel.iterstream.multiplexer import Multiplexer, MultiplexingStrategy
from squirrel.iterstream.source import IterableSource, IterableSamplerSource
from squirrel.iterstream.torch_composables import SplitByRank, SplitByWorker, TorchIterable, skip_k
from squirrel.framework.exceptions import PyTorchSplittingException

Expand Down Expand Up @@ -77,6 +79,103 @@ def test_multi_worker_torch_iterable_map(samples: List[int]) -> None:
assert out.size() == (20, 5)


def _extract_key(d: dict) -> str:
return d["meta"]["sha"]


@pytest.mark.parametrize("num_workers", [0, 1, 2, 4])
def test_torch_iterable_multiprocessing_with_muxing_and_multiprocess(
num_workers: int, dummy_data_catalog: Catalog
) -> None:
"""Test that iterable of composable can properly be split by worker."""
batch_size = 5

cat = dummy_data_catalog
d0: IterDriver = cat["data_0"].get_driver()
d1: IterDriver = cat["data_1"].get_driver()
d2: IterDriver = cat["data_2"].get_driver()

d0_count = cat["data_0"].metadata["num_samples"]
d1_count = cat["data_1"].metadata["num_samples"]
d2_count = cat["data_2"].metadata["num_samples"]

p = [0.35, 0.6, 0.05]
max_reinits = 4
mux = Multiplexer(
[
d0.get_iter().split_by_worker_pytorch(),
d1.get_iter().split_by_worker_pytorch(),
d2.get_iter().split_by_worker_pytorch(),
],
mux_strategy=MultiplexingStrategy.ROUND_ROBIN,
sampling_probas=p,
seed=42,
max_reinits=max_reinits,
)

it = mux.map(_extract_key).batched(batch_size, drop_last_if_not_full=False).to_torch_iterable()

dl = tud.DataLoader(it, num_workers=num_workers)
cntr = 0
for b in dl:
cntr += len(b)

assert cntr == sum([d0_count, d1_count, d2_count])


@mock.patch("torch.distributed.is_available", mock.MagicMock(return_value=True))
@mock.patch("torch.distributed.is_initialized", mock.MagicMock(return_value=True))
@mock.patch("torch.distributed.get_world_size")
@mock.patch("torch.distributed.get_rank")
def test_torch_iterable_multiprocessing_with_muxing_and_multirank(
mock_get_rank: int, mock_get_world_size: int, dummy_data_catalog: Catalog
) -> None:
"""Test that iterable of composable can properly be split by worker and rank."""
batch_size = 5
num_workers = 2

world_size = 2
mock_get_world_size.return_value = world_size

rank_counts = {}
for rank in range(world_size):
mock_get_rank.return_value = rank

cat = dummy_data_catalog
d0: IterDriver = cat["data_0"].get_driver()
d1: IterDriver = cat["data_1"].get_driver()
d2: IterDriver = cat["data_2"].get_driver()

d0_count = cat["data_0"].metadata["num_samples"]
d1_count = cat["data_1"].metadata["num_samples"]
d2_count = cat["data_2"].metadata["num_samples"]

p = [0.35, 0.6, 0.05]
max_reinits = 4
mux = Multiplexer(
[
d0.get_iter().split_by_worker_pytorch().split_by_rank_pytorch(),
d1.get_iter().split_by_worker_pytorch().split_by_rank_pytorch(),
d2.get_iter().split_by_worker_pytorch().split_by_rank_pytorch(),
],
mux_strategy=MultiplexingStrategy.ROUND_ROBIN,
sampling_probas=p,
seed=42,
max_reinits=max_reinits,
)

it = mux.map(_extract_key).batched(batch_size, drop_last_if_not_full=False).to_torch_iterable()

dl = tud.DataLoader(it, num_workers=num_workers)
cntr = 0
for b in dl:
cntr += len(b)

rank_counts[rank] = cntr

assert sum(rank_counts.values()) == sum([d0_count, d1_count, d2_count])


def test_multi_worker_torch_iterable_async_map(samples: List[int]) -> None:
"""Test async_map is picklable and forkable in pytorch multiprocessing context"""
num_workers = 4
Expand All @@ -91,6 +190,26 @@ def test_multi_worker_torch_iterable_async_map(samples: List[int]) -> None:
assert out.size() == (20, 5)


@pytest.mark.parametrize("num_workers", [0, 1, 2, 4])
def test_split_by_worker_in_iterable_sampler_source_is_captured(num_workers: int, samples: List[int]) -> None:
"""Test that split_by_worker can be captured in an iterable sampler source."""
batch_size = 5
src_0 = IterableSource(samples).split_by_worker_pytorch()
src_1 = IterableSource(samples).split_by_worker_pytorch()

samp_src = IterableSamplerSource([src_0, src_1])

dl = tud.DataLoader(
samp_src.batched(batch_size, drop_last_if_not_full=False).to_torch_iterable(), num_workers=num_workers
)

out = torch.Tensor(list(dl))
cntr = Counter(out.cpu().flatten().long().numpy().tolist())
assert sorted(list(cntr.keys())) == samples
assert list(cntr.values()) == len(samples) * [2]
assert out.size() == (40, 5)


@mock.patch("torch.distributed.is_available", mock.MagicMock(return_value=True))
@mock.patch("torch.distributed.is_initialized", mock.MagicMock(return_value=True))
# @mock.patch("torch.distributed.group.WORLD", mock.MagicMock(return_value="WORLD"))
Expand Down

0 comments on commit cfcefd6

Please sign in to comment.