From 6db8b7e439976533879ab450b9c3a7c3492cabd2 Mon Sep 17 00:00:00 2001 From: Alireza Sohofi Date: Wed, 16 Aug 2023 16:34:54 +0200 Subject: [PATCH] make the split_by_worker and slpit_by_rank optional --- squirrel/iterstream/base.py | 18 +++++++++++++++--- squirrel/iterstream/torch_composables.py | 8 +++++--- test/test_iterstream/test_torch_composables.py | 3 +++ 3 files changed, 23 insertions(+), 6 deletions(-) diff --git a/squirrel/iterstream/base.py b/squirrel/iterstream/base.py index b15f919d..3bf9ae5f 100644 --- a/squirrel/iterstream/base.py +++ b/squirrel/iterstream/base.py @@ -324,11 +324,23 @@ def split_by_rank_pytorch(self, torch_dist_group: t.Optional[str] = None) -> Com return self.compose(SplitByRank, torch_dist_group) - def to_torch_iterable(self) -> Composable: - """Convert the stream to a torch iterable.""" + def to_torch_iterable(self, enforce_rank_check: bool = True, enforce_worker_check: bool = True) -> Composable: + """ + Convert the stream to a torch iterable. + + Args: + enforce_rank_check: if set to true, checks that the method `split_by_rank_pytorch` has been called prior to + calling `to_torch_iterable`. This is important to avoid loading the same sample more than once in the + multi-rank pytorch environment. + enforce_worker_check: if set to true, checks that the method `split_by_worker_pytorch` has been called + prior to calling `to_torch_iterable`. This is important to avoid loading the same sample more than + once in the multi-worker pytorch environment. + """ from squirrel.iterstream.torch_composables import TorchIterable - return self.compose(TorchIterable) + return self.compose( + partial(TorchIterable, enforce_rank_check=enforce_rank_check, enforce_worker_check=enforce_worker_check) + ) class _Iterable(Composable): diff --git a/squirrel/iterstream/torch_composables.py b/squirrel/iterstream/torch_composables.py index 6bde718a..a72d6db4 100644 --- a/squirrel/iterstream/torch_composables.py +++ b/squirrel/iterstream/torch_composables.py @@ -61,19 +61,21 @@ def __iter__(self) -> Iterator: class TorchIterable(Composable, IterableDataset): """Mixin-Composable to have squirrel pipeline inherit from PyTorch IterableDataset""" - def __init__(self) -> None: + def __init__(self, enforce_rank_check: bool = True, enforce_worker_check: bool = True) -> None: """Init""" super().__init__() + self.enforce_rank_check = enforce_rank_check + self.enforce_worker_check = enforce_worker_check def __iter__(self) -> Iterator: """Method to iterate over the source""" - if _in_multi_rank_env(): + if self.enforce_rank_check and _in_multi_rank_env(): if not self._contains_rank_split(self.source): raise PyTorchSplittingException( "Composable was not split by rank. This will lead to unexpected iteration behaviour." "Add a 'split_by_rank_pytorch' call to your composable to avoid this error. " ) - if _in_multi_worker_env(): + if self.enforce_worker_check and _in_multi_worker_env(): if not self._contains_worker_split(self.source): raise PyTorchSplittingException( "Composable was not split by worker. This will lead to unexpected iteration behaviour." diff --git a/test/test_iterstream/test_torch_composables.py b/test/test_iterstream/test_torch_composables.py index 2820048c..b3787c58 100644 --- a/test/test_iterstream/test_torch_composables.py +++ b/test/test_iterstream/test_torch_composables.py @@ -212,6 +212,9 @@ def test_error_when_not_splitting_in_mp(mock_get_worker_info: Any, samples: List it = IterableSource(samples).to_torch_iterable() next(iter(it)) + res = IterableSource(samples).to_torch_iterable(enforce_worker_check=False, enforce_rank_check=False).collect() + assert res == samples + # Split by rank and worker, this should work # ADD SIMPLE MAP FN