Skip to content

Commit

Permalink
Introduce ScalerProtocol class
Browse files Browse the repository at this point in the history
  • Loading branch information
AdrianSosic committed Jul 19, 2024
1 parent 59eed75 commit 2f5f851
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 37 deletions.
16 changes: 10 additions & 6 deletions baybe/surrogates/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import TYPE_CHECKING, Any, ClassVar

import pandas as pd
from altair import Literal
from attrs import define, field
from cattrs import override
from cattrs.dispatch import (
Expand All @@ -16,6 +17,7 @@
UnstructuredValue,
UnstructureHook,
)
from sklearn.preprocessing import MinMaxScaler

from baybe.exceptions import ModelNotTrainedError
from baybe.objectives.base import Objective
Expand All @@ -28,7 +30,7 @@
)
from baybe.serialization.mixin import SerialMixin
from baybe.utils.dataframe import to_tensor
from baybe.utils.scaling import ScalingMethod, make_scaler
from baybe.utils.scaling import ScalerProtocol

if TYPE_CHECKING:
from botorch.models.model import Model
Expand Down Expand Up @@ -86,9 +88,11 @@ def to_botorch(self) -> Model:
return AdapterModel(self)

@staticmethod
def _get_parameter_scaling(parameter: Parameter) -> ScalingMethod:
def _get_parameter_scaler(
parameter: Parameter,
) -> ScalerProtocol | Literal["passthrough"]: # noqa: F821
"""Return the scaling method to be used for the given parameter."""
return ScalingMethod.MINMAX
return MinMaxScaler()

def _make_input_scaler(
self, searchspace: SearchSpace, measurements: pd.DataFrame
Expand All @@ -100,7 +104,7 @@ def _make_input_scaler(
# TODO: Filter down to columns that actually remain in the comp rep of the
# searchspace, since the transformer can break down otherwise.
transformers = [
(make_scaler(self._get_parameter_scaling(p)), p.comp_df.columns)
(self._get_parameter_scaler(p), p.comp_df.columns)
for p in searchspace.parameters
]
scaler = make_column_transformer(*transformers)
Expand Down Expand Up @@ -225,7 +229,7 @@ def _estimate_moments(self, candidates: Tensor) -> tuple[Tensor, Tensor]:


def _make_hook_decode_onnx_str(
raw_unstructure_hook: UnstructureHook
raw_unstructure_hook: UnstructureHook,
) -> UnstructureHook:
"""Wrap an unstructuring hook to let it also decode the contained ONNX string."""

Expand Down Expand Up @@ -253,7 +257,7 @@ def wrapper(dct: UnstructuredValue, _: TargetType) -> StructuredValue:


def _block_serialize_custom_architecture(
raw_unstructure_hook: UnstructureHook
raw_unstructure_hook: UnstructureHook,
) -> UnstructureHook:
"""Raise error if attempt to serialize a custom architecture surrogate."""
# TODO: Ideally, this hook should be removed and unstructuring the Surrogate
Expand Down
10 changes: 6 additions & 4 deletions baybe/surrogates/gaussian_process/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING, ClassVar
from typing import TYPE_CHECKING, ClassVar, Literal

from attrs import define, field
from attrs.validators import instance_of
Expand All @@ -23,7 +23,7 @@
DefaultKernelFactory,
_default_noise_factory,
)
from baybe.utils.scaling import ScalingMethod
from baybe.utils.scaling import ScalerProtocol

if TYPE_CHECKING:
from botorch.models.model import Model
Expand Down Expand Up @@ -111,11 +111,13 @@ def to_botorch(self) -> Model: # noqa: D102
return self._model

@staticmethod
def _get_parameter_scaling(parameter: Parameter) -> ScalingMethod:
def _get_parameter_scaler(
parameter: Parameter,
) -> ScalerProtocol | Literal["passthrough"]:
# See base class.

# For GPs, we use botorch's built-in machinery for scaling.
return ScalingMethod.IDENTITY
return "passthrough"

@staticmethod
def _get_model_context(
Expand Down
39 changes: 12 additions & 27 deletions baybe/utils/scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,21 @@

from __future__ import annotations

from enum import Enum
from typing import TYPE_CHECKING, Literal, TypeAlias
from typing import Protocol

if TYPE_CHECKING:
from sklearn.base import BaseEstimator, TransformerMixin
import pandas as pd

Scaler: TypeAlias = BaseEstimator | TransformerMixin

class ScalerProtocol(Protocol):
"""Type protocol specifying the interface scalers need to implement.
class ScalingMethod(Enum):
"""Available scaling methods."""
The protocol is compatible with sklearn scalers such as
:class:`sklearn.preprocessing.MinMaxScaler` or
:class:`sklearn.preprocessing.MaxAbsScaler`.
"""

IDENTITY = "IDENTITY"
"""Identity transformation (no scaling applied)."""
def fit(df: pd.DataFrame, /) -> None:
"""Fit the scaler to a given data set."""

MINMAX = "MINMAX"
"""Min-max scaling, mapping the observed value range to [0, 1]."""

MAXABS = "MAXABS"
"""Max-abs scaling, normalizing by the largest observed absolute."""


def make_scaler(method: ScalingMethod, /) -> Scaler | Literal["passthrough"]:
"""Create a scaler object based on the specified method."""
from sklearn.preprocessing import MaxAbsScaler, MinMaxScaler

match method:
case ScalingMethod.IDENTITY:
return "passthrough"
case ScalingMethod.MINMAX:
return MinMaxScaler()
case ScalingMethod.MAXABS:
return MaxAbsScaler()
def transform(df: pd.DataFrame, /) -> pd.DataFrame:
"""Transform a data using the fitted scaling logic."""

0 comments on commit 2f5f851

Please sign in to comment.