Skip to content

Commit

Permalink
Rename anti argument to complement
Browse files Browse the repository at this point in the history
  • Loading branch information
AdrianSosic committed Nov 19, 2024
1 parent 66bc0ca commit e7db1ca
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 21 deletions.
17 changes: 9 additions & 8 deletions baybe/campaign.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def toggle_discrete_candidates( # noqa: DOC501
self,
constraint: DiscreteConstraint | pd.DataFrame,
exclude: bool,
anti: bool = False,
complement: bool = False,
dry_run: bool = False,
) -> pd.DataFrame:
"""In-/exclude certain discrete points in/from the candidate set.
Expand All @@ -306,9 +306,10 @@ def toggle_discrete_candidates( # noqa: DOC501
for details.
exclude: If ``True``, the specified candidates are excluded.
If ``False``, the candidates are considered for recommendation.
anti: Boolean flag deciding if the points specified by the filter or their
complement is to be kept. For details, see
:func:`~baybe.utils.dataframe.filter_df`.
complement: If ``False``, the filtering mechanism is used as is.
If ``True``, the filtering mechanism is inverted so that
the complement of the subset specified by the filter is toggled.
For details, see :func:`~baybe.utils.dataframe.filter_df`.
dry_run: If ``True``, the target subset is only extracted but not
affected. If ``False``, the candidate set is updated correspondingly.
Useful for setting up the correct filtering mechanism.
Expand All @@ -327,15 +328,15 @@ def _(
self,
constraint: DiscreteConstraint,
exclude: bool,
anti: bool = False,
complement: bool = False,
dry_run: bool = False,
) -> pd.DataFrame:
# Filter search space dataframe according to the given constraint
df = self.searchspace.discrete.exp_rep
idx = constraint.get_valid(df)

# Determine the candidate subset to be toggled
points = df.drop(index=idx) if anti else df.loc[idx].copy()
points = df.drop(index=idx) if complement else df.loc[idx].copy()

if not dry_run:
self._searchspace_metadata.loc[points.index, _EXCLUDED] = exclude
Expand All @@ -347,11 +348,11 @@ def _(
self,
constraint: pd.DataFrame,
exclude: bool,
anti: bool = False,
complement: bool = False,
dry_run: bool = False,
) -> pd.DataFrame:
# Determine the candidate subset to be toggled
points = filter_df(self.searchspace.discrete.exp_rep, constraint, anti)
points = filter_df(self.searchspace.discrete.exp_rep, constraint, complement)

if not dry_run:
self._searchspace_metadata.loc[points.index, _EXCLUDED] = exclude
Expand Down
4 changes: 3 additions & 1 deletion baybe/simulation/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,9 @@ def simulate_experiment(
# available in the lookup
if impute_mode == "ignore":
campaign.toggle_discrete_candidates(
lookup[[p.name for p in campaign.parameters]], exclude=True, anti=True
lookup[[p.name for p in campaign.parameters]],
exclude=True,
complement=True,
)

# Run the DOE loop
Expand Down
4 changes: 3 additions & 1 deletion baybe/simulation/scenarios.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,9 @@ def _simulate_groupby(
for c in group.columns
if c in campaign.searchspace.discrete.parameter_names
]
campaign_group.toggle_discrete_candidates(group[cols], exclude=True, anti=True)
campaign_group.toggle_discrete_candidates(
group[cols], exclude=True, complement=True
)

# Run the group simulation
try:
Expand Down
2 changes: 1 addition & 1 deletion baybe/simulation/transfer_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def simulate_transfer_learning(
# TODO: Reconsider if deepcopies are required once [16605] is resolved
campaign_task = deepcopy(campaign)
campaign_task.toggle_discrete_candidates(
pd.DataFrame({task_param.name: [task]}), exclude=True, anti=True
pd.DataFrame({task_param.name: [task]}), exclude=True, complement=True
)

# Use all off-task data as training data
Expand Down
12 changes: 6 additions & 6 deletions baybe/utils/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,17 +604,17 @@ def get_transform_objects(


def filter_df(
df: pd.DataFrame, filter: pd.DataFrame, anti: bool = False
df: pd.DataFrame, filter: pd.DataFrame, complement: bool = False
) -> pd.DataFrame:
"""Filter a dataframe based on a second dataframe defining filtering conditions.
Filtering is done via a join (see ``anti`` argument for details) between the
Filtering is done via a join (see ``complement`` argument for details) between the
input dataframe and the filter dataframe.
Args:
df: The dataframe to be filtered.
filter: The dataframe defining the filtering conditions.
anti: If ``False``, the filter dataframe determines the rows to be kept
complement: If ``False``, the filter dataframe determines the rows to be kept
(i.e. selection via regular join). If ``True``, the filtering mechanism is
inverted so that the complement set of rows is kept (i.e. selection
via anti-join).
Expand All @@ -634,12 +634,12 @@ def filter_df(
2 1 a
3 1 b
>>> filter_df(df, pd.DataFrame([0], columns=["num"]), anti=False)
>>> filter_df(df, pd.DataFrame([0], columns=["num"]), complement=False)
num cat
0 0 a
1 0 b
>>> filter_df(df, pd.DataFrame([0], columns=["num"]), anti=True)
>>> filter_df(df, pd.DataFrame([0], columns=["num"]), complement=True)
num cat
2 1 a
3 1 b
Expand All @@ -651,7 +651,7 @@ def filter_df(
out = pd.merge(
df.reset_index(names="_df_index"), filter, how="left", indicator=True
).set_index("_df_index")
to_drop = out["_merge"] == ("both" if anti else "left_only")
to_drop = out["_merge"] == ("both" if complement else "left_only")

# Drop the points
out.drop(index=out[to_drop].index, inplace=True)
Expand Down
8 changes: 4 additions & 4 deletions tests/test_campaign.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_get_surrogate(campaign, n_iterations, batch_size):
assert model is not None, "Something went wrong during surrogate model extraction."


@pytest.mark.parametrize("anti", [False, True], ids=["regular", "anti"])
@pytest.mark.parametrize("complement", [False, True], ids=["regular", "complement"])
@pytest.mark.parametrize("exclude", [True, False], ids=["exclude", "include"])
@pytest.mark.parametrize(
"constraint",
Expand All @@ -48,7 +48,7 @@ def test_get_surrogate(campaign, n_iterations, batch_size):
],
ids=["dataframe", "constraints"],
)
def test_candidate_toggling(constraint, exclude, anti):
def test_candidate_toggling(constraint, exclude, complement):
"""Toggling discrete candidates updates the campaign metadata accordingly."""
subspace = SubspaceDiscrete.from_product(
[
Expand All @@ -62,11 +62,11 @@ def test_candidate_toggling(constraint, exclude, anti):
campaign._searchspace_metadata[_EXCLUDED] = not exclude

# Toggle the candidates
campaign.toggle_discrete_candidates(constraint, exclude, anti=anti)
campaign.toggle_discrete_candidates(constraint, exclude, complement=complement)

# Extract row indices of candidates whose metadata should have been toggled
matches = campaign.searchspace.discrete.exp_rep["a"] == 0
idx = matches.index[~matches] if anti else matches.index[matches]
idx = matches.index[~matches] if complement else matches.index[matches]

# Assert that metadata is set correctly
target = campaign._searchspace_metadata.loc[idx, _EXCLUDED]
Expand Down

0 comments on commit e7db1ca

Please sign in to comment.