Skip to content

Commit

Permalink
Replace literal return type with None
Browse files Browse the repository at this point in the history
  • Loading branch information
AdrianSosic committed Jul 23, 2024
1 parent e7f3f67 commit 8360c67
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 12 deletions.
6 changes: 3 additions & 3 deletions baybe/surrogates/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from abc import ABC, abstractmethod
from collections.abc import Callable
from enum import Enum, auto
from typing import TYPE_CHECKING, Any, ClassVar, Literal
from typing import TYPE_CHECKING, Any, ClassVar

import pandas as pd
from attrs import define, field
Expand Down Expand Up @@ -110,7 +110,7 @@ def to_botorch(self) -> Model:
@staticmethod
def _make_parameter_scaler(
parameter: Parameter,
) -> ParameterScalerProtocol | Literal["passthrough"]:
) -> ParameterScalerProtocol | None:
"""Return the scaler to be used for the given parameter."""
return MinMaxScaler()

Expand All @@ -129,7 +129,7 @@ def _make_input_scaler(self, searchspace: SearchSpace) -> ColumnTransformer:
# Create the composite scaler from the parameter-wise scaler objects
transformers = [
(
self._make_parameter_scaler(p),
"passthrough" if (s := self._make_parameter_scaler(p) is None) else s,
[c for c in p.comp_rep_columns if c in searchspace.comp_rep_columns],
)
for p in searchspace.parameters
Expand Down
6 changes: 3 additions & 3 deletions baybe/surrogates/gaussian_process/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING, ClassVar, Literal
from typing import TYPE_CHECKING, ClassVar

from attrs import define, field
from attrs.validators import instance_of
Expand Down Expand Up @@ -115,12 +115,12 @@ def to_botorch(self) -> Model: # noqa: D102
@staticmethod
def _make_parameter_scaler(
parameter: Parameter,
) -> ParameterScalerProtocol | Literal["passthrough"]:
) -> ParameterScalerProtocol | None:
# See base class.

# Task parameters are handled separately through an index kernel
if isinstance(parameter, TaskParameter):
return "passthrough"
return

return MinMaxScaler()

Expand Down
6 changes: 3 additions & 3 deletions baybe/surrogates/ngboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Any, ClassVar, Literal
from typing import TYPE_CHECKING, Any, ClassVar

from attr import define, field
from ngboost import NGBRegressor
Expand Down Expand Up @@ -56,11 +56,11 @@ def __attrs_post_init__(self):
@staticmethod
def _make_parameter_scaler(
parameter: Parameter,
) -> ParameterScalerProtocol | Literal["passthrough"]:
) -> ParameterScalerProtocol | None:
# See base class.

# Tree-like models do not require any input scaling
return "passthrough"
return

@batchify
def _estimate_moments(self, candidates: Tensor, /) -> tuple[Tensor, Tensor]:
Expand Down
6 changes: 3 additions & 3 deletions baybe/surrogates/random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Any, ClassVar, Literal
from typing import TYPE_CHECKING, Any, ClassVar

import numpy as np
from attr import define, field
Expand Down Expand Up @@ -51,11 +51,11 @@ class RandomForestSurrogate(GaussianSurrogate):
@staticmethod
def _make_parameter_scaler(
parameter: Parameter,
) -> ParameterScalerProtocol | Literal["passthrough"]:
) -> ParameterScalerProtocol | None:
# See base class.

# Tree-like models do not require any input scaling
return "passthrough"
return

@batchify
def _estimate_moments(self, candidates: Tensor, /) -> tuple[Tensor, Tensor]:
Expand Down

0 comments on commit 8360c67

Please sign in to comment.