Skip to content

Commit

Permalink
Make toggle_discrete_candidates expect a collection of constraints
Browse files Browse the repository at this point in the history
  • Loading branch information
AdrianSosic committed Nov 19, 2024
1 parent e7db1ca commit 49e4959
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 14 deletions.
25 changes: 14 additions & 11 deletions baybe/campaign.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

import gc
import json
from functools import singledispatchmethod
from collections.abc import Collection
from functools import reduce, singledispatchmethod
from typing import TYPE_CHECKING

import cattrs
Expand Down Expand Up @@ -291,17 +292,17 @@ def _mark_as_measured(
@singledispatchmethod
def toggle_discrete_candidates( # noqa: DOC501
self,
constraint: DiscreteConstraint | pd.DataFrame,
constraints: Collection[DiscreteConstraint] | pd.DataFrame,
exclude: bool,
complement: bool = False,
dry_run: bool = False,
) -> pd.DataFrame:
"""In-/exclude certain discrete points in/from the candidate set.
Args:
constraint: A filtering mechanism determining the candidates subset to be
in-/excluded. Can be either a
:class:`~baybe.constraints.base.DiscreteConstraint` or a dataframe.
constraints: A filtering mechanism determining the candidates subset to be
in-/excluded. Can be either a collection of
:class:`~baybe.constraints.base.DiscreteConstraint`s or a dataframe.
For the latter, see :func:`~baybe.utils.dataframe.filter_df`
for details.
exclude: If ``True``, the specified candidates are excluded.
Expand All @@ -320,20 +321,22 @@ def toggle_discrete_candidates( # noqa: DOC501
"""
raise NotImplementedError(
f"Candidate toggling is not implemented for constraint specifications of "
f"type {type(constraint)}."
f"type {type(constraints)}."
)

@toggle_discrete_candidates.register
@toggle_discrete_candidates.register(Collection)
def _(
self,
constraint: DiscreteConstraint,
constraints: Collection[DiscreteConstraint],
exclude: bool,
complement: bool = False,
dry_run: bool = False,
) -> pd.DataFrame:
# Filter search space dataframe according to the given constraint
df = self.searchspace.discrete.exp_rep
idx = constraint.get_valid(df)
idx = reduce(
lambda x, y: x.intersection(y), (c.get_valid(df) for c in constraints)
)

# Determine the candidate subset to be toggled
points = df.drop(index=idx) if complement else df.loc[idx].copy()
Expand All @@ -346,13 +349,13 @@ def _(
@toggle_discrete_candidates.register
def _(
self,
constraint: pd.DataFrame,
constraints: pd.DataFrame,
exclude: bool,
complement: bool = False,
dry_run: bool = False,
) -> pd.DataFrame:
# Determine the candidate subset to be toggled
points = filter_df(self.searchspace.discrete.exp_rep, constraint, complement)
points = filter_df(self.searchspace.discrete.exp_rep, constraints, complement)

if not dry_run:
self._searchspace_metadata.loc[points.index, _EXCLUDED] = exclude
Expand Down
6 changes: 3 additions & 3 deletions tests/test_campaign.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@ def test_get_surrogate(campaign, n_iterations, batch_size):
@pytest.mark.parametrize("complement", [False, True], ids=["regular", "complement"])
@pytest.mark.parametrize("exclude", [True, False], ids=["exclude", "include"])
@pytest.mark.parametrize(
"constraint",
"constraints",
[
pd.DataFrame({"a": [0]}),
DiscreteExcludeConstraint(["a"], [SubSelectionCondition([1])]),
],
ids=["dataframe", "constraints"],
)
def test_candidate_toggling(constraint, exclude, complement):
def test_candidate_toggling(constraints, exclude, complement):
"""Toggling discrete candidates updates the campaign metadata accordingly."""
subspace = SubspaceDiscrete.from_product(
[
Expand All @@ -62,7 +62,7 @@ def test_candidate_toggling(constraint, exclude, complement):
campaign._searchspace_metadata[_EXCLUDED] = not exclude

# Toggle the candidates
campaign.toggle_discrete_candidates(constraint, exclude, complement=complement)
campaign.toggle_discrete_candidates(constraints, exclude, complement=complement)

# Extract row indices of candidates whose metadata should have been toggled
matches = campaign.searchspace.discrete.exp_rep["a"] == 0
Expand Down

0 comments on commit 49e4959

Please sign in to comment.