From 7de89c8f1745ec6042868eca705e2388f5b5d768 Mon Sep 17 00:00:00 2001 From: Bas Nijholt Date: Tue, 11 Oct 2022 21:05:26 -0700 Subject: [PATCH] Add type-hints to adaptive/learner/skopt_learner.py --- adaptive/learner/skopt_learner.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/adaptive/learner/skopt_learner.py b/adaptive/learner/skopt_learner.py index e12f49daa..1c3c18fd7 100644 --- a/adaptive/learner/skopt_learner.py +++ b/adaptive/learner/skopt_learner.py @@ -1,6 +1,7 @@ from __future__ import annotations import collections +from typing import Callable import numpy as np from skopt import Optimizer @@ -25,8 +26,8 @@ class SKOptLearner(Optimizer, BaseLearner): Arguments to pass to ``skopt.Optimizer``. """ - def __init__(self, function, **kwargs): - self.function = function + def __init__(self, function: Callable, **kwargs) -> None: + self.function = function # type: ignore self.pending_points = set() self.data = collections.OrderedDict() self._kwargs = kwargs @@ -36,7 +37,7 @@ def new(self) -> SKOptLearner: """Return a new `~adaptive.SKOptLearner` without the data.""" return SKOptLearner(self.function, **self._kwargs) - def tell(self, x, y, fit=True): + def tell(self, x: float | list[float], y: float, fit: bool = True) -> None: if isinstance(x, collections.abc.Iterable): self.pending_points.discard(tuple(x)) self.data[tuple(x)] = y @@ -55,7 +56,7 @@ def remove_unfinished(self): pass @cache_latest - def loss(self, real=True): + def loss(self, real: bool = True) -> float: if not self.models: return np.inf else: @@ -65,7 +66,12 @@ def loss(self, real=True): # estimator of loss, but it is the cheapest. return 1 - model.score(self.Xi, self.yi) - def ask(self, n, tell_pending=True): + def ask( + self, n: int, tell_pending: bool = True + ) -> ( + tuple[list[float], list[float]] + | tuple[list[list[float]], list[float]] # XXX: this indicates a bug! + ): if not tell_pending: raise NotImplementedError( "Asking points is an irreversible " @@ -79,7 +85,7 @@ def ask(self, n, tell_pending=True): return [p[0] for p in points], [self.loss() / n] * n @property - def npoints(self): + def npoints(self) -> int: """Number of evaluated points.""" return len(self.Xi)