Skip to content

Commit

Permalink
Add ability to mix batch initial conditions and internal IC generation (
Browse files Browse the repository at this point in the history
#2610)

Summary:
## Motivation

See #2609

### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)?

Yes

Pull Request resolved: #2610

Test Plan:
Basic testing of the code is easy the challenge is working out what the run on implications might be, will this break people's code?

## Related PRs

facebook/Ax#2938

Reviewed By: Balandat

Differential Revision: D66102868

Pulled By: saitcakmak

fbshipit-source-id: b3491581a205b0fbe62edd670510e95f13e08177
  • Loading branch information
CompRhys authored and facebook-github-bot committed Dec 3, 2024
1 parent a1763a1 commit 4190f74
Show file tree
Hide file tree
Showing 3 changed files with 296 additions and 60 deletions.
202 changes: 165 additions & 37 deletions botorch/optim/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,43 @@ def __post_init__(self) -> None:
"3-dimensional. Its shape is "
f"{batch_initial_conditions_shape}."
)

if batch_initial_conditions_shape[-1] != d:
raise ValueError(
f"batch_initial_conditions.shape[-1] must be {d}. The "
f"shape is {batch_initial_conditions_shape}."
)

if len(batch_initial_conditions_shape) == 2:
warnings.warn(
"If using a 2-dim `batch_initial_conditions` botorch will "
"default to old behavior of ignoring `num_restarts` and just "
"use the given `batch_initial_conditions` by setting "
"`raw_samples` to None.",
RuntimeWarning,
stacklevel=3,
)
# Use object.__setattr__ to bypass immutability and set a value
object.__setattr__(self, "raw_samples", None)

if (
len(batch_initial_conditions_shape) == 3
and batch_initial_conditions_shape[0] < self.num_restarts
and batch_initial_conditions_shape[-2] != self.q
):
warnings.warn(
"If using a 3-dim `batch_initial_conditions` where the "
"first dimension is less than `num_restarts` and the second "
"dimension is not equal to `q`, botorch will default to "
"old behavior of ignoring `num_restarts` and just use the "
"given `batch_initial_conditions` by setting `raw_samples` "
"to None.",
RuntimeWarning,
stacklevel=3,
)
# Use object.__setattr__ to bypass immutability and set a value
object.__setattr__(self, "raw_samples", None)

elif self.ic_generator is None:
if self.nonlinear_inequality_constraints is not None:
raise RuntimeError(
Expand All @@ -126,6 +157,7 @@ def __post_init__(self) -> None:
"Must specify `raw_samples` when "
"`batch_initial_conditions` is None`."
)

