-
Notifications
You must be signed in to change notification settings - Fork 47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor parameter tolerance handling #114
Changes from all commits
6e8c7f4
a089087
57f90cd
411465e
69aeac6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,7 @@ | |
from __future__ import annotations | ||
|
||
import json | ||
from typing import List | ||
from typing import List, Literal | ||
|
||
import cattrs | ||
import numpy as np | ||
|
@@ -159,7 +159,7 @@ def validate_config(cls, config_json: str) -> None: | |
def add_measurements( | ||
self, | ||
data: pd.DataFrame, | ||
numerical_measurements_must_be_within_tolerance: bool = True, | ||
on_tolerance_violation: Literal["raise", "warn", "ignore"] = "raise", | ||
) -> None: | ||
"""Add results from a dataframe to the internal database. | ||
|
||
|
@@ -172,8 +172,12 @@ def add_measurements( | |
Args: | ||
data: The data to be added (with filled values for targets). Preferably | ||
created via :func:`baybe.campaign.Campaign.recommend`. | ||
numerical_measurements_must_be_within_tolerance: Flag indicating if | ||
numerical parameters need to be within their tolerances. | ||
on_tolerance_violation: The mode determining how to handle the attempt | ||
of adding numerical data that violates parameter tolerances. Unless | ||
set to ``raise``, the measurements will be added to the database | ||
despite potential violations. However, note that values lying | ||
significantly outside the convex hull of numerical parameters can | ||
lead to scaling problems in model training. | ||
|
||
Raises: | ||
ValueError: If one of the targets has missing values or NaNs in the provided | ||
|
@@ -211,9 +215,7 @@ def add_measurements( | |
|
||
# Update meta data | ||
# TODO: refactor responsibilities | ||
self.searchspace.discrete.mark_as_measured( | ||
data, numerical_measurements_must_be_within_tolerance | ||
) | ||
self.searchspace.discrete.mark_as_measured(data, on_tolerance_violation) | ||
|
||
# Read in measurements and add them to the database | ||
self.n_batches_done += 1 | ||
|
@@ -226,12 +228,13 @@ def add_measurements( | |
) | ||
|
||
# Telemetry | ||
# TODO: Code is inefficient because of unnecessary second fuzzy matching | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Which unnecessary second fuzzy matching are you talking about here? Is this something happening within one of the functions? |
||
telemetry_record_value(TELEM_LABELS["COUNT_ADD_RESULTS"], 1) | ||
telemetry_record_recommended_measurement_percentage( | ||
self._cached_recommendation, | ||
data, | ||
self.parameters, | ||
numerical_measurements_must_be_within_tolerance, | ||
on_tolerance_violation, | ||
) | ||
|
||
def recommend( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,12 +7,11 @@ | |
import numpy as np | ||
import pandas as pd | ||
from attrs import define, field | ||
from attrs.validators import min_len | ||
from attrs.validators import ge, min_len | ||
|
||
from baybe.exceptions import NumericalUnderflowError | ||
from baybe.parameters.base import DiscreteParameter, Parameter | ||
from baybe.parameters.validation import validate_is_finite, validate_unique_values | ||
from baybe.utils import DTypeFloatNumpy, InfiniteIntervalError, Interval, convert_bounds | ||
from baybe.utils import InfiniteIntervalError, Interval, convert_bounds | ||
|
||
|
||
@define(frozen=True, slots=False) | ||
|
@@ -24,7 +23,7 @@ class NumericalDiscreteParameter(DiscreteParameter): | |
# See base class. | ||
|
||
# object variables | ||
# NOTE: The parameter values are assumed to be sorted by the tolerance validator. | ||
# NOTE: The values are assumed to be sorted by the tolerance default method. | ||
_values: Tuple[float, ...] = field( | ||
# FIXME[typing]: https://github.com/python-attrs/cattrs/issues/111 | ||
converter=lambda x: sorted(cattrs.structure(x, Tuple[float, ...])), # type: ignore | ||
|
@@ -37,43 +36,16 @@ class NumericalDiscreteParameter(DiscreteParameter): | |
) | ||
"""The values the parameter can take.""" | ||
|
||
tolerance: float = field(default=0.0) | ||
"""The absolute tolerance used for deciding whether a value is in range. A tolerance | ||
larger than half the minimum distance between parameter values is not allowed | ||
because that could cause ambiguity when inputting data points later.""" | ||
tolerance: float = field(validator=ge(0.0)) | ||
"""The absolute tolerance used for deciding whether a value is considered in range. | ||
A value is considered in range if its distance to the closest parameter value | ||
is smaller than the specified tolerance.""" | ||
|
||
@tolerance.validator | ||
def _validate_tolerance( # noqa: DOC101, DOC103 | ||
self, _: Any, tolerance: float | ||
) -> None: | ||
"""Validate that the given tolerance is safe. | ||
|
||
The tolerance is the allowed experimental uncertainty when | ||
reading in measured values. A tolerance larger than half the minimum | ||
distance between parameter values is not allowed because that could cause | ||
ambiguity when inputting data points later. | ||
|
||
Raises: | ||
ValueError: If the tolerance is not safe. | ||
""" | ||
# For zero tolerance, the only left requirement is that all parameter values | ||
# are distinct, which is already ensured by the corresponding validator. | ||
if tolerance == 0.0: | ||
return | ||
|
||
min_dist = np.diff(self.values).min() | ||
if min_dist == (eps := np.nextafter(0, 1, dtype=DTypeFloatNumpy)): | ||
raise NumericalUnderflowError( | ||
f"The distance between any two parameter values must be at least " | ||
f"twice the size of the used floating point resolution of {eps}." | ||
) | ||
|
||
if tolerance >= (max_tol := min_dist / 2.0): | ||
raise ValueError( | ||
f"Parameter '{self.name}' is initialized with tolerance {tolerance} " | ||
f"but due to the given parameter values {self.values}, the specified " | ||
f"tolerance must be smaller than {max_tol} to avoid ambiguity." | ||
) | ||
@tolerance.default | ||
def default_tolerance(self) -> float: | ||
"""Set the tolerance to fraction of the smallest value distance.""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This reads as if the |
||
fraction = 0.1 | ||
return fraction * np.diff(self.values).min().item() | ||
|
||
@property | ||
def values(self) -> tuple: # noqa: D102 | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ | |
from __future__ import annotations | ||
|
||
import logging | ||
import warnings | ||
from typing import TYPE_CHECKING, Dict, Iterable, List, Literal, Optional, Tuple, Union | ||
|
||
import numpy as np | ||
|
@@ -321,7 +322,7 @@ def fuzzy_row_match( | |
left_df: pd.DataFrame, | ||
right_df: pd.DataFrame, | ||
parameters: List[Parameter], | ||
numerical_measurements_must_be_within_tolerance: bool, | ||
on_tolerance_violation: Literal["raise", "warn", "ignore"], | ||
) -> pd.Index: | ||
"""Match row of the right dataframe to the rows of the left dataframe. | ||
|
||
|
@@ -337,10 +338,8 @@ def fuzzy_row_match( | |
dataframe. | ||
parameters: List of baybe parameter objects that are needed to identify | ||
potential tolerances. | ||
numerical_measurements_must_be_within_tolerance: If ``True``, numerical | ||
parameters are matched with the search space elements only if there is a | ||
match within the parameter tolerance. If ``False``, the closest match is | ||
considered, irrespective of the distance. | ||
on_tolerance_violation: The mode determining what how to handle a missing | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Either add some words about the different modes here or refer to the place where you described them. |
||
match due to parameter tolerance violation. | ||
|
||
Returns: | ||
The index of the matching rows in ``left_df``. | ||
|
@@ -349,6 +348,13 @@ def fuzzy_row_match( | |
ValueError: If some rows are present in the right but not in the left dataframe. | ||
ValueError: If the input data has invalid values. | ||
""" | ||
# Assert that the passed violation mode is valid | ||
if on_tolerance_violation not in ["raise", "warn", "ignore"]: | ||
raise ValueError( | ||
"""Argument passed to `on_tolerance_violation` must be one """ | ||
"""of '["raise", "warn", "ignore"]'.""" | ||
) | ||
|
||
# Assert that all parameters appear in the given dataframe | ||
if not all(col in right_df.columns for col in left_df.columns): | ||
raise ValueError( | ||
|
@@ -360,25 +366,31 @@ def fuzzy_row_match( | |
|
||
# Iterate over all input rows | ||
for ind, row in right_df.iterrows(): | ||
# Check if the row represents a valid input | ||
valid = True | ||
# Check if all values of the row are in the respective parameter ranges | ||
for param in parameters: | ||
if param.is_numeric: | ||
if numerical_measurements_must_be_within_tolerance: | ||
valid &= param.is_in_range(row[param.name]) | ||
else: | ||
valid &= param.is_in_range(row[param.name]) | ||
if not valid: | ||
raise ValueError( | ||
f"Input data on row with the index {row.name} has invalid " | ||
f"values in parameter '{param.name}'. " | ||
f"For categorical parameters, values need to exactly match a " | ||
f"valid choice defined in your config. " | ||
f"For numerical parameters, a match is accepted only if " | ||
f"the input value is within the specified tolerance/range. Set " | ||
f"the flag 'numerical_measurements_must_be_within_tolerance' " | ||
f"to 'False' to disable this behavior." | ||
) | ||
if not param.is_in_range((val := row[param.name])): | ||
if param.is_numeric and on_tolerance_violation == "ignore": | ||
break | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shouldnt this be a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also think that this needs to be |
||
if param.is_numeric and on_tolerance_violation == "warn": | ||
warnings.warn( | ||
f"The value '{val}' is outside the range of parameter " | ||
f"'{param.name}'. " | ||
f"If you expected a match between your input " | ||
f"and the parameter, consider increasing the parameter's " | ||
f"tolerance value or adding more parameter values. " | ||
f"You can silence this warning using the 'ignore' mode.", | ||
UserWarning, | ||
) | ||
break | ||
else: | ||
raise ValueError( | ||
f"The value '{val}' is outside the range of parameter " | ||
f"'{param.name}'. " | ||
f"If you expected a match between your input " | ||
f"and the parameter, consider increasing the parameter's " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this error is shown for both numerical and non-numerical parameters, hence should make the possible distinct situations clearer. Eg a categorical parameter has no tolerance |
||
f"tolerance value or adding more parameter values. " | ||
f"You can bypass this check using the 'ignore' or 'warn' mode." | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i cant comment on the line, but we shouldnt forget about whats below here in line 411+
|
||
|
||
# Differentiate category-like and discrete numerical parameters | ||
cat_cols = [p.name for p in parameters if not p.is_numeric] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since these violations are not an issue for scaling, I would go with
warn
as default, notraise
. This is also needed in the API I think.The user can explicitly still disallow violations but setting
raise
which would be similar to explicitly setting the oldnumerical_measurements_must_be_within_tolerance
toTrue
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree.