Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor parameter tolerance handling #114

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions baybe/campaign.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import json
from typing import List
from typing import List, Literal

import cattrs
import numpy as np
Expand Down Expand Up @@ -159,7 +159,7 @@ def validate_config(cls, config_json: str) -> None:
def add_measurements(
self,
data: pd.DataFrame,
numerical_measurements_must_be_within_tolerance: bool = True,
on_tolerance_violation: Literal["raise", "warn", "ignore"] = "raise",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since these violations are not an issue for scaling, I would go with warn as default, not raise. This is also needed in the API I think.

The user can explicitly still disallow violations but setting raise which would be similar to explicitly setting the old numerical_measurements_must_be_within_tolerance to True

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree.

) -> None:
"""Add results from a dataframe to the internal database.

Expand All @@ -172,8 +172,12 @@ def add_measurements(
Args:
data: The data to be added (with filled values for targets). Preferably
created via :func:`baybe.campaign.Campaign.recommend`.
numerical_measurements_must_be_within_tolerance: Flag indicating if
numerical parameters need to be within their tolerances.
on_tolerance_violation: The mode determining how to handle the attempt
of adding numerical data that violates parameter tolerances. Unless
set to ``raise``, the measurements will be added to the database
despite potential violations. However, note that values lying
significantly outside the convex hull of numerical parameters can
lead to scaling problems in model training.

Raises:
ValueError: If one of the targets has missing values or NaNs in the provided
Expand Down Expand Up @@ -211,9 +215,7 @@ def add_measurements(

# Update meta data
# TODO: refactor responsibilities
self.searchspace.discrete.mark_as_measured(
data, numerical_measurements_must_be_within_tolerance
)
self.searchspace.discrete.mark_as_measured(data, on_tolerance_violation)

# Read in measurements and add them to the database
self.n_batches_done += 1
Expand All @@ -226,12 +228,13 @@ def add_measurements(
)

# Telemetry
# TODO: Code is inefficient because of unnecessary second fuzzy matching
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which unnecessary second fuzzy matching are you talking about here? Is this something happening within one of the functions?

telemetry_record_value(TELEM_LABELS["COUNT_ADD_RESULTS"], 1)
telemetry_record_recommended_measurement_percentage(
self._cached_recommendation,
data,
self.parameters,
numerical_measurements_must_be_within_tolerance,
on_tolerance_violation,
)

def recommend(
Expand Down
52 changes: 12 additions & 40 deletions baybe/parameters/numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@
import numpy as np
import pandas as pd
from attrs import define, field
from attrs.validators import min_len
from attrs.validators import ge, min_len

from baybe.exceptions import NumericalUnderflowError
from baybe.parameters.base import DiscreteParameter, Parameter
from baybe.parameters.validation import validate_is_finite, validate_unique_values
from baybe.utils import DTypeFloatNumpy, InfiniteIntervalError, Interval, convert_bounds
from baybe.utils import InfiniteIntervalError, Interval, convert_bounds


@define(frozen=True, slots=False)
Expand All @@ -24,7 +23,7 @@ class NumericalDiscreteParameter(DiscreteParameter):
# See base class.

# object variables
# NOTE: The parameter values are assumed to be sorted by the tolerance validator.
# NOTE: The values are assumed to be sorted by the tolerance default method.
_values: Tuple[float, ...] = field(
# FIXME[typing]: https://github.com/python-attrs/cattrs/issues/111
converter=lambda x: sorted(cattrs.structure(x, Tuple[float, ...])), # type: ignore
Expand All @@ -37,43 +36,16 @@ class NumericalDiscreteParameter(DiscreteParameter):
)
"""The values the parameter can take."""

tolerance: float = field(default=0.0)
"""The absolute tolerance used for deciding whether a value is in range. A tolerance
larger than half the minimum distance between parameter values is not allowed
because that could cause ambiguity when inputting data points later."""
tolerance: float = field(validator=ge(0.0))
"""The absolute tolerance used for deciding whether a value is considered in range.
A value is considered in range if its distance to the closest parameter value
is smaller than the specified tolerance."""

@tolerance.validator
def _validate_tolerance( # noqa: DOC101, DOC103
self, _: Any, tolerance: float
) -> None:
"""Validate that the given tolerance is safe.

The tolerance is the allowed experimental uncertainty when
reading in measured values. A tolerance larger than half the minimum
distance between parameter values is not allowed because that could cause
ambiguity when inputting data points later.

