Skip to content

Commit

Permalink
Optimization over mixed spaces in optimize_acqf_homotopy (#2639)
Browse files Browse the repository at this point in the history
Summary:
<!--
Thank you for sending the PR! We appreciate you spending the time to make BoTorch better.

Help us understand your motivation by explaining why you decided to make this change.

You can learn more about contributing to BoTorch here: https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md
-->

## Motivation

So far, `optimize_acqf_homotopy` can only handle `fixed_features` but not `fixed_features_list` which is useful for optimization over mixed spaces. In the spirit of `optimize_acqf_list`, this PR adds the option for `fixed_features_list` by using `optimize_acqf_mixed` instead of `optimize_acqf` in `optimize_acqf_homotopy`, when `fixed_features_list` is provided.

Currently, it is not working as `optimize_acqf_mixed` has no option `return_best_only=False` like `optimize_acqf`. The easiest way to solve it (maybe not the best), would be to provide an option to return the candidates from all restarts for the finally picked fixed features combination in `optimize_acqf_mixed`. For the start, it would be sufficient to only implement it for `q=1`, and error out for the case of `q>1`.

What do you think?

### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)?

Yes.

Pull Request resolved: #2639

Test Plan: Unit tests.

Reviewed By: esantorella

Differential Revision: D66671709

Pulled By: Balandat

fbshipit-source-id: 1afdb8be2ba21735e832c9aef56fca008f8f30ea
  • Loading branch information
jduerholt authored and facebook-github-bot committed Dec 3, 2024
1 parent 4190f74 commit 88f47bc
Show file tree
Hide file tree
Showing 4 changed files with 208 additions and 39 deletions.
56 changes: 48 additions & 8 deletions botorch/optim/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,7 +933,11 @@ def optimize_acqf_mixed(
nonlinear_inequality_constraints: list[tuple[Callable, bool]] | None = None,
post_processing_func: Callable[[Tensor], Tensor] | None = None,
batch_initial_conditions: Tensor | None = None,
return_best_only: bool = True,
gen_candidates: TGenCandidates | None = None,
ic_generator: TGenInitialConditions | None = None,
timeout_sec: float | None = None,
retry_on_optimization_warning: bool = True,
ic_gen_kwargs: dict | None = None,
) -> tuple[Tensor, Tensor]:
r"""Optimize over a list of fixed_features and returns the best solution.
Expand Down Expand Up @@ -982,20 +986,38 @@ def optimize_acqf_mixed(
transformations).
batch_initial_conditions: A tensor to specify the initial conditions. Set
this if you do not want to use default initialization strategy.
return_best_only: If False, outputs the solutions corresponding to all
random restart initializations of the optimization. Setting this keyword
to False is only allowed for `q=1`. Defaults to True.
gen_candidates: A callable for generating candidates (and their associated
acquisition values) given a tensor of initial conditions and an
acquisition function. Other common inputs include lower and upper bounds
and a dictionary of options, but refer to the documentation of specific
generation functions (e.g gen_candidates_scipy and gen_candidates_torch)
for method-specific inputs. Default: `gen_candidates_scipy`
ic_generator: Function for generating initial conditions. Not needed when
`batch_initial_conditions` are provided. Defaults to
`gen_one_shot_kg_initial_conditions` for `qKnowledgeGradient` acquisition
functions and `gen_batch_initial_conditions` otherwise. Must be specified
for nonlinear inequality constraints.
timeout_sec: Max amount of time optimization can run for.
retry_on_optimization_warning: Whether to retry candidate generation with a new
set of initial conditions when it fails with an `OptimizationWarning`.
ic_gen_kwargs: Additional keyword arguments passed to function specified by
`ic_generator`
Returns:
A two-element tuple containing
- a `q x d`-dim tensor of generated candidates.
- an associated acquisition value.
- A tensor of generated candidates. The shape is
-- `q x d` if `return_best_only` is True (default)
-- `num_restarts x q x d` if `return_best_only` is False
- a tensor of associated acquisition values of dim `num_restarts`
if `return_best_only=False` else a scalar acquisition value.
"""
if not return_best_only and q > 1:
raise NotImplementedError("`return_best_only=False` is only supported for q=1.")

if not fixed_features_list:
raise ValueError("fixed_features_list must be non-empty.")

Expand All @@ -1010,11 +1032,12 @@ def optimize_acqf_mixed(
ic_gen_kwargs = ic_gen_kwargs or {}

if q == 1:
timeout_sec = timeout_sec / len(fixed_features_list) if timeout_sec else None
ff_candidate_list, ff_acq_value_list = [], []
num_candidate_generation_failures = 0
for fixed_features in fixed_features_list:
try:
candidate, acq_value = optimize_acqf(
candidates, acq_values = optimize_acqf(
acq_function=acq_function,
bounds=bounds,
q=q,
Expand All @@ -1028,15 +1051,19 @@ def optimize_acqf_mixed(
post_processing_func=post_processing_func,
batch_initial_conditions=batch_initial_conditions,
ic_generator=ic_generator,
return_best_only=True,
return_best_only=False, # here we always return all candidates
# and filter later
gen_candidates=gen_candidates,
timeout_sec=timeout_sec,
retry_on_optimization_warning=retry_on_optimization_warning,
**ic_gen_kwargs,
)
except CandidateGenerationError:
# if candidate generation fails, we skip this candidate
num_candidate_generation_failures += 1
continue
ff_candidate_list.append(candidate)
ff_acq_value_list.append(acq_value)
ff_candidate_list.append(candidates)
ff_acq_value_list.append(acq_values)

if len(ff_candidate_list) == 0:
raise CandidateGenerationError(
Expand All @@ -1051,16 +1078,25 @@ def optimize_acqf_mixed(
OptimizationWarning,
stacklevel=3,
)

ff_acq_values = torch.stack(ff_acq_value_list)
best = torch.argmax(ff_acq_values)
return ff_candidate_list[best], ff_acq_values[best]
max_res = torch.max(ff_acq_values, dim=-1)
best_batch_idx = torch.argmax(max_res.values)
best_batch_candidates = ff_candidate_list[best_batch_idx]
best_acq_values = ff_acq_value_list[best_batch_idx]
if not return_best_only:
return best_batch_candidates, best_acq_values

best_idx = max_res.indices[best_batch_idx]
return best_batch_candidates[best_idx], best_acq_values[best_idx]

# For batch optimization with q > 1 we do not want to enumerate all n_combos^n
# possible combinations of discrete choices. Instead, we use sequential greedy
# optimization.
base_X_pending = acq_function.X_pending
candidates = torch.tensor([], device=bounds.device, dtype=bounds.dtype)

timeout_sec = timeout_sec / q if timeout_sec else None
for _ in range(q):
candidate, acq_value = optimize_acqf_mixed(
acq_function=acq_function,
Expand All @@ -1075,8 +1111,12 @@ def optimize_acqf_mixed(
nonlinear_inequality_constraints=nonlinear_inequality_constraints,
post_processing_func=post_processing_func,
batch_initial_conditions=batch_initial_conditions,
gen_candidates=gen_candidates,
ic_generator=ic_generator,
ic_gen_kwargs=ic_gen_kwargs,
timeout_sec=timeout_sec,
retry_on_optimization_warning=retry_on_optimization_warning,
return_best_only=True,
)
candidates = torch.cat([candidates, candidate], dim=-2)
acq_function.set_X_pending(
Expand Down
52 changes: 40 additions & 12 deletions botorch/optim/optimize_homotopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from __future__ import annotations

import warnings

from collections.abc import Callable

from typing import Any
Expand All @@ -15,7 +17,7 @@
from botorch.generation.gen import TGenCandidates
from botorch.optim.homotopy import Homotopy
from botorch.optim.initializers import TGenInitialConditions
from botorch.optim.optimize import optimize_acqf
from botorch.optim.optimize import optimize_acqf, optimize_acqf_mixed
from torch import Tensor


Expand Down Expand Up @@ -67,14 +69,13 @@ def optimize_acqf_homotopy(
equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
nonlinear_inequality_constraints: list[tuple[Callable, bool]] | None = None,
fixed_features: dict[int, float] | None = None,
fixed_features_list: list[dict[int, float]] | None = None,
post_processing_func: Callable[[Tensor], Tensor] | None = None,
batch_initial_conditions: Tensor | None = None,
gen_candidates: TGenCandidates | None = None,
sequential: bool = False,
*,
ic_generator: TGenInitialConditions | None = None,
timeout_sec: float | None = None,
return_full_tree: bool = False,
retry_on_optimization_warning: bool = True,
**ic_gen_kwargs: Any,
) -> tuple[Tensor, Tensor]:
Expand Down Expand Up @@ -129,6 +130,10 @@ def optimize_acqf_homotopy(
`options`.
fixed_features: A map `{feature_index: value}` for features that
should be fixed to a particular value during generation.
fixed_features_list: A list of maps `{feature_index: value}`. The i-th
item represents the fixed_feature for the i-th optimization. If
`fixed_features_list` is provided, `optimize_acqf_mixed` is invoked.
All indices (`feature_index`) should be non-negative.
post_processing_func: A function that post-processes an optimization
result appropriately (i.e., according to `round-trip`
transformations).
Expand All @@ -140,37 +145,57 @@ def optimize_acqf_homotopy(
and a dictionary of options, but refer to the documentation of specific
generation functions (e.g gen_candidates_scipy and gen_candidates_torch)
for method-specific inputs. Default: `gen_candidates_scipy`
sequential: If False, uses joint optimization, otherwise uses sequential
optimization.
ic_generator: Function for generating initial conditions. Not needed when
`batch_initial_conditions` are provided. Defaults to
`gen_one_shot_kg_initial_conditions` for `qKnowledgeGradient` acquisition
functions and `gen_batch_initial_conditions` otherwise. Must be specified
for nonlinear inequality constraints.
timeout_sec: Max amount of time optimization can run for.
return_full_tree: Return the full tree of optimizers of the previous
iteration.
retry_on_optimization_warning: Whether to retry candidate generation with a new
set of initial conditions when it fails with an `OptimizationWarning`.
ic_gen_kwargs: Additional keyword arguments passed to function specified by
`ic_generator`
"""
if fixed_features and fixed_features_list:
raise ValueError(
"Either `fixed_feature` or `fixed_features_list` can be provided, not both."
)

if fixed_features:
message = (
"The `fixed_features` argument is deprecated, "
"use `fixed_features_list` instead."
)
warnings.warn(
message,
DeprecationWarning,
stacklevel=2,
)

shared_optimize_acqf_kwargs = {
"num_restarts": num_restarts,
"inequality_constraints": inequality_constraints,
"equality_constraints": equality_constraints,
"nonlinear_inequality_constraints": nonlinear_inequality_constraints,
"fixed_features": fixed_features,
"return_best_only": False, # False to make n_restarts persist through homotopy.
"gen_candidates": gen_candidates,
"sequential": sequential,
"ic_generator": ic_generator,
"timeout_sec": timeout_sec,
"return_full_tree": return_full_tree,
"retry_on_optimization_warning": retry_on_optimization_warning,
**ic_gen_kwargs,
}

if fixed_features_list and len(fixed_features_list) > 1:
optimization_fn = optimize_acqf_mixed
fixed_features_kwargs = {"fixed_features_list": fixed_features_list}
else:
optimization_fn = optimize_acqf
fixed_features_kwargs = {
"fixed_features": fixed_features_list[0]
if fixed_features_list
else fixed_features
}

candidate_list, acq_value_list = [], []
if q > 1:
base_X_pending = acq_function.X_pending
Expand All @@ -181,15 +206,17 @@ def optimize_acqf_homotopy(
homotopy.restart()

while not homotopy.should_stop:
candidates, acq_values = optimize_acqf(
candidates, acq_values = optimization_fn(
acq_function=acq_function,
bounds=bounds,
q=1,
options=options,
batch_initial_conditions=candidates,
raw_samples=q_raw_samples,
**fixed_features_kwargs,
**shared_optimize_acqf_kwargs,
)

homotopy.step()

# Set raw_samples to None such that pruned restarts are not repopulated
Expand All @@ -204,13 +231,14 @@ def optimize_acqf_homotopy(
).unsqueeze(1)

# Optimize one more time with the final options
candidates, acq_values = optimize_acqf(
candidates, acq_values = optimization_fn(
acq_function=acq_function,
bounds=bounds,
q=1,
options=final_options,
raw_samples=q_raw_samples,
batch_initial_conditions=candidates,
**fixed_features_kwargs,
**shared_optimize_acqf_kwargs,
)

Expand Down
60 changes: 54 additions & 6 deletions test/optim/test_homotopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def test_optimize_acqf_homotopy(self):
candidate, acqf_val = optimize_acqf_homotopy(
q=1,
acq_function=acqf,
bounds=torch.tensor([[-10], [5]]).to(**tkwargs),
bounds=torch.tensor([[-10], [5]], **tkwargs),
homotopy=Homotopy(homotopy_parameters=[hp]),
num_restarts=2,
raw_samples=16,
Expand All @@ -132,14 +132,62 @@ def test_optimize_acqf_homotopy(self):
f=lambda x: 5 - (x - p).sum(dim=-1, keepdims=True) ** 2
)
acqf = PosteriorMean(model=model)
# test raise warning on using `fixed_features` argument
message = (
"The `fixed_features` argument is deprecated, "
"use `fixed_features_list` instead."
)
with self.assertWarnsRegex(DeprecationWarning, message):
optimize_acqf_homotopy(
q=1,
acq_function=acqf,
bounds=torch.tensor([[-10, -10], [5, 5]]).to(**tkwargs),
homotopy=Homotopy(homotopy_parameters=[hp]),
num_restarts=2,
raw_samples=16,
fixed_features=fixed_features,
)

candidate, acqf_val = optimize_acqf_homotopy(
q=1,
acq_function=acqf,
bounds=torch.tensor([[-10, -10], [5, 5]], **tkwargs),
homotopy=Homotopy(homotopy_parameters=[hp]),
num_restarts=2,
raw_samples=16,
fixed_features_list=[fixed_features],
)
self.assertEqual(candidate[0, 0], torch.tensor(1, **tkwargs))

# test fixed feature list
fixed_features_list = [{0: 1.0}, {1: 3.0}]
model = GenericDeterministicModel(
f=lambda x: 5 - (x - p).sum(dim=-1, keepdims=True) ** 2
)
acqf = PosteriorMean(model=model)
# test raise error when fixed_features and fixed_features_list are both provided
with self.assertRaisesRegex(
ValueError,
"Either `fixed_feature` or `fixed_features_list` can be provided",
):
optimize_acqf_homotopy(
q=1,
acq_function=acqf,
bounds=torch.tensor([[-10, -10, -10], [5, 5, 5]], **tkwargs),
homotopy=Homotopy(homotopy_parameters=[hp]),
num_restarts=2,
raw_samples=16,
fixed_features_list=fixed_features_list,
fixed_features=fixed_features,
)
candidate, acqf_val = optimize_acqf_homotopy(
q=1,
acq_function=acqf,
bounds=torch.tensor([[-10, -10], [5, 5]]).to(**tkwargs),
bounds=torch.tensor([[-10, -10, -10], [5, 5, 5]], **tkwargs),
homotopy=Homotopy(homotopy_parameters=[hp]),
num_restarts=2,
raw_samples=16,
fixed_features=fixed_features,
fixed_features_list=fixed_features_list,
)
self.assertEqual(candidate[0, 0], torch.tensor(1, **tkwargs))

Expand All @@ -148,11 +196,11 @@ def test_optimize_acqf_homotopy(self):
candidate, acqf_val = optimize_acqf_homotopy(
q=3,
acq_function=acqf,
bounds=torch.tensor([[-10, -10], [5, 5]]).to(**tkwargs),
bounds=torch.tensor([[-10, -10], [5, 5]], **tkwargs),
homotopy=Homotopy(homotopy_parameters=[hp]),
num_restarts=2,
raw_samples=16,
fixed_features=fixed_features,
fixed_features_list=[fixed_features],
)
self.assertEqual(candidate.shape, torch.Size([3, 2]))
self.assertEqual(acqf_val.shape, torch.Size([3]))
Expand All @@ -170,7 +218,7 @@ def test_optimize_acqf_homotopy(self):
candidate, acqf_val = optimize_acqf_homotopy(
q=1,
acq_function=acqf,
bounds=torch.tensor([[-10, -10], [5, 5]]).to(**tkwargs),
bounds=torch.tensor([[-10, -10], [5, 5]], **tkwargs),
homotopy=Homotopy(homotopy_parameters=[hp]),
num_restarts=2,
raw_samples=16,
Expand Down
Loading

0 comments on commit 88f47bc

Please sign in to comment.