Skip to content

Commit

Permalink
Extend toggling test to also verify the updated campaign metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
AdrianSosic committed Nov 12, 2024
1 parent 8bff13b commit c3ba88c
Showing 1 changed file with 23 additions and 6 deletions.
29 changes: 23 additions & 6 deletions tests/test_campaign.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pandas.testing import assert_frame_equal
from pytest import param

from baybe.campaign import Campaign
from baybe.campaign import _EXCLUDED, Campaign
from baybe.parameters.numerical import NumericalDiscreteParameter
from baybe.searchspace.discrete import SubspaceDiscrete

Expand Down Expand Up @@ -51,16 +51,33 @@ def test_get_surrogate(campaign, n_iterations, batch_size):
],
ids=["regular", "anti"],
)
def test_candidate_filter(anti, expected):
"""The candidate filter extracts the correct subset of points."""
@pytest.mark.parametrize("exclude", [True, False], ids=["exclude", "include"])
def test_candidate_filter(exclude, anti, expected):
"""The candidate filter extracts the correct subset of points and the campaign
metadata is updated accordingly.""" # noqa

subspace = SubspaceDiscrete.from_product(
[
NumericalDiscreteParameter("a", [0, 1]),
NumericalDiscreteParameter("b", [3, 4, 5]),
]
)
campaign = Campaign(subspace)
df = campaign.toggle_discrete_candidates(pd.DataFrame({"a": [0]}), False, anti=anti)
assert_frame_equal(
df, pd.merge(df.reset_index(), expected).set_index("index"), check_names=False

# Set metadata to opposite of targeted value so that we can verify the effect later
campaign._searchspace_metadata[_EXCLUDED] = not exclude

# Toggle the candidates
df = campaign.toggle_discrete_candidates(
pd.DataFrame({"a": [0]}), exclude, anti=anti
)

# Assert that the filtering is correct
rows = pd.merge(df.reset_index(), expected).set_index("index")
assert_frame_equal(df, rows, check_names=False)

# Assert that metadata is set correctly
target = campaign._searchspace_metadata.loc[rows.index, _EXCLUDED]
other = campaign._searchspace_metadata[_EXCLUDED].drop(index=rows.index)
assert all(target == exclude) # must contain the updated values
assert all(other != exclude) # must contain the original values

0 comments on commit c3ba88c

Please sign in to comment.