Raises:
ValueError: If the tolerance is not safe.
"""
# For zero tolerance, the only left requirement is that all parameter values
# are distinct, which is already ensured by the corresponding validator.
if tolerance == 0.0:
return

min_dist = np.diff(self.values).min()
if min_dist == (eps := np.nextafter(0, 1, dtype=DTypeFloatNumpy)):
raise NumericalUnderflowError(
f"The distance between any two parameter values must be at least "
f"twice the size of the used floating point resolution of {eps}."
)

if tolerance >= (max_tol := min_dist / 2.0):
raise ValueError(
f"Parameter '{self.name}' is initialized with tolerance {tolerance} "
f"but due to the given parameter values {self.values}, the specified "
f"tolerance must be smaller than {max_tol} to avoid ambiguity."
)
@tolerance.default
def default_tolerance(self) -> float:
"""Set the tolerance to fraction of the smallest value distance."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This reads as if the fraction can be chosen by the user. Also, we should add the information in the docstring that this fraction is set to 0.1

fraction = 0.1
return fraction * np.diff(self.values).min().item()

@property
def values(self) -> tuple: # noqa: D102
Expand Down
9 changes: 4 additions & 5 deletions baybe/searchspace/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import Any, Collection, Iterable, List, Optional, Tuple, cast
from typing import Any, Collection, Iterable, List, Literal, Optional, Tuple, cast

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -252,21 +252,20 @@ def param_bounds_comp(self) -> torch.Tensor:
def mark_as_measured(
self,
measurements: pd.DataFrame,
numerical_measurements_must_be_within_tolerance: bool,
on_tolerance_violation: Literal["raise", "warn", "ignore"],
) -> None:
"""Mark the given elements of the space as measured.

Args:
measurements: A dataframe containing parameter settings that should be
marked as measured.
numerical_measurements_must_be_within_tolerance: See
:func:`baybe.utils.dataframe.fuzzy_row_match`.
on_tolerance_violation: See :func:`baybe.utils.dataframe.fuzzy_row_match`.
"""
inds_matched = fuzzy_row_match(
self.exp_rep,
measurements,
self.parameters,
numerical_measurements_must_be_within_tolerance,
on_tolerance_violation,
)
self.metadata.loc[inds_matched, "was_measured"] = True

Expand Down
11 changes: 4 additions & 7 deletions baybe/telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
import logging
import os
import socket
from typing import TYPE_CHECKING, Dict, List, Union
from typing import TYPE_CHECKING, Dict, List, Literal, Union
from urllib.parse import urlparse

import pandas as pd
Expand Down Expand Up @@ -275,7 +275,7 @@ def telemetry_record_recommended_measurement_percentage(
cached_recommendation: pd.DataFrame,
measurements: pd.DataFrame,
parameters: List[Parameter],
numerical_measurements_must_be_within_tolerance: bool,
on_tolerance_violation: Literal["raise", "warn", "ignore"] = "raise",
) -> None:
"""Submit the percentage of added measurements.

Expand All @@ -293,10 +293,7 @@ def telemetry_record_recommended_measurement_percentage(
measurements: The measurements which are supposed to be checked against cached
recommendations.
parameters: The list of parameters spanning the entire search space.
numerical_measurements_must_be_within_tolerance: If ``True``, numerical
parameter entries are matched with the reference elements only if there is
a match within the parameter tolerance. If ``False``, the closest match
is considered, irrespective of the distance.
on_tolerance_violation: See :func:`baybe.utils.dataframe.fuzzy_row_match`.
"""
if is_enabled():
if len(cached_recommendation) > 0:
Expand All @@ -306,7 +303,7 @@ def telemetry_record_recommended_measurement_percentage(
cached_recommendation,
measurements,
parameters,
numerical_measurements_must_be_within_tolerance,
on_tolerance_violation,
)
)
/ len(cached_recommendation)
Expand Down
58 changes: 35 additions & 23 deletions baybe/utils/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import logging
import warnings
from typing import TYPE_CHECKING, Dict, Iterable, List, Literal, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -321,7 +322,7 @@ def fuzzy_row_match(
left_df: pd.DataFrame,
right_df: pd.DataFrame,
parameters: List[Parameter],
numerical_measurements_must_be_within_tolerance: bool,
on_tolerance_violation: Literal["raise", "warn", "ignore"],
) -> pd.Index:
"""Match row of the right dataframe to the rows of the left dataframe.

Expand All @@ -337,10 +338,8 @@ def fuzzy_row_match(
dataframe.
parameters: List of baybe parameter objects that are needed to identify
potential tolerances.
numerical_measurements_must_be_within_tolerance: If ``True``, numerical
parameters are matched with the search space elements only if there is a
match within the parameter tolerance. If ``False``, the closest match is
considered, irrespective of the distance.
on_tolerance_violation: The mode determining what how to handle a missing
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either add some words about the different modes here or refer to the place where you described them.

match due to parameter tolerance violation.

Returns:
The index of the matching rows in ``left_df``.
Expand All @@ -349,6 +348,13 @@ def fuzzy_row_match(
ValueError: If some rows are present in the right but not in the left dataframe.
ValueError: If the input data has invalid values.
"""
# Assert that the passed violation mode is valid
if on_tolerance_violation not in ["raise", "warn", "ignore"]:
raise ValueError(
"""Argument passed to `on_tolerance_violation` must be one """
"""of '["raise", "warn", "ignore"]'."""
)

