Skip to content

Commit

Permalink
Support multiple datasets for validation and test splits.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 708302603
  • Loading branch information
kmaninis authored and Scenic Authors committed Dec 20, 2024
1 parent 97d6ac5 commit fe753ff
Showing 1 changed file with 18 additions and 7 deletions.
25 changes: 18 additions & 7 deletions scenic/dataset_lib/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import dataclasses
import functools
import itertools
from typing import Any, Callable, Dict, Iterator, Optional, Sequence, Union
from typing import Any, Callable, Dict, List, Iterator, Optional, Sequence, Union

from absl import logging
from flax.training import common_utils
Expand All @@ -35,6 +35,8 @@
PyTree = Any
DatasetIterator = Union[Iterator[Any], Dict[str, Iterator[Any]]]
DatasetIteratorProvider = Callable[[], DatasetIterator]
DatasetIteratorType = DatasetIterator | DatasetIteratorProvider
DatasetType = Union[tf.data.Dataset, Dict[str, tf.data.Dataset]]


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -68,14 +70,23 @@ class Dataset:
classification tasks, `num_classes` is used for the configuring head of
the model.
"""
train_iter: DatasetIterator | DatasetIteratorProvider | None = None
valid_iter: DatasetIterator | DatasetIteratorProvider | None = None
test_iter: DatasetIterator | DatasetIteratorProvider | None = None
train_iter: DatasetIteratorType | None = None
valid_iter: DatasetIteratorType | None = None
test_iter: DatasetIteratorType | None = None
meta_data: Dict[str, Any] = dataclasses.field(default_factory=dict)

train_ds: Union[tf.data.Dataset, Dict[str, tf.data.Dataset]] | None = None
valid_ds: Union[tf.data.Dataset, Dict[str, tf.data.Dataset]] | None = None
test_ds: Union[tf.data.Dataset, Dict[str, tf.data.Dataset]] | None = None
train_ds: DatasetType | None = None
valid_ds: DatasetType | None = None
test_ds: DatasetType | None = None

# Multiple dataset support.
train_multi_iter: List[DatasetIteratorType] | None = None
valid_multi_iter: List[DatasetIteratorType] | None = None
test_multi_iter: List[DatasetIteratorType] | None = None

train_multi_ds: List[DatasetType] | None = None
valid_multi_ds: List[DatasetType] | None = None
test_multi_ds: List[DatasetType] | None = None


def maybe_pad_batch(batch: Dict[str, PyTree],
Expand Down

0 comments on commit fe753ff

Please sign in to comment.