Skip to content

Commit

Permalink
Add TopK downselection for initial batch generation. (#2636)
Browse files Browse the repository at this point in the history
Summary:
## Motivation

In order to get facebook/Ax#2938 over the line with initial candidate generation that obey the constraints we want to use the existing tooling within `botorch`. The hard coded logic currently in Ax uses topk to downselect the sobol samples. To make a change there that will not impact existing users we then need to implement topk downselection in `botorch`.

Pull Request resolved: #2636

Test Plan:
TODO:
- [x] add tests for initialize_q_batch_topk

## Related PRs

facebook/Ax#2938. (#2610 was initially intended to play part of this solution but then I realized that the pattern I wanted to use was conflating repeats and the batch dimension.)

Reviewed By: Balandat

Differential Revision: D66413947

Pulled By: saitcakmak

fbshipit-source-id: 39e71f5cc0468d554419fa25dd545d9ee25289dc
  • Loading branch information
CompRhys authored and facebook-github-bot committed Dec 6, 2024
1 parent 0d5e131 commit c1eb255
Show file tree
Hide file tree
Showing 6 changed files with 476 additions and 268 deletions.
7 changes: 6 additions & 1 deletion botorch/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@
LinearHomotopySchedule,
LogLinearHomotopySchedule,
)
from botorch.optim.initializers import initialize_q_batch, initialize_q_batch_nonneg
from botorch.optim.initializers import (
initialize_q_batch,
initialize_q_batch_nonneg,
initialize_q_batch_topn,
)
from botorch.optim.optimize import (
gen_batch_initial_conditions,
optimize_acqf,
Expand All @@ -43,6 +47,7 @@
"gen_batch_initial_conditions",
"initialize_q_batch",
"initialize_q_batch_nonneg",
"initialize_q_batch_topn",
"OptimizationResult",
"OptimizationStatus",
"optimize_acqf",
Expand Down
104 changes: 90 additions & 14 deletions botorch/optim/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,13 +271,15 @@ def gen_batch_initial_conditions(
fixed_features: A map `{feature_index: value}` for features that
should be fixed to a particular value during generation.
options: Options for initial condition generation. For valid options see
`initialize_q_batch` and `initialize_q_batch_nonneg`. If `options`
contains a `nonnegative=True` entry, then `acq_function` is
assumed to be non-negative (useful when using custom acquisition
functions). In addition, an "init_batch_limit" option can be passed
to specify the batch limit for the initialization. This is useful
for avoiding memory limits when computing the batch posterior over
raw samples.
`initialize_q_batch_topn`, `initialize_q_batch_nonneg`, and
`initialize_q_batch`. If `options` contains a `topn=True` then
`initialize_q_batch_topn` will be used. Else if `options` contains a
`nonnegative=True` entry, then `acq_function` is assumed to be
non-negative (useful when using custom acquisition functions).
`initialize_q_batch` will be used otherwise. In addition, an
"init_batch_limit" option can be passed to specify the batch limit
for the initialization. This is useful for avoiding memory limits
when computing the batch posterior over raw samples.
inequality constraints: A list of tuples (indices, coefficients, rhs),
with each tuple encoding an inequality constraint of the form
`\sum_i (X[indices[i]] * coefficients[i]) >= rhs`.
Expand Down Expand Up @@ -328,14 +330,24 @@ def gen_batch_initial_conditions(
init_kwargs = {}
device = bounds.device
bounds_cpu = bounds.cpu()
if "eta" in options:
init_kwargs["eta"] = options.get("eta")
if options.get("nonnegative") or is_nonnegative(acq_function):

if options.get("topn"):
init_func = initialize_q_batch_topn
init_func_opts = ["sorted", "largest"]
elif options.get("nonnegative") or is_nonnegative(acq_function):
init_func = initialize_q_batch_nonneg
if "alpha" in options:
init_kwargs["alpha"] = options.get("alpha")
init_func_opts = ["alpha", "eta"]
else:
init_func = initialize_q_batch
init_func_opts = ["eta"]

for opt in init_func_opts:
# default value of "largest" to "acq_function.maximize" if it exists
if opt == "largest" and hasattr(acq_function, "maximize"):
init_kwargs[opt] = acq_function.maximize

if opt in options:
init_kwargs[opt] = options.get(opt)

q = 1 if q is None else q
# the dimension the samples are drawn from
Expand Down Expand Up @@ -363,7 +375,9 @@ def gen_batch_initial_conditions(
X_rnd_nlzd = torch.rand(
n, q, bounds_cpu.shape[-1], dtype=bounds.dtype
)
X_rnd = bounds_cpu[0] + (bounds_cpu[1] - bounds_cpu[0]) * X_rnd_nlzd
X_rnd = unnormalize(
X_rnd_nlzd, bounds_cpu, update_constant_bounds=False
)
else:
X_rnd = sample_q_batches_from_polytope(
n=n,
Expand All @@ -375,7 +389,8 @@ def gen_batch_initial_conditions(
equality_constraints=equality_constraints,
inequality_constraints=inequality_constraints,
)
# sample points around best

# sample additional points around best
if sample_around_best:
X_best_rnd = sample_points_around_best(
acq_function=acq_function,
Expand All @@ -395,6 +410,8 @@ def gen_batch_initial_conditions(
)
# Keep X on CPU for consistency & to limit GPU memory usage.
X_rnd = fix_features(X_rnd, fixed_features=fixed_features).cpu()

# Append the fixed fantasies to the randomly generated points
if fixed_X_fantasies is not None:
if (d_f := fixed_X_fantasies.shape[-1]) != (d_r := X_rnd.shape[-1]):
raise BotorchTensorDimensionError(
Expand All @@ -411,6 +428,9 @@ def gen_batch_initial_conditions(
],
dim=-2,
)

# Evaluate the acquisition function on `X_rnd` using `batch_limit`
# sized chunks.
with torch.no_grad():
if batch_limit is None:
batch_limit = X_rnd.shape[0]
Expand All @@ -423,16 +443,22 @@ def gen_batch_initial_conditions(
],
dim=0,
)

# Downselect the initial conditions based on the acquisition function values
batch_initial_conditions, _ = init_func(
X=X_rnd, acq_vals=acq_vals, n=num_restarts, **init_kwargs
)
batch_initial_conditions = batch_initial_conditions.to(device=device)

# Return the initial conditions if no warnings were raised
if not any(issubclass(w.category, BadInitialCandidatesWarning) for w in ws):
return batch_initial_conditions

if factor < max_factor:
factor += 1
if seed is not None:
seed += 1 # make sure to sample different X_rnd

warnings.warn(
"Unable to find non-zero acquisition function values - initial conditions "
"are being selected randomly.",
Expand Down Expand Up @@ -1057,6 +1083,56 @@ def initialize_q_batch_nonneg(
return X[idcs], acq_vals[idcs]


def initialize_q_batch_topn(
X: Tensor, acq_vals: Tensor, n: int, largest: bool = True, sorted: bool = True
) -> tuple[Tensor, Tensor]:
r"""Take the top `n` initial conditions for candidate generation.
Args:
X: A `b x q x d` tensor of `b` samples of `q`-batches from a `d`-dim.
feature space. Typically, these are generated using qMC.
acq_vals: A tensor of `b` outcomes associated with the samples. Typically, this
is the value of the batch acquisition function to be maximized.
n: The number of initial condition to be generated. Must be less than `b`.
Returns:
- An `n x q x d` tensor of `n` `q`-batch initial conditions.
- An `n` tensor of the corresponding acquisition values.
Example:
>>> # To get `n=10` starting points of q-batch size `q=3`
>>> # for model with `d=6`:
>>> qUCB = qUpperConfidenceBound(model, beta=0.1)
>>> X_rnd = torch.rand(500, 3, 6)
>>> X_init, acq_init = initialize_q_batch_topn(
... X=X_rnd, acq_vals=qUCB(X_rnd), n=10
... )
"""
n_samples = X.shape[0]
if n > n_samples:
raise RuntimeError(
f"n ({n}) cannot be larger than the number of "
f"provided samples ({n_samples})"
)
elif n == n_samples:
return X, acq_vals

Ystd = acq_vals.std(dim=0)
if torch.any(Ystd == 0):
warnings.warn(
"All acquisition values for raw samples points are the same for "
"at least one batch. Choosing initial conditions at random.",
BadInitialCandidatesWarning,
stacklevel=3,
)
idcs = torch.randperm(n=n_samples, device=X.device)[:n]
return X[idcs], acq_vals[idcs]

topk_out, topk_idcs = acq_vals.topk(n, largest=largest, sorted=sorted)
return X[topk_idcs], topk_out


def sample_points_around_best(
acq_function: AcquisitionFunction,
n_discrete_points: int,
Expand Down
5 changes: 3 additions & 2 deletions botorch/utils/feasible_volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import botorch.models.model as model
import torch
from botorch.logging import _get_logger
from botorch.utils.sampling import manual_seed
from botorch.utils.sampling import manual_seed, unnormalize
from torch import Tensor


Expand Down Expand Up @@ -164,9 +164,10 @@ def estimate_feasible_volume(
seed = seed if seed is not None else torch.randint(0, 1000000, (1,)).item()

with manual_seed(seed=seed):
box_samples = bounds[0] + (bounds[1] - bounds[0]) * torch.rand(
samples_nlzd = torch.rand(
(nsample_feature, bounds.size(1)), dtype=dtype, device=device
)
box_samples = unnormalize(samples_nlzd, bounds, update_constant_bounds=False)

features, p_feature = get_feasible_samples(
samples=box_samples, inequality_constraints=inequality_constraints
Expand Down
8 changes: 3 additions & 5 deletions botorch/utils/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,12 @@ def draw_sobol_samples(
batch_shape = batch_shape or torch.Size()
batch_size = int(torch.prod(torch.tensor(batch_shape)))
d = bounds.shape[-1]
lower = bounds[0]
rng = bounds[1] - bounds[0]
sobol_engine = SobolEngine(q * d, scramble=True, seed=seed)
samples_raw = sobol_engine.draw(batch_size * n, dtype=lower.dtype)
samples_raw = samples_raw.view(*batch_shape, n, q, d).to(device=lower.device)
samples_raw = sobol_engine.draw(batch_size * n, dtype=bounds.dtype)
samples_raw = samples_raw.view(*batch_shape, n, q, d).to(device=bounds.device)
if batch_shape != torch.Size():
samples_raw = samples_raw.permute(-3, *range(len(batch_shape)), -2, -1)
return lower + rng * samples_raw
return unnormalize(samples_raw, bounds, update_constant_bounds=False)


def draw_sobol_normal_samples(
Expand Down
33 changes: 21 additions & 12 deletions botorch/utils/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,17 +66,18 @@ def _update_constant_bounds(bounds: Tensor) -> Tensor:
return bounds


def normalize(X: Tensor, bounds: Tensor) -> Tensor:
def normalize(X: Tensor, bounds: Tensor, update_constant_bounds: bool = True) -> Tensor:
r"""Min-max normalize X w.r.t. the provided bounds.
NOTE: If the upper and lower bounds are identical for a dimension, that dimension
will not be scaled. Such dimensions will only be shifted as
`new_X[..., i] = X[..., i] - bounds[0, i]`. This avoids division by zero issues.
Args:
X: `... x d` tensor of data
bounds: `2 x d` tensor of lower and upper bounds for each of the X's d
columns.
update_constant_bounds: If `True`, update the constant bounds in order to
avoid division by zero issues. When the upper and lower bounds are
identical for a dimension, that dimension will not be scaled. Such
dimensions will only be shifted as
`new_X[..., i] = X[..., i] - bounds[0, i]`.
Returns:
A `... x d`-dim tensor of normalized data, given by
Expand All @@ -89,21 +90,27 @@ def normalize(X: Tensor, bounds: Tensor) -> Tensor:
>>> bounds = torch.stack([torch.zeros(3), 0.5 * torch.ones(3)])
>>> X_normalized = normalize(X, bounds)
"""
bounds = _update_constant_bounds(bounds=bounds)
bounds = (
_update_constant_bounds(bounds=bounds) if update_constant_bounds else bounds
)
return (X - bounds[0]) / (bounds[1] - bounds[0])


def unnormalize(X: Tensor, bounds: Tensor) -> Tensor:
def unnormalize(
X: Tensor, bounds: Tensor, update_constant_bounds: bool = True
) -> Tensor:
r"""Un-normalizes X w.r.t. the provided bounds.
NOTE: If the upper and lower bounds are identical for a dimension, that dimension
will not be scaled. Such dimensions will only be shifted as
`new_X[..., i] = X[..., i] + bounds[0, i]`, matching the behavior of `normalize`.
Args:
X: `... x d` tensor of data
bounds: `2 x d` tensor of lower and upper bounds for each of the X's d
columns.
update_constant_bounds: If `True`, update the constant bounds in order to
avoid division by zero issues. When the upper and lower bounds are
identical for a dimension, that dimension will not be scaled. Such
dimensions will only be shifted as
`new_X[..., i] = X[..., i] + bounds[0, i]`. This is the inverse of
the behavior of `normalize` when `update_constant_bounds=True`.
Returns:
A `... x d`-dim tensor of unnormalized data, given by
Expand All @@ -116,7 +123,9 @@ def unnormalize(X: Tensor, bounds: Tensor) -> Tensor:
>>> bounds = torch.stack([torch.zeros(3), 0.5 * torch.ones(3)])
>>> X = unnormalize(X_normalized, bounds)
"""
bounds = _update_constant_bounds(bounds=bounds)
bounds = (
_update_constant_bounds(bounds=bounds) if update_constant_bounds else bounds
)
return X * (bounds[1] - bounds[0]) + bounds[0]


Expand Down
Loading

0 comments on commit c1eb255

Please sign in to comment.