From cfcefd6f9c3f0982825466b23b5beaaf9c4d4f27 Mon Sep 17 00:00:00 2001 From: Johannes Otterbach Date: Thu, 4 Jan 2024 17:06:31 +0100 Subject: [PATCH] Add list of iterables to recursive env check of torch iterable (#171) * Add list of iterables to recursive env check of torch iterable * Add IterableSampleSource to tests --- squirrel/iterstream/multiplexer.py | 4 +- squirrel/iterstream/torch_composables.py | 12 +- .../test_iterstream/test_torch_composables.py | 125 +++++++++++++++++- 3 files changed, 133 insertions(+), 8 deletions(-) diff --git a/squirrel/iterstream/multiplexer.py b/squirrel/iterstream/multiplexer.py index 87b3503..dc79eeb 100644 --- a/squirrel/iterstream/multiplexer.py +++ b/squirrel/iterstream/multiplexer.py @@ -99,12 +99,12 @@ def __init__( Note that the algorithm stops whenever max_reinits are hit or all composables have been reinitialized at least once. """ - super().__init__() + super().__init__(source=composables) self.mux_strategy = mux_strategy if mux_strategy == MultiplexingStrategy.SAMPLING: assert sampling_probas is not None assert len(sampling_probas) == len( - composables + self.source ), "Need sampling probas and composables to have same number of entries" self.composables, self.sampling_probas = self._init_composables_and_probas( diff --git a/squirrel/iterstream/torch_composables.py b/squirrel/iterstream/torch_composables.py index a72d6db..c6702cc 100644 --- a/squirrel/iterstream/torch_composables.py +++ b/squirrel/iterstream/torch_composables.py @@ -6,7 +6,7 @@ import torch from torch.utils.data import IterableDataset -from squirrel.iterstream.base import Composable +from squirrel.iterstream import Composable, Multiplexer, IterableSamplerSource from squirrel.framework.exceptions import PyTorchSplittingException logger = logging.getLogger(__name__) @@ -90,7 +90,10 @@ def _contains_rank_split(self, source: Composable) -> bool: elif not isinstance(source, Composable): return False else: - return self._contains_rank_split(source.source) + if isinstance(source, (Multiplexer, IterableSamplerSource)): + return all(self._contains_rank_split(src) for src in source.source) + else: + return self._contains_rank_split(source.source) def _contains_worker_split(self, source: Composable) -> bool: """Check if SplitByWorker was chained to this Composable""" @@ -99,7 +102,10 @@ def _contains_worker_split(self, source: Composable) -> bool: elif not isinstance(source, Composable): return False else: - return self._contains_worker_split(source.source) + if isinstance(source, (Multiplexer, IterableSamplerSource)): + return all(self._contains_worker_split(src) for src in source.source) + else: + return self._contains_worker_split(source.source) def _skip_k(it: Iterable, start: int, step: int) -> Iterator: diff --git a/test/test_iterstream/test_torch_composables.py b/test/test_iterstream/test_torch_composables.py index e3fb601..59c7b51 100644 --- a/test/test_iterstream/test_torch_composables.py +++ b/test/test_iterstream/test_torch_composables.py @@ -2,15 +2,17 @@ from functools import partial from typing import List, Any from unittest import mock -from collections import namedtuple +from collections import namedtuple, Counter import pytest import torch import torch.utils.data as tud -from squirrel.driver import MessagepackDriver +from squirrel.catalog import Catalog +from squirrel.driver import MessagepackDriver, IterDriver from squirrel.iterstream.iterators import map_ -from squirrel.iterstream.source import IterableSource +from squirrel.iterstream.multiplexer import Multiplexer, MultiplexingStrategy +from squirrel.iterstream.source import IterableSource, IterableSamplerSource from squirrel.iterstream.torch_composables import SplitByRank, SplitByWorker, TorchIterable, skip_k from squirrel.framework.exceptions import PyTorchSplittingException @@ -77,6 +79,103 @@ def test_multi_worker_torch_iterable_map(samples: List[int]) -> None: assert out.size() == (20, 5) +def _extract_key(d: dict) -> str: + return d["meta"]["sha"] + + +@pytest.mark.parametrize("num_workers", [0, 1, 2, 4]) +def test_torch_iterable_multiprocessing_with_muxing_and_multiprocess( + num_workers: int, dummy_data_catalog: Catalog +) -> None: + """Test that iterable of composable can properly be split by worker.""" + batch_size = 5 + + cat = dummy_data_catalog + d0: IterDriver = cat["data_0"].get_driver() + d1: IterDriver = cat["data_1"].get_driver() + d2: IterDriver = cat["data_2"].get_driver() + + d0_count = cat["data_0"].metadata["num_samples"] + d1_count = cat["data_1"].metadata["num_samples"] + d2_count = cat["data_2"].metadata["num_samples"] + + p = [0.35, 0.6, 0.05] + max_reinits = 4 + mux = Multiplexer( + [ + d0.get_iter().split_by_worker_pytorch(), + d1.get_iter().split_by_worker_pytorch(), + d2.get_iter().split_by_worker_pytorch(), + ], + mux_strategy=MultiplexingStrategy.ROUND_ROBIN, + sampling_probas=p, + seed=42, + max_reinits=max_reinits, + ) + + it = mux.map(_extract_key).batched(batch_size, drop_last_if_not_full=False).to_torch_iterable() + + dl = tud.DataLoader(it, num_workers=num_workers) + cntr = 0 + for b in dl: + cntr += len(b) + + assert cntr == sum([d0_count, d1_count, d2_count]) + + +@mock.patch("torch.distributed.is_available", mock.MagicMock(return_value=True)) +@mock.patch("torch.distributed.is_initialized", mock.MagicMock(return_value=True)) +@mock.patch("torch.distributed.get_world_size") +@mock.patch("torch.distributed.get_rank") +def test_torch_iterable_multiprocessing_with_muxing_and_multirank( + mock_get_rank: int, mock_get_world_size: int, dummy_data_catalog: Catalog +) -> None: + """Test that iterable of composable can properly be split by worker and rank.""" + batch_size = 5 + num_workers = 2 + + world_size = 2 + mock_get_world_size.return_value = world_size + + rank_counts = {} + for rank in range(world_size): + mock_get_rank.return_value = rank + + cat = dummy_data_catalog + d0: IterDriver = cat["data_0"].get_driver() + d1: IterDriver = cat["data_1"].get_driver() + d2: IterDriver = cat["data_2"].get_driver() + + d0_count = cat["data_0"].metadata["num_samples"] + d1_count = cat["data_1"].metadata["num_samples"] + d2_count = cat["data_2"].metadata["num_samples"] + + p = [0.35, 0.6, 0.05] + max_reinits = 4 + mux = Multiplexer( + [ + d0.get_iter().split_by_worker_pytorch().split_by_rank_pytorch(), + d1.get_iter().split_by_worker_pytorch().split_by_rank_pytorch(), + d2.get_iter().split_by_worker_pytorch().split_by_rank_pytorch(), + ], + mux_strategy=MultiplexingStrategy.ROUND_ROBIN, + sampling_probas=p, + seed=42, + max_reinits=max_reinits, + ) + + it = mux.map(_extract_key).batched(batch_size, drop_last_if_not_full=False).to_torch_iterable() + + dl = tud.DataLoader(it, num_workers=num_workers) + cntr = 0 + for b in dl: + cntr += len(b) + + rank_counts[rank] = cntr + + assert sum(rank_counts.values()) == sum([d0_count, d1_count, d2_count]) + + def test_multi_worker_torch_iterable_async_map(samples: List[int]) -> None: """Test async_map is picklable and forkable in pytorch multiprocessing context""" num_workers = 4 @@ -91,6 +190,26 @@ def test_multi_worker_torch_iterable_async_map(samples: List[int]) -> None: assert out.size() == (20, 5) +@pytest.mark.parametrize("num_workers", [0, 1, 2, 4]) +def test_split_by_worker_in_iterable_sampler_source_is_captured(num_workers: int, samples: List[int]) -> None: + """Test that split_by_worker can be captured in an iterable sampler source.""" + batch_size = 5 + src_0 = IterableSource(samples).split_by_worker_pytorch() + src_1 = IterableSource(samples).split_by_worker_pytorch() + + samp_src = IterableSamplerSource([src_0, src_1]) + + dl = tud.DataLoader( + samp_src.batched(batch_size, drop_last_if_not_full=False).to_torch_iterable(), num_workers=num_workers + ) + + out = torch.Tensor(list(dl)) + cntr = Counter(out.cpu().flatten().long().numpy().tolist()) + assert sorted(list(cntr.keys())) == samples + assert list(cntr.values()) == len(samples) * [2] + assert out.size() == (40, 5) + + @mock.patch("torch.distributed.is_available", mock.MagicMock(return_value=True)) @mock.patch("torch.distributed.is_initialized", mock.MagicMock(return_value=True)) # @mock.patch("torch.distributed.group.WORLD", mock.MagicMock(return_value="WORLD"))