if self.fixed_features is not None and any(
(k < 0 for k in self.fixed_features)
):
Expand Down Expand Up @@ -253,20 +285,49 @@ def _optimize_acqf_sequential_q(
return candidates, torch.stack(acq_value_list)


def _combine_initial_conditions(
provided_initial_conditions: Tensor | None = None,
generated_initial_conditions: Tensor | None = None,
dim=0,
) -> Tensor:
if (
provided_initial_conditions is not None
and generated_initial_conditions is not None
):
return torch.cat(
[provided_initial_conditions, generated_initial_conditions], dim=dim
)
elif provided_initial_conditions is not None:
return provided_initial_conditions
elif generated_initial_conditions is not None:
return generated_initial_conditions
else:
raise ValueError(
"Either `batch_initial_conditions` or `raw_samples` must be set."
)


def _optimize_acqf_batch(opt_inputs: OptimizeAcqfInputs) -> tuple[Tensor, Tensor]:
options = opt_inputs.options or {}

initial_conditions_provided = opt_inputs.batch_initial_conditions is not None
required_num_restarts = opt_inputs.num_restarts
provided_initial_conditions = opt_inputs.batch_initial_conditions
generated_initial_conditions = None

if initial_conditions_provided:
batch_initial_conditions = opt_inputs.batch_initial_conditions
else:
# pyre-ignore[28]: Unexpected keyword argument `acq_function` to anonymous call.
batch_initial_conditions = opt_inputs.get_ic_generator()(
if (
provided_initial_conditions is not None
and len(provided_initial_conditions.shape) == 3
):
required_num_restarts -= provided_initial_conditions.shape[0]

if opt_inputs.raw_samples is not None and required_num_restarts > 0:
# pyre-ignore[28]: Unexpected keyword argument `acq_function`
# to anonymous call.
generated_initial_conditions = opt_inputs.get_ic_generator()(
acq_function=opt_inputs.acq_function,
bounds=opt_inputs.bounds,
q=opt_inputs.q,
num_restarts=opt_inputs.num_restarts,
num_restarts=required_num_restarts,
raw_samples=opt_inputs.raw_samples,
fixed_features=opt_inputs.fixed_features,
options=options,
Expand All @@ -275,6 +336,11 @@ def _optimize_acqf_batch(opt_inputs: OptimizeAcqfInputs) -> tuple[Tensor, Tensor
**opt_inputs.ic_gen_kwargs,
)

batch_initial_conditions = _combine_initial_conditions(
provided_initial_conditions=provided_initial_conditions,
generated_initial_conditions=generated_initial_conditions,
)

batch_limit: int = options.get(
"batch_limit",
(
Expand Down Expand Up @@ -344,23 +410,24 @@ def _optimize_batch_candidates() -> tuple[Tensor, Tensor, list[Warning]]:
first_warn_msg = (
"Optimization failed in `gen_candidates_scipy` with the following "
f"warning(s):\n{[w.message for w in ws]}\nBecause you specified "
"`batch_initial_conditions`, optimization will not be retried with "
"new initial conditions and will proceed with the current solution."
" Suggested remediation: Try again with different "
"`batch_initial_conditions`, or don't provide `batch_initial_conditions.`"
if initial_conditions_provided
"`batch_initial_conditions` larger than required `num_restarts`, "
"optimization will not be retried with new initial conditions and "
"will proceed with the current solution. Suggested remediation: "
"Try again with different `batch_initial_conditions`, don't provide "
"`batch_initial_conditions`, or increase `num_restarts`."
if batch_initial_conditions is not None and required_num_restarts <= 0
else "Optimization failed in `gen_candidates_scipy` with the following "
f"warning(s):\n{[w.message for w in ws]}\nTrying again with a new "
"set of initial conditions."
)
warnings.warn(first_warn_msg, RuntimeWarning, stacklevel=2)

if not initial_conditions_provided:
batch_initial_conditions = opt_inputs.get_ic_generator()(
if opt_inputs.raw_samples is not None and required_num_restarts > 0:
generated_initial_conditions = opt_inputs.get_ic_generator()(
acq_function=opt_inputs.acq_function,
bounds=opt_inputs.bounds,
q=opt_inputs.q,
num_restarts=opt_inputs.num_restarts,
num_restarts=required_num_restarts,
raw_samples=opt_inputs.raw_samples,
fixed_features=opt_inputs.fixed_features,
options=options,
Expand All @@ -369,6 +436,11 @@ def _optimize_batch_candidates() -> tuple[Tensor, Tensor, list[Warning]]:
**opt_inputs.ic_gen_kwargs,
)

batch_initial_conditions = _combine_initial_conditions(
provided_initial_conditions=provided_initial_conditions,
generated_initial_conditions=generated_initial_conditions,
)

batch_candidates, batch_acq_values, ws = _optimize_batch_candidates()

optimization_warning_raised = any(
Expand Down Expand Up @@ -1177,7 +1249,7 @@ def _gen_batch_initial_conditions_local_search(
inequality_constraints: list[tuple[Tensor, Tensor, float]],
min_points: int,
max_tries: int = 100,
):
) -> Tensor:
"""Generate initial conditions for local search."""
device = discrete_choices[0].device
dtype = discrete_choices[0].dtype
Expand All @@ -1197,6 +1269,58 @@ def _gen_batch_initial_conditions_local_search(
raise RuntimeError(f"Failed to generate at least {min_points} initial conditions")


def _gen_starting_points_local_search(
discrete_choices: list[Tensor],
raw_samples: int,
batch_initial_conditions: Tensor,
X_avoid: Tensor,
inequality_constraints: list[tuple[Tensor, Tensor, float]],
min_points: int,
acq_function: AcquisitionFunction,
max_batch_size: int = 2048,
max_tries: int = 100,
) -> Tensor:
required_min_points = min_points
provided_X0 = None
generated_X0 = None

if batch_initial_conditions is not None:
provided_X0 = _filter_invalid(
X=batch_initial_conditions.squeeze(1), X_avoid=X_avoid
)
provided_X0 = _filter_infeasible(
X=provided_X0, inequality_constraints=inequality_constraints
).unsqueeze(1)
required_min_points -= batch_initial_conditions.shape[0]

if required_min_points > 0:
generated_X0 = _gen_batch_initial_conditions_local_search(
discrete_choices=discrete_choices,
raw_samples=raw_samples,
X_avoid=X_avoid,
inequality_constraints=inequality_constraints,
min_points=min_points,
max_tries=max_tries,
)

# pick the best starting points
with torch.no_grad():
acqvals_init = _split_batch_eval_acqf(
acq_function=acq_function,
X=generated_X0.unsqueeze(1),
max_batch_size=max_batch_size,
).unsqueeze(-1)

generated_X0 = generated_X0[
acqvals_init.topk(k=min_points, largest=True, dim=0).indices
]

return _combine_initial_conditions(
provided_initial_conditions=provided_X0 if provided_X0 is not None else None,
generated_initial_conditions=generated_X0 if generated_X0 is not None else None,
)


def optimize_acqf_discrete_local_search(
acq_function: AcquisitionFunction,
discrete_choices: list[Tensor],
Expand All @@ -1207,6 +1331,7 @@ def optimize_acqf_discrete_local_search(
X_avoid: Tensor | None = None,
batch_initial_conditions: Tensor | None = None,
max_batch_size: int = 2048,
max_tries: int = 100,
unique: bool = True,
) -> tuple[Tensor, Tensor]:
r"""Optimize acquisition function over a lattice.
Expand Down Expand Up @@ -1238,6 +1363,8 @@ def optimize_acqf_discrete_local_search(
max_batch_size: The maximum number of choices to evaluate in batch.
A large limit can cause excessive memory usage if the model has
a large training set.
max_tries: Maximum number of iterations to try when generating initial
conditions.
unique: If True return unique choices, o/w choices may be repeated
(only relevant if `q > 1`).
Expand All @@ -1247,6 +1374,16 @@ def optimize_acqf_discrete_local_search(
- a `q x d`-dim tensor of generated candidates.
- an associated acquisition value.
"""
if batch_initial_conditions is not None:
if not (
len(batch_initial_conditions.shape) == 3
and batch_initial_conditions.shape[-2] == 1
):
raise ValueError(
"batch_initial_conditions must have shape `n x 1 x d` if "
f"given (received shape {batch_initial_conditions.shape})."
)

candidate_list = []
base_X_pending = acq_function.X_pending if q > 1 else None
base_X_avoid = X_avoid
Expand All @@ -1259,27 +1396,18 @@ def optimize_acqf_discrete_local_search(
inequality_constraints = inequality_constraints or []
for i in range(q):
# generate some starting points
if i == 0 and batch_initial_conditions is not None:
X0 = _filter_invalid(X=batch_initial_conditions.squeeze(1), X_avoid=X_avoid)
X0 = _filter_infeasible(
X=X0, inequality_constraints=inequality_constraints
).unsqueeze(1)
else:
X_init = _gen_batch_initial_conditions_local_search(
discrete_choices=discrete_choices,
raw_samples=raw_samples,
X_avoid=X_avoid,
inequality_constraints=inequality_constraints,
min_points=num_restarts,
)
# pick the best starting points
with torch.no_grad():
acqvals_init = _split_batch_eval_acqf(
acq_function=acq_function,
X=X_init.unsqueeze(1),
max_batch_size=max_batch_size,
).unsqueeze(-1)
X0 = X_init[acqvals_init.topk(k=num_restarts, largest=True, dim=0).indices]
X0 = _gen_starting_points_local_search(
discrete_choices=discrete_choices,
raw_samples=raw_samples,
batch_initial_conditions=batch_initial_conditions,
X_avoid=X_avoid,
inequality_constraints=inequality_constraints,
min_points=num_restarts,
acq_function=acq_function,
max_batch_size=max_batch_size,
max_tries=max_tries,
)
batch_initial_conditions = None

# optimize from the best starting points
best_xs = torch.zeros(len(X0), dim, device=device, dtype=dtype)
Expand Down
8 changes: 7 additions & 1 deletion botorch/optim/optimize_homotopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ def optimize_acqf_homotopy(
"""
shared_optimize_acqf_kwargs = {
"num_restarts": num_restarts,
"raw_samples": raw_samples,
"inequality_constraints": inequality_constraints,
"equality_constraints": equality_constraints,
"nonlinear_inequality_constraints": nonlinear_inequality_constraints,
Expand All @@ -178,6 +177,7 @@ def optimize_acqf_homotopy(

for _ in range(q):
candidates = batch_initial_conditions
q_raw_samples = raw_samples
homotopy.restart()

while not homotopy.should_stop:
Expand All @@ -187,10 +187,15 @@ def optimize_acqf_homotopy(
q=1,
options=options,
batch_initial_conditions=candidates,
raw_samples=q_raw_samples,
**shared_optimize_acqf_kwargs,
)
homotopy.step()

# Set raw_samples to None such that pruned restarts are not repopulated
# at each step in the homotopy.
q_raw_samples = None

# Prune candidates
candidates = prune_candidates(
candidates=candidates.squeeze(1),
Expand All @@ -204,6 +209,7 @@ def optimize_acqf_homotopy(
bounds=bounds,
q=1,
options=final_options,
raw_samples=q_raw_samples,
batch_initial_conditions=candidates,
**shared_optimize_acqf_kwargs,
)
Expand Down
Loading

0 comments on commit 4190f74

Please sign in to comment.