Skip to content

Commit

Permalink
Add support for continuous relaxation within optimize_acqf_mixed_alte…
Browse files Browse the repository at this point in the history
…rnating (#2635)

Summary:
Pull Request resolved: #2635

`optimize_acqf_mixed_alternating` utilizes local search to optimize discrete dimensions. This works well when there are a small number of values for the discrete dimensions but it does not scale well as the number of values increases. To address this, we have been transforming the high-cardinality dimensions in Ax and only passing in the low-cardinality dimensions as part of `discrete_dims`.
This diff adds support for using continuous relaxation for discrete dimensions that have more than `max_discrete_values` (configurable via `options`).

Also updates the optimizer to fall back to `optimize_acqf` if there are no discrete dimensions left. This is more user friendly than erroring out (particularly when used through Ax).

Reviewed By: Balandat

Differential Revision: D66239005

fbshipit-source-id: 0878115eb08ea75acb34ad8e891cf88393d4e36c
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Nov 21, 2024
1 parent de46059 commit 5d37606
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 8 deletions.
58 changes: 57 additions & 1 deletion botorch/optim/optimize_mixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@
MAX_ITER_ALTER = 64 # Maximum number of alternating iterations.
MAX_ITER_DISCRETE = 4 # Maximum number of discrete iterations.
MAX_ITER_CONT = 8 # Maximum number of continuous iterations.
# Maximum number of discrete values for a discrete dimension.
# If there are more values for a dimension, we will use continuous
# relaxation to optimize it.
MAX_DISCRETE_VALUES = 20
# Maximum number of iterations for optimizing the continuous relaxation
# during initialization
MAX_ITER_INIT = 100
Expand All @@ -52,6 +56,7 @@
"maxiter_discrete",
"maxiter_continuous",
"maxiter_init",
"max_discrete_values",
"num_spray_points",
"std_cont_perturbation",
"batch_limit",
Expand All @@ -60,6 +65,40 @@
SUPPORTED_INITIALIZATION = {"continuous_relaxation", "equally_spaced", "random"}


def _setup_continuous_relaxation(
discrete_dims: list[int],
bounds: Tensor,
max_discrete_values: int,
post_processing_func: Callable[[Tensor], Tensor] | None,
) -> tuple[list[int], Callable[[Tensor], Tensor] | None]:
r"""Update `discrete_dims` and `post_processing_func` to use
continuous relaxation for discrete dimensions that have more than
`max_discrete_values` values. These dimensions are removed from
`discrete_dims` and `post_processing_func` is updated to round
them to the nearest integer.
"""
discrete_dims_t = torch.tensor(discrete_dims, dtype=torch.long)
num_discrete_values = (
bounds[1, discrete_dims_t] - bounds[0, discrete_dims_t]
).cpu()
dims_to_relax = discrete_dims_t[num_discrete_values > max_discrete_values]
if dims_to_relax.numel() == 0:
# No dimension needs continuous relaxation.
return discrete_dims, post_processing_func
# Remove relaxed dims from `discrete_dims`.
discrete_dims = list(set(discrete_dims).difference(dims_to_relax.tolist()))

def new_post_processing_func(X: Tensor) -> Tensor:
r"""Round the relaxed dimensions to the nearest integer and apply the original
`post_processing_func`."""
X[..., dims_to_relax] = X[..., dims_to_relax].round()
if post_processing_func is not None:
X = post_processing_func(X)
return X

return discrete_dims, new_post_processing_func


def _filter_infeasible(
X: Tensor, inequality_constraints: list[tuple[Tensor, Tensor, float]] | None
) -> Tensor:
Expand Down Expand Up @@ -532,6 +571,9 @@ def optimize_acqf_mixed_alternating(
iterations.
NOTE: This method assumes that all discrete variables are integer valued.
The discrete dimensions that have more than
`options.get("max_discrete_values", MAX_DISCRETE_VALUES)` values will
be optimized using continuous relaxation.
# TODO: Support categorical variables.
Expand All @@ -549,6 +591,9 @@ def optimize_acqf_mixed_alternating(
Defaults to 4.
- "maxiter_continuous": Maximum number of iterations in each continuous step.
Defaults to 8.
- "max_discrete_values": Maximum number of values for a discrete dimension
to be optimized using discrete step / local search. The discrete dimensions
with more values will be optimized using continuous relaxation.
- "num_spray_points": Number of spray points (around `X_baseline`) to add to
the points generated by the initialization strategy. Defaults to 20 if
all discrete variables are binary and to 0 otherwise.
Expand Down Expand Up @@ -598,6 +643,17 @@ def optimize_acqf_mixed_alternating(
f"Received an unsupported option {unsupported_keys}. {SUPPORTED_OPTIONS=}."
)

# Update discrete dims and post processing functions to account for any
# dimensions that should be using continuous relaxation.
discrete_dims, post_processing_func = _setup_continuous_relaxation(
discrete_dims=discrete_dims,
bounds=bounds,
max_discrete_values=assert_is_instance(
options.get("max_discrete_values", MAX_DISCRETE_VALUES), int
),
post_processing_func=post_processing_func,
)

opt_inputs = OptimizeAcqfInputs(
acq_function=acq_function,
bounds=bounds,
Expand All @@ -623,7 +679,7 @@ def optimize_acqf_mixed_alternating(
# Remove fixed features from dims, so they don't get optimized.
discrete_dims = [dim for dim in discrete_dims if dim not in fixed_features]
if len(discrete_dims) == 0:
raise ValueError("There must be at least one discrete parameter.")
return _optimize_acqf(opt_inputs=opt_inputs)
if not (
isinstance(discrete_dims, list)
and len(set(discrete_dims)) == len(discrete_dims)
Expand Down
91 changes: 85 additions & 6 deletions test/optim/test_optimize_mixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@
from botorch.models.gp_regression import SingleTaskGP
from botorch.optim.optimize import _optimize_acqf, OptimizeAcqfInputs
from botorch.optim.optimize_mixed import (
_setup_continuous_relaxation,
complement_indices,
continuous_step,
discrete_step,
generate_starting_points,
get_nearest_neighbors,
get_spray_points,
MAX_DISCRETE_VALUES,
optimize_acqf_mixed_alternating,
sample_feasible_points,
)
Expand Down Expand Up @@ -544,20 +546,29 @@ def test_optimize_acqf_mixed_binary_only(self) -> None:
self.assertEqual(candidates.shape[-1], dim)
c_binary = candidates[:, binary_dims + [2]]
self.assertTrue(((c_binary == 0) | (c_binary == 1)).all())
# Only continuous parameters will raise an error.
with self.assertRaisesRegex(
ValueError,
"There must be at least one discrete parameter",
):
# Only continuous parameters should fallback to optimize_acqf.
with mock.patch(
f"{OPT_MODULE}._optimize_acqf", wraps=_optimize_acqf
) as wrapped_optimize:
optimize_acqf_mixed_alternating(
acq_function=acqf,
bounds=bounds,
discrete_dims=[],
options=options,
q=1,
raw_samples=20,
num_restarts=20,
num_restarts=2,
)
wrapped_optimize.assert_called_once_with(
opt_inputs=_make_opt_inputs(
acq_function=acqf,
bounds=bounds,
options=options,
q=1,
raw_samples=20,
num_restarts=2,
)
)
# Only discrete works fine.
candidates, _ = optimize_acqf_mixed_alternating(
acq_function=acqf,
Expand Down Expand Up @@ -720,3 +731,71 @@ def test_optimize_acqf_mixed_integer(self) -> None:
wrapped_sample_feasible.assert_called_once()
# Should request 4 candidates, since all 4 are infeasible.
self.assertEqual(wrapped_sample_feasible.call_args.kwargs["num_points"], 4)

def test_optimize_acqf_mixed_continuous_relaxation(self) -> None:
# Testing with integer variables.
train_X, train_Y, binary_dims, cont_dims = self._get_data()
# Update the data to introduce integer dimensions.
binary_dims = [0]
integer_dims = [3, 4]
discrete_dims = binary_dims + integer_dims
bounds = torch.tensor(
[[0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 40.0, 15.0]],
dtype=torch.double,
device=self.device,
)
# Update the model to have a different optimizer.
root = torch.tensor([0.0, 0.0, 0.0, 25.0, 10.0], device=self.device)
model = QuadraticDeterministicModel(root)
acqf = qLogNoisyExpectedImprovement(model=model, X_baseline=train_X)

for max_discrete_values, post_processing_func in (
(None, None),
(5, lambda X: X + 10),
):
options = {
"batch_limit": 5,
"init_batch_limit": 20,
"maxiter_alternating": 1,
}
if max_discrete_values is not None:
options["max_discrete_values"] = max_discrete_values
with mock.patch(
f"{OPT_MODULE}._setup_continuous_relaxation",
wraps=_setup_continuous_relaxation,
) as wrapped_setup, mock.patch(
f"{OPT_MODULE}.discrete_step", wraps=discrete_step
) as wrapped_discrete:
candidates, _ = optimize_acqf_mixed_alternating(
acq_function=acqf,
bounds=bounds,
discrete_dims=discrete_dims,
q=3,
raw_samples=32,
num_restarts=4,
options=options,
post_processing_func=post_processing_func,
)
wrapped_setup.assert_called_once_with(
discrete_dims=discrete_dims,
bounds=bounds,
max_discrete_values=max_discrete_values or MAX_DISCRETE_VALUES,
post_processing_func=post_processing_func,
)
discrete_call_args = wrapped_discrete.call_args.kwargs
expected_dims = [0, 4] if max_discrete_values is None else [0]
self.assertAllClose(
discrete_call_args["discrete_dims"],
torch.tensor(expected_dims, device=self.device),
)
# Check that dim 3 is rounded.
X = torch.ones(1, 5, device=self.device) * 0.6
X_expected = X.clone()
X_expected[0, 3] = 1.0
if max_discrete_values is not None:
X_expected[0, 4] = 1.0
if post_processing_func is not None:
X_expected = post_processing_func(X_expected)
self.assertAllClose(
discrete_call_args["opt_inputs"].post_processing_func(X), X_expected
)
2 changes: 1 addition & 1 deletion test/test_utils/test_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def test_mock_optimize_mixed_alternating(self) -> None:
) as mock_neighbors:
optimize_acqf_mixed_alternating(
acq_function=SinAcqusitionFunction(),
bounds=torch.tensor([[-2.0, 0.0], [2.0, 200.0]]),
bounds=torch.tensor([[-2.0, 0.0], [2.0, 20.0]]),
discrete_dims=[1],
num_restarts=1,
)
Expand Down

0 comments on commit 5d37606

Please sign in to comment.