# Assert that all parameters appear in the given dataframe
if not all(col in right_df.columns for col in left_df.columns):
raise ValueError(
Expand All @@ -360,25 +366,31 @@ def fuzzy_row_match(

# Iterate over all input rows
for ind, row in right_df.iterrows():
# Check if the row represents a valid input
valid = True
# Check if all values of the row are in the respective parameter ranges
for param in parameters:
if param.is_numeric:
if numerical_measurements_must_be_within_tolerance:
valid &= param.is_in_range(row[param.name])
else:
valid &= param.is_in_range(row[param.name])
if not valid:
raise ValueError(
f"Input data on row with the index {row.name} has invalid "
f"values in parameter '{param.name}'. "
f"For categorical parameters, values need to exactly match a "
f"valid choice defined in your config. "
f"For numerical parameters, a match is accepted only if "
f"the input value is within the specified tolerance/range. Set "
f"the flag 'numerical_measurements_must_be_within_tolerance' "
f"to 'False' to disable this behavior."
)
if not param.is_in_range((val := row[param.name])):
if param.is_numeric and on_tolerance_violation == "ignore":
break
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldnt this be a continue? still need to checkt he other parameters in the list

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also think that this needs to be continue

if param.is_numeric and on_tolerance_violation == "warn":
warnings.warn(
f"The value '{val}' is outside the range of parameter "
f"'{param.name}'. "
f"If you expected a match between your input "
f"and the parameter, consider increasing the parameter's "
f"tolerance value or adding more parameter values. "
f"You can silence this warning using the 'ignore' mode.",
UserWarning,
)
break
else:
raise ValueError(
f"The value '{val}' is outside the range of parameter "
f"'{param.name}'. "
f"If you expected a match between your input "
f"and the parameter, consider increasing the parameter's "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this error is shown for both numerical and non-numerical parameters, hence should make the possible distinct situations clearer. Eg a categorical parameter has no tolerance

f"tolerance value or adding more parameter values. "
f"You can bypass this check using the 'ignore' or 'warn' mode."
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i cant comment on the line, but we shouldnt forget about whats below here in line 411+

  1. it still uses old logger mechanism
  2. the warnings there should also be controlled by on_tolerance_validation. eg no need to print a no match found warning if its set to ignore


# Differentiate category-like and discrete numerical parameters
cat_cols = [p.name for p in parameters if not p.is_numeric]
Expand Down
14 changes: 1 addition & 13 deletions tests/hypothesis_strategies/parameters.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Hypothesis strategies for parameters."""

import hypothesis.strategies as st
import numpy as np
from hypothesis.extra.pandas import columns, data_frames

from baybe.parameters.categorical import (
Expand Down Expand Up @@ -88,18 +87,7 @@ def numerical_discrete_parameter(
unique=True,
)
)
max_tolerance = np.diff(np.sort(values)).min() / 2
if max_tolerance == 0.0:
tolerance = 0.0
else:
tolerance = draw(
st.floats(
min_value=0.0,
max_value=max_tolerance,
allow_nan=False,
exclude_max=True,
)
)
tolerance = draw(st.floats(min_value=0.0))
return NumericalDiscreteParameter(name=name, values=values, tolerance=tolerance)


Expand Down
Loading