Skip to content

Commit

Permalink
Merge: Add efficient from_simplex validation (#166)
Browse files Browse the repository at this point in the history
Obejct creation form json with the `from_soimplex` constructor has so
far been done with the default method which is extensively creating the
entire object.

This is ineffective for cartesian-product like search spaces such as
made with `from_product` and `from_simplex`

This introduces a separate valdiator for that function. It also
consolidates some of the constraint filtering into a small utility.
  • Loading branch information
Scienfitz authored Mar 11, 2024
2 parents 5df7632 + ba3882d commit 16c07e0
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 48 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Replaced unmaintained `mordred` dependency by `mordredcommunity`
- `SearchSpace`s now use `ndarray` instead of `Tensor`

### Fixed
- `from_simplex` now efficiently validated in `Campaign.validate_config`

## [0.8.0] - 2024-02-29
### Changed
- BoTorch dependency bumped to `>=0.9.3`
Expand Down
20 changes: 16 additions & 4 deletions baybe/searchspace/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
)
from baybe.parameters.base import ContinuousParameter, DiscreteParameter, Parameter
from baybe.searchspace.continuous import SubspaceContinuous
from baybe.searchspace.discrete import SubspaceDiscrete
from baybe.searchspace.discrete import (
SubspaceDiscrete,
validate_simplex_subspace_from_config,
)
from baybe.searchspace.validation import validate_parameters
from baybe.serialization import SerialMixin, converter, select_constructor_hook
from baybe.telemetry import TELEM_LABELS, telemetry_record_value
Expand Down Expand Up @@ -300,7 +303,7 @@ def transform(

def validate_searchspace_from_config(specs: dict, _) -> None:
"""Validate the search space specifications while skipping costly creation steps."""
# For product spaces, only validate the inputs
# Validate product inputs without constructing it
if specs.get("constructor", None) == "from_product":
parameters = converter.structure(specs["parameters"], List[Parameter])
validate_parameters(parameters)
Expand All @@ -310,9 +313,18 @@ def validate_searchspace_from_config(specs: dict, _) -> None:
constraints = converter.structure(specs["constraints"], List[Constraint])
validate_constraints(constraints, parameters)

# For all other types, validate by construction
else:
converter.structure(specs, SearchSpace)
discrete_subspace_specs = specs.get("discrete", {})
if discrete_subspace_specs.get("constructor", None) == "from_simplex":
# Validate discrete simplex subspace
_validation_converter = converter.copy()
_validation_converter.register_structure_hook(
SubspaceDiscrete, validate_simplex_subspace_from_config
)
_validation_converter.structure(discrete_subspace_specs, SubspaceDiscrete)
else:
# For all other types, validate by construction
converter.structure(specs, SearchSpace)


# Register deserialization hook
Expand Down
95 changes: 75 additions & 20 deletions baybe/searchspace/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
)
from baybe.parameters.base import DiscreteParameter, Parameter
from baybe.parameters.utils import get_parameters_from_dataframe
from baybe.searchspace.validation import validate_parameter_names
from baybe.searchspace.validation import validate_parameter_names, validate_parameters
from baybe.serialization import SerialMixin, converter, select_constructor_hook
from baybe.utils.boolean import eq_dataframe
from baybe.utils.dataframe import (
Expand Down Expand Up @@ -197,24 +197,14 @@ def from_product(
empty_encoding: bool = False,
) -> SubspaceDiscrete:
"""See :class:`baybe.searchspace.core.SearchSpace`."""
# Store the input
if constraints is None:
constraints = []
else:
# Reorder the constraints according to their execution order
constraints = sorted(
constraints,
key=lambda x: DISCRETE_CONSTRAINTS_FILTERING_ORDER.index(x.__class__),
)
# Set defaults
constraints = constraints or []

# Create a dataframe representing the experimental search space
exp_rep = parameter_cartesian_prod_to_df(parameters)

# Remove entries that violate parameter constraints:
for constraint in (c for c in constraints if c.eval_during_creation):
idxs = constraint.get_invalid(exp_rep)
exp_rep.drop(index=idxs, inplace=True)
exp_rep.reset_index(inplace=True, drop=True)
# Remove entries that violate parameter constraints
_apply_constraint_filter(exp_rep, constraints)

return SubspaceDiscrete(
parameters=parameters,
Expand Down Expand Up @@ -354,7 +344,7 @@ def from_simplex(
max_values = [max(p.values) for p in simplex_parameters]
if not (min(min_values) >= 0.0):
raise ValueError(
f"All parameters passed to '{cls.from_simplex.__name__}' "
f"All simplex_parameters passed to '{cls.from_simplex.__name__}' "
f"must have non-negative values only."
)

Expand Down Expand Up @@ -463,10 +453,7 @@ def drop_invalid(
exp_rep = pd.merge(exp_rep, product_space, how="cross")

# Remove entries that violate parameter constraints:
for constraint in (c for c in constraints if c.eval_during_creation):
idxs = constraint.get_invalid(exp_rep)
exp_rep.drop(index=idxs, inplace=True)
exp_rep.reset_index(inplace=True, drop=True)
_apply_constraint_filter(exp_rep, constraints)

return cls(
parameters=simplex_parameters + product_parameters,
Expand Down Expand Up @@ -587,6 +574,27 @@ def transform(
return comp_rep


def _apply_constraint_filter(df: pd.DataFrame, constraints: List[DiscreteConstraint]):
"""Remove discrete search space entries inplace based on constraints.
Args:
df: The data in experimental representation to be modified inplace.
constraints: List of discrete constraints.
"""
# Reorder the constraints according to their execution order
constraints = sorted(
constraints,
key=lambda x: DISCRETE_CONSTRAINTS_FILTERING_ORDER.index(x.__class__),
)

# Remove entries that violate parameter constraints:
for constraint in (c for c in constraints if c.eval_during_creation):
idxs = constraint.get_invalid(df)
df.drop(index=idxs, inplace=True)
df.reset_index(inplace=True, drop=True)


def parameter_cartesian_prod_to_df(
parameters: Iterable[Parameter],
) -> pd.DataFrame:
Expand All @@ -613,5 +621,52 @@ def parameter_cartesian_prod_to_df(
return ret


def validate_simplex_subspace_from_config(specs: dict, _) -> None:
"""Validate the discrete space while skipping costly creation steps."""
# Validate product inputs without constructing it
if specs.get("constructor", None) == "from_product":
parameters = converter.structure(specs["parameters"], List[DiscreteParameter])
validate_parameters(parameters)

constraints = specs.get("constraints", None)
if constraints:
constraints = converter.structure(
specs["constraints"], List[DiscreteConstraint]
)
validate_constraints(constraints, parameters)

# Validate simplex inputs without constructing it
elif specs.get("constructor", None) == "from_simplex":
simplex_parameters = converter.structure(
specs["simplex_parameters"], List[NumericalDiscreteParameter]
)

if not all(min(p.values) >= 0.0 for p in simplex_parameters):
raise ValueError(
f"All simplex_parameters passed to "
f"'{SubspaceDiscrete.from_simplex.__name__}' must have non-negative "
f"values only."
)

product_parameters = specs.get("product_parameters", None)
if product_parameters:
product_parameters = converter.structure(
specs["product_parameters"], List[DiscreteParameter]
)

validate_parameters(simplex_parameters + product_parameters)

constraints = specs.get("constraints", None)
if constraints:
constraints = converter.structure(
specs["constraints"], List[DiscreteConstraint]
)
validate_constraints(constraints, simplex_parameters + product_parameters)

# For all other types, validate by construction
else:
converter.structure(specs, SubspaceDiscrete)


# Register deserialization hook
converter.register_structure_hook(SubspaceDiscrete, select_constructor_hook)
96 changes: 74 additions & 22 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,30 +655,37 @@ def fixture_default_config():
# default campaign object instead of hardcoding it here. This avoids redundant
# code and automatically keeps them synced.
cfg = """{
"parameters": [
{
"type": "NumericalDiscreteParameter",
"name": "Temp_C",
"values": [10, 20, 30, 40]
},
{
"type": "NumericalDiscreteParameter",
"name": "Concentration",
"values": [0.2, 0.3, 1.4]
},
__fillin__
"searchspace": {
"constructor": "from_product",
"parameters": [
{
"type": "NumericalDiscreteParameter",
"name": "Temp_C",
"values": [10, 20, 30, 40]
},
{
"type": "NumericalDiscreteParameter",
"name": "Concentration",
"values": [0.2, 0.3, 1.4]
},
__fillin__
{
"type": "CategoricalParameter",
"name": "Base",
"values": ["base1", "base2", "base3", "base4", "base5"]
}
],
"constraints": []
},
"objective": {
"mode": "SINGLE",
"targets": [
{
"type": "CategoricalParameter",
"name": "Base",
"values": ["base1", "base2", "base3", "base4", "base5"]
"type": "NumericalTarget",
"name": "Yield",
"mode": "MAX"
}
],
"constraints": [],
"objective": {
"mode": "SINGLE",
"targets": [
{"name": "Yield", "mode": "MAX"}
]
]
},
"recommender": {
"type": "TwoPhaseMetaRecommender",
Expand Down Expand Up @@ -716,6 +723,51 @@ def fixture_default_config():
return cfg


@pytest.fixture(name="simplex_config")
def fixture_default_simplex_config():
"""The default simplex config to be used if not specified differently."""
cfg = """{
"searchspace": {
"discrete": {
"constructor": "from_simplex",
"simplex_parameters": [
{
"type": "NumericalDiscreteParameter",
"name": "simplex1",
"values": [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
},
{
"type": "NumericalDiscreteParameter",
"name": "simplex2",
"values": [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
}
],
"product_parameters": [
{
"type": "CategoricalParameter",
"name": "Granularity",
"values": ["coarse", "medium", "fine"]
}
],
"max_sum": 1.0,
"boundary_only": true
}
},
"objective": {
"mode": "SINGLE",
"targets": [
{
"type": "NumericalTarget",
"name": "Yield",
"mode": "MAX"
}
]
}
}"""

return cfg


@pytest.fixture(name="onnx_str")
def fixture_default_onnx_str() -> Union[bytes, None]:
"""The default ONNX model string to be used if not specified differently."""
Expand Down
14 changes: 12 additions & 2 deletions tests/serialization/test_campaign_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,21 @@ def test_campaign_serialization(campaign):
assert campaign == campaign2


def test_valid_config(config):
def test_valid_product_config(config):
Campaign.validate_config(config)


def test_invalid_config(config):
def test_invalid_product_config(config):
config = config.replace("CategoricalParameter", "CatParam")
with pytest.raises(ClassValidationError):
Campaign.validate_config(config)


def test_valid_simplex_config(simplex_config):
Campaign.validate_config(simplex_config)


def test_invalid_simplex_config(simplex_config):
simplex_config = simplex_config.replace("0.0, ", "-1.0, 0.0, ")
with pytest.raises(ClassValidationError):
Campaign.validate_config(simplex_config)

0 comments on commit 16c07e0

Please sign in to comment.