Skip to content

Commit

Permalink
introduce trial_indices argument to SupervisedDataset (#2595)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebook/Ax#2960

Pull Request resolved: #2595

Adds optional `group_indices` to SupervisedDataset, whose dimensionality should correspond 1:1 with the first few dimensions of X and Y tensors, as validated in `_validate` ([pointer](https://www.internalfb.com/diff/D64764019?permalink=1739375523489084)).

Reviewed By: Balandat

Differential Revision: D64764019

fbshipit-source-id: 733460c2baa84e3a73573227c48f9bb20047e241
  • Loading branch information
bernardbeckerman authored and facebook-github-bot committed Nov 26, 2024
1 parent 3f2e2c7 commit 2e143c9
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 4 deletions.
31 changes: 30 additions & 1 deletion botorch/utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import torch
from botorch.exceptions.errors import InputDataError, UnsupportedError
from botorch.utils.containers import BotorchContainer, SliceContainer
from pyre_extensions import none_throws
from torch import long, ones, Tensor


Expand Down Expand Up @@ -55,6 +56,7 @@ def __init__(
outcome_names: list[str],
Yvar: BotorchContainer | Tensor | None = None,
validate_init: bool = True,
group_indices: Tensor | None = None,
) -> None:
r"""Constructs a `SupervisedDataset`.
Expand All @@ -66,12 +68,17 @@ def __init__(
Yvar: An optional `Tensor` or `BotorchContainer` representing
the observation noise.
validate_init: If `True`, validates the input shapes.
group_indices: A `Tensor` representing the which rows of X and Y are
grouped together. This is used to support applications in which multiple
observations should be considered as a group, e.g., learning-curve-based
modeling. If provided, its shape must be compatible with X and Y.
"""
self._X = X
self._Y = Y
self._Yvar = Yvar
self.feature_names = feature_names
self.outcome_names = outcome_names
self.group_indices = group_indices
self.validate_init = validate_init
if validate_init:
self._validate()
Expand All @@ -98,6 +105,7 @@ def _validate(
self,
validate_feature_names: bool = True,
validate_outcome_names: bool = True,
validate_group_indices: bool = True,
) -> None:
r"""Checks that the shapes of the inputs are compatible with each other.
Expand All @@ -110,6 +118,8 @@ def _validate(
`outcomes_names` matches the # of columns of `self.Y`. If a
particular dataset, e.g., `RankingDataset`, is known to violate
this assumption, this can be set to `False`.
validate_group_indices: By default, we validate that the shape of
`group_indices` matches the shape of X and Y.
"""
shape_X = self.X.shape
if isinstance(self._X, BotorchContainer):
Expand All @@ -135,8 +145,20 @@ def _validate(
"`Y` must have the same number of columns as the number of "
"outcomes in `outcome_names`."
)
if validate_group_indices and self.group_indices is not None:
if self.group_indices.shape != shape_X:
raise ValueError(
f"shape_X ({shape_X}) must have the same shape as "
f"group_indices ({none_throws(self.group_indices).shape})."
)

def __eq__(self, other: Any) -> bool:
if self.group_indices is None and other.group_indices is None:
group_indices_equal = True
elif self.group_indices is None or other.group_indices is None:
group_indices_equal = False
else:
group_indices_equal = torch.equal(self.group_indices, other.group_indices)
return (
type(other) is type(self)
and torch.equal(self.X, other.X)
Expand All @@ -148,6 +170,7 @@ def __eq__(self, other: Any) -> bool:
)
and self.feature_names == other.feature_names
and self.outcome_names == other.outcome_names
and group_indices_equal
)

def clone(
Expand Down Expand Up @@ -256,7 +279,11 @@ def __init__(
)

def _validate(self) -> None:
super()._validate(validate_feature_names=False, validate_outcome_names=False)
super()._validate(
validate_feature_names=False,
validate_outcome_names=False,
validate_group_indices=False,
)
if len(self.feature_names) != self._X.values.shape[-1]:
raise ValueError(
"The `values` field of `X` must have the same number of columns as "
Expand Down Expand Up @@ -331,6 +358,7 @@ def __init__(
self.has_heterogeneous_features = any(
datasets[0].feature_names != ds.feature_names for ds in datasets[1:]
)
self.group_indices = None

@classmethod
def from_joint_dataset(
Expand Down Expand Up @@ -584,6 +612,7 @@ def __init__(
c: [self.feature_names.index(i) for i in parameter_decomposition[c]]
for c in self.context_buckets
}
self.group_indices = None

@property
def X(self) -> Tensor:
Expand Down
22 changes: 19 additions & 3 deletions test/utils/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,27 +98,35 @@ def make_contextual_dataset(
class TestDatasets(BotorchTestCase):
def test_supervised(self):
# Generate some data
X = rand(3, 2)
Y = rand(3, 1)
n_rows = 3
X = rand(n_rows, 2)
Y = rand(n_rows, 1)
feature_names = ["x1", "x2"]
outcome_names = ["y"]
group_indices = tensor(range(n_rows))

# Test `__init__`
dataset = SupervisedDataset(
X=X, Y=Y, feature_names=feature_names, outcome_names=outcome_names
X=X,
Y=Y,
feature_names=feature_names,
outcome_names=outcome_names,
group_indices=group_indices,
)
self.assertIsInstance(dataset.X, Tensor)
self.assertIsInstance(dataset._X, Tensor)
self.assertIsInstance(dataset.Y, Tensor)
self.assertIsInstance(dataset._Y, Tensor)
self.assertEqual(dataset.feature_names, feature_names)
self.assertEqual(dataset.outcome_names, outcome_names)
self.assertTrue(torch.equal(dataset.group_indices, group_indices))

dataset2 = SupervisedDataset(
X=DenseContainer(X, X.shape[-1:]),
Y=DenseContainer(Y, Y.shape[-1:]),
feature_names=feature_names,
outcome_names=outcome_names,
group_indices=group_indices,
)
self.assertIsInstance(dataset2.X, Tensor)
self.assertIsInstance(dataset2._X, DenseContainer)
Expand Down Expand Up @@ -156,6 +164,14 @@ def test_supervised(self):
feature_names=feature_names,
outcome_names=[],
)
with self.assertRaisesRegex(ValueError, "group_indices"):
SupervisedDataset(
X=rand(2, 2),
Y=rand(2, 1),
feature_names=feature_names,
outcome_names=outcome_names,
group_indices=tensor(range(n_rows + 1)),
)

# Test with Yvar.
dataset = SupervisedDataset(
Expand Down

0 comments on commit 2e143c9

Please sign in to comment.