Skip to content

Commit

Permalink
TYP: extend typing to curves (#554)
Browse files Browse the repository at this point in the history
  • Loading branch information
attack68 authored Dec 14, 2024
1 parent ddafac0 commit e7bf3dd
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 60 deletions.
21 changes: 10 additions & 11 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -139,17 +139,16 @@ ignore = [
"python/tests/*" = ["F401", "B", "N", "S", "ANN"]

[tool.mypy]
files = [
"python/rateslib/calendars/**/*.py",
"python/rateslib/curves/**/*.py",
"python/rateslib/dual/**/*.py",
"python/rateslib/fx/**/*.py",
"python/rateslib/__init__.py",
"python/rateslib/_spec_loader.py",
"python/rateslib/default.py",
"python/rateslib/json.py",
"python/rateslib/scheduling.py",
"python/rateslib/splines.py",
packages = [
"rateslib"
]
exclude = [
"python/rateslib/instruments",
"python/rateslib/fx_volatility.py",
"python/rateslib/legs.py",
"python/rateslib/periods.py",
"python/rateslib/solver.py",
# "python/rateslib/rs.pyi",
]
strict = true

Expand Down
48 changes: 28 additions & 20 deletions python/rateslib/curves/curves.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
from rateslib import defaults
from rateslib.calendars import CalInput, add_tenor, dcf
from rateslib.calendars.dcfs import _DCF1d
from rateslib.calendars.rs import CalTypes, Modifier, get_calendar
from rateslib.calendars.rs import CalTypes, get_calendar
from rateslib.default import NoInput, PlotOutput, _drb, plot
from rateslib.dual import (
from rateslib.dual import ( # type: ignore[attr-defined]
Arr1dF64,
Arr1dObj,
Dual,
Expand All @@ -35,8 +35,8 @@
dual_log,
set_order_convert,
)
from rateslib.rs import Modifier, index_left_f64
from rateslib.rs import from_json as from_json_rs
from rateslib.rs import index_left_f64
from rateslib.splines import PPSplineDual, PPSplineDual2, PPSplineF64

if TYPE_CHECKING:
Expand Down Expand Up @@ -171,7 +171,7 @@ class Curve:
_base_type: str = "dfs"
collateral: str | None = None

def __init__(
def __init__( # type: ignore[no-untyped-def]
self,
nodes: dict[datetime, DualTypes],
*,
Expand All @@ -186,7 +186,7 @@ def __init__(
modifier: str | NoInput = NoInput(0),
calendar: CalInput = NoInput(0),
ad: int = 0,
**kwargs, # type: ignore[no-utyped-def]
**kwargs,
) -> None:
self.clear_cache()
self.id: str = _drb(uuid4().hex[:5], id) # 1 in a million clash
Expand Down Expand Up @@ -225,7 +225,7 @@ def __init__(
self.t = t
self._c_input: bool = not isinstance(c, NoInput)
if not isinstance(self.t, NoInput):
self.t_posix: list[float] = [_.replace(tzinfo=UTC).timestamp() for _ in t]
self.t_posix: list[float] | None = [_.replace(tzinfo=UTC).timestamp() for _ in self.t]
if not isinstance(c, NoInput):
self.spline: PPSplineF64 | PPSplineDual | PPSplineDual2 | None = PPSplineF64(
4, self.t_posix, c
Expand All @@ -240,6 +240,9 @@ def __init__(
self.t_posix = None
self.spline = None

self.index_base: DualTypes | NoInput = NoInput(0)
self.index_lag: int | NoInput = NoInput(0)

self._set_ad_order(order=ad)

def __eq__(self, other: Any) -> bool:
Expand Down Expand Up @@ -298,11 +301,16 @@ def to_json(self) -> str:
else:
t = [t.strftime("%Y-%m-%d") for t in self.t]

container = {
if self._c_input and self.spline is not None:
c_ = self.spline.c
else:
c_ = None

container: dict[str, Any] = {
"nodes": {dt.strftime("%Y-%m-%d"): v.real for dt, v in self.nodes.items()},
"interpolation": self.interpolation if isinstance(self.interpolation, str) else None,
"t": t,
"c": self.spline.c if self._c_input else None,
"c": c_,
"id": self.id,
"convention": self.convention,
"endpoints": self.spline_endpoints,
Expand Down Expand Up @@ -362,7 +370,7 @@ def __getitem__(self, date: datetime) -> DualTypes:
UserWarning,
)
# self.spline cannot be None becuase self.t is given and it has been calibrated
val = self._op_exp(self.spline.ppev_single(date_posix)) # type: ignore[operator]
val = self._op_exp(self.spline.ppev_single(date_posix)) # type: ignore[union-attr]

self._maybe_add_to_cache(date, val)
return val
Expand Down Expand Up @@ -556,9 +564,9 @@ def clear_cache(self) -> None:
Alternatively the curve caching as a feature can be set to *False* in ``defaults``.
"""
self._cache: dict[datetime, Number] = dict()
self._cache: dict[datetime, DualTypes] = dict()

def _maybe_add_to_cache(self, date: datetime, val: Number) -> None:
def _maybe_add_to_cache(self, date: datetime, val: DualTypes) -> None:
if defaults.curve_caching:
self._cache[date] = val

Expand All @@ -581,13 +589,14 @@ def csolve(self) -> None:
if isinstance(self.t, NoInput) or self._c_input:
return None

t_posix = self.t_posix.copy()
# attributes relating to splines will then exist
t_posix = self.t_posix.copy() # type: ignore[union-attr]
tau_posix = [k.replace(tzinfo=UTC).timestamp() for k in self.nodes if k >= self.t[0]]
y = [self._op_log(v) for k, v in self.nodes.items() if k >= self.t[0]]

# Left side constraint
if self.spline_endpoints[0].lower() == "natural":
tau_posix.insert(0, self.t_posix[0])
tau_posix.insert(0, self.t_posix[0]) # type: ignore[index]
y.insert(0, set_order_convert(0.0, self.ad, None))
left_n = 2
elif self.spline_endpoints[0].lower() == "not_a_knot":
Expand All @@ -600,7 +609,7 @@ def csolve(self) -> None:

# Right side constraint
if self.spline_endpoints[1].lower() == "natural":
tau_posix.append(self.t_posix[-1])
tau_posix.append(self.t_posix[-1]) # type: ignore[index]
y.append(set_order_convert(0, self.ad, None))
right_n = 2
elif self.spline_endpoints[1].lower() == "not_a_knot":
Expand Down Expand Up @@ -741,8 +750,8 @@ def shift(
calendar=self.calendar,
modifier=self.modifier,
interpolation="log_linear",
index_base=self.index_base, # type: ignore[attr-defined]
index_lag=self.index_lag, # type: ignore[attr-defined]
index_base=self.index_base,
index_lag=self.index_lag,
)

_: CompositeCurve = CompositeCurve(curves=[self, shifted], id=id)
Expand Down Expand Up @@ -792,7 +801,7 @@ def shift(
self._set_ad_order(_ad)
return __

def _translate_nodes(self, start: datetime) -> dict[datetime, Number]:
def _translate_nodes(self, start: datetime) -> dict[datetime, DualTypes]:
scalar = 1 / self[start]
new_nodes = {k: scalar * v for k, v in self.nodes.items()}

Expand Down Expand Up @@ -940,7 +949,7 @@ def translate(self, start: datetime, t: bool = False) -> Curve:
if start <= self.node_dates[0]:
raise ValueError("Cannot translate into the past. Review the docs.")

new_nodes = self._translate_nodes(start)
new_nodes: dict[datetime, DualTypes] = self._translate_nodes(start)

# re-organise the t-knot sequence
if isinstance(self.t, NoInput):
Expand Down Expand Up @@ -1877,13 +1886,12 @@ def __init__( # type: ignore[no-untyped-def]
index_lag: int | NoInput = NoInput(0),
**kwargs,
) -> None:
super().__init__(*args, **{"interpolation": "linear_index", **kwargs})
self.index_lag = _drb(defaults.index_lag, index_lag)
if isinstance(index_base, NoInput):
raise ValueError("`index_base` must be given for IndexCurve.")
self.index_base: DualTypes = index_base

super().__init__(*args, **{"interpolation": "linear_index", **kwargs})

def index_value(self, date: datetime, interpolation: str = "daily") -> DualTypes:
"""
Calculate the accrued value of the index from the ``index_base``.
Expand Down
14 changes: 10 additions & 4 deletions python/rateslib/curves/rs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from rateslib.calendars import CalInput, _get_modifier, get_calendar # type: ignore[attr-defined]
from rateslib.calendars.dcfs import _get_convention
from rateslib.default import NoInput, _drb
from rateslib.dual import Number, _get_adorder
from rateslib.dual import DualTypes, Number, _get_adorder
from rateslib.rs import (
ADOrder,
FlatBackwardInterpolator,
Expand Down Expand Up @@ -38,15 +38,19 @@ def __init__(
self,
nodes: dict[datetime, Number],
*,
interpolation: str | Callable | NoInput = NoInput(0),
interpolation: str
| Callable[[datetime, dict[datetime, DualTypes]], DualTypes]
| NoInput = NoInput(0),
id: str | NoInput = NoInput(0),
convention: str | NoInput = NoInput(0),
modifier: str | NoInput = NoInput(0),
calendar: CalInput = NoInput(0),
ad: int = 0,
index_base: float | NoInput = NoInput(0),
):
self._py_interpolator = interpolation if callable(interpolation) else None
self._py_interpolator: Callable[[datetime, dict[datetime, DualTypes]], DualTypes] | None = (
interpolation if callable(interpolation) else None
)

self.obj = CurveObj(
nodes=nodes,
Expand Down Expand Up @@ -92,7 +96,9 @@ def _set_ad_order(self, ad: int) -> None:
self.obj.set_ad_order(_get_adorder(ad))

@staticmethod
def _validate_interpolator(interpolation: str | Callable | NoInput) -> CurveInterpolator:
def _validate_interpolator(
interpolation: str | Callable[[datetime, dict[datetime, DualTypes]], DualTypes] | NoInput,
) -> CurveInterpolator:
if interpolation is NoInput.blank:
return _get_interpolator(defaults.interpolation["Curve"])
elif isinstance(interpolation, str):
Expand Down
8 changes: 5 additions & 3 deletions python/rateslib/dual/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# Contact rateslib at gmail.com if this code is observed outside its intended sphere.


def set_order(val: Number, order: int) -> Number:
def set_order(val: DualTypes, order: int) -> DualTypes:
"""
Changes the order of a :class:`Dual` or :class:`Dual2` leaving floats and ints
unchanged.
Expand All @@ -42,6 +42,8 @@ def set_order(val: Number, order: int) -> Number:
elif order == 1 and isinstance(val, Dual2):
return val.to_dual()
elif order == 0:
if isinstance(val, Variable): # TODO (low): remove branch when float(Variable) is fixed
return val.real
return float(val)
# otherwise:
# - val is a Float or an Int
Expand All @@ -50,8 +52,8 @@ def set_order(val: Number, order: int) -> Number:


def set_order_convert(
val: Number, order: int, tag: list[str] | None, vars_from: Dual | Dual2 | None = None
) -> Number:
val: DualTypes, order: int, tag: list[str] | None, vars_from: Dual | Dual2 | None = None
) -> DualTypes:
"""
Convert a float, :class:`Dual` or :class:`Dual2` type to a specified alternate type.
Expand Down
21 changes: 11 additions & 10 deletions python/rateslib/fx/fx_forwards.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def rate(
self,
pair: str,
settlement: datetime | NoInput = NoInput(0),
) -> Number:
) -> DualTypes:
"""
Return the fx forward rate for a currency pair.
Expand Down Expand Up @@ -505,7 +505,7 @@ def _rate_with_path(
pair: str,
settlement: datetime | NoInput = NoInput(0),
path: list[dict[str, int]] | NoInput = NoInput(0),
) -> tuple[Number, list[dict[str, int]]]:
) -> tuple[DualTypes, list[dict[str, int]]]:
"""
Return the fx forward rate for a currency pair, including the path taken to traverse ccys.
Expand Down Expand Up @@ -543,7 +543,7 @@ def _get_d_f_idx_and_path(
raise ValueError("`settlement` cannot be before immediate FX rate date.")

if settlement_ == self.fx_rates_immediate.settlement:
rate_ = self.fx_rates_immediate.rate(pair)
rate_: DualTypes = self.fx_rates_immediate.rate(pair)
_, _, path = _get_d_f_idx_and_path(pair, path)
return rate_, path

Expand All @@ -554,22 +554,23 @@ def _get_d_f_idx_and_path(

# otherwise must rely on curves and path search which is slower
d_idx, f_idx, path = _get_d_f_idx_and_path(pair, path)
rate_, current_idx = 1.0, f_idx
rate_ = 1.0
current_idx = f_idx
for route in path:
if "col" in route:
coll_ccy = self.currencies_list[current_idx]
cash_ccy = self.currencies_list[route["col"]]
w_i = self.fx_curves[f"{cash_ccy}{coll_ccy}"][settlement_]
v_i = self.fx_curves[f"{coll_ccy}{coll_ccy}"][settlement_]
rate_ *= self.fx_rates_immediate.fx_array[route["col"], current_idx]
rate_ *= self.fx_rates_immediate._fx_array_el(route["col"], current_idx)
rate_ *= w_i / v_i
current_idx = route["col"]
elif "row" in route:
coll_ccy = self.currencies_list[route["row"]]
cash_ccy = self.currencies_list[current_idx]
w_i = self.fx_curves[f"{cash_ccy}{coll_ccy}"][settlement_]
v_i = self.fx_curves[f"{coll_ccy}{coll_ccy}"][settlement_]
rate_ *= self.fx_rates_immediate.fx_array[route["row"], current_idx]
rate_ *= self.fx_rates_immediate._fx_array_el(route["row"], current_idx)
rate_ *= v_i / w_i
current_idx = route["row"]

Expand Down Expand Up @@ -733,7 +734,7 @@ def convert(
settlement_: datetime = self.immediate if isinstance(settlement, NoInput) else settlement
value_date_: datetime = settlement_ if isinstance(value_date, NoInput) else value_date

fx_rate: Number = self.rate(domestic + foreign, settlement_)
fx_rate: DualTypes = self.rate(domestic + foreign, settlement_)
if value_date_ == settlement_:
return fx_rate * value
else:
Expand Down Expand Up @@ -822,7 +823,7 @@ def swap(
pair: str,
settlements: list[datetime],
path: list[dict[str, int]] | NoInput = NoInput(0),
) -> Number:
) -> DualTypes:
"""
Return the FXSwap mid-market rate for the given currency pair.
Expand Down Expand Up @@ -1018,9 +1019,9 @@ def plot(
points: int = (right_ - left_).days
x = [left_ + timedelta(days=i) for i in range(points)]
_, path = self._rate_with_path(pair, x[0])
rates: list[Number] = [self._rate_with_path(pair, _, path=path)[0] for _ in x]
rates: list[DualTypes] = [self._rate_with_path(pair, _, path=path)[0] for _ in x]
if not fx_swap:
y: list[list[Number]] = [rates]
y: list[list[DualTypes]] = [rates]
else:
y = [[(rate - rates[0]) * 10000 for rate in rates]]
return plot(x, y)
Expand Down
Empty file added python/rateslib/py.typed
Empty file.
Loading

0 comments on commit e7bf3dd

Please sign in to comment.