Skip to content

Commit

Permalink
validate caches on FXForwards
Browse files Browse the repository at this point in the history
  • Loading branch information
attack68 committed Dec 24, 2024
1 parent 32e598a commit 1e0d035
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 28 deletions.
26 changes: 13 additions & 13 deletions python/rateslib/curves/curves.py
Original file line number Diff line number Diff line change
Expand Up @@ -2309,8 +2309,8 @@ def _check_init_attribute(self, attr: str) -> None:
f"Cannot composite curves with different attributes, got for '{attr}': {attrs},",
)

@_validate_caches # type: ignore[arg-type]
def rate(
@_validate_caches
def rate( # type: ignore[override]
self,
effective: datetime,
termination: datetime | str | NoInput = NoInput(0),
Expand Down Expand Up @@ -2392,7 +2392,7 @@ def rate(

return _

@_validate_caches # type: ignore[arg-type]
@_validate_caches
def __getitem__(self, date: datetime) -> DualTypes:
if defaults.curve_caching and date in self._cache:
return self._cache[date]
Expand Down Expand Up @@ -2424,7 +2424,7 @@ def __getitem__(self, date: datetime) -> DualTypes:
f"Base curve type is unrecognised: {self._base_type}",
) # pragma: no cover

@_validate_caches # type: ignore[arg-type]
@_validate_caches
def shift(
self,
spread: DualTypes,
Expand Down Expand Up @@ -2471,7 +2471,7 @@ def shift(
_.collateral = _drb(None, collateral)
return _

@_validate_caches # type: ignore[arg-type]
@_validate_caches
def translate(self, start: datetime, t: bool = False) -> CompositeCurve:
"""
Create a new curve with an initial node date moved forward keeping all else
Expand All @@ -2497,7 +2497,7 @@ def translate(self, start: datetime, t: bool = False) -> CompositeCurve:
# cache check unnecessary since translate is constructed from up-to-date objects directly
return CompositeCurve(curves=[curve.translate(start, t) for curve in self.curves])

@_validate_caches # type: ignore[arg-type]
@_validate_caches
def roll(self, tenor: datetime | str) -> CompositeCurve:
"""
Create a new curve with its shape translated in time
Expand All @@ -2523,7 +2523,7 @@ def roll(self, tenor: datetime | str) -> CompositeCurve:
# cache check unnecessary since roll is constructed from up-to-date objects directly
return CompositeCurve(curves=[curve.roll(tenor) for curve in self.curves])

@_validate_caches # type: ignore[arg-type]
@_validate_caches
def index_value(self, date: datetime, interpolation: str = "daily") -> DualTypes:
"""
Calculate the accrued value of the index from the ``index_base``, which is taken
Expand Down Expand Up @@ -2603,8 +2603,8 @@ def __init__(
self.multi_csa_max_step = min(1825, multi_csa_max_step)
super().__init__(curves, id)

@_validate_caches # type: ignore[arg-type]
def rate(
@_validate_caches
def rate( # type: ignore[override]
self,
effective: datetime,
termination: datetime | str,
Expand Down Expand Up @@ -2646,7 +2646,7 @@ def rate(
_: DualTypes = (df_num / df_den - 1) * 100 / (d * n)
return _

@_validate_caches # type: ignore[arg-type]
@_validate_caches
def __getitem__(self, date: datetime) -> DualTypes:
# will return a composited discount factor
if date == self.curves[0].node_dates[0]:
Expand Down Expand Up @@ -2691,7 +2691,7 @@ def _get_step(step: int) -> int:
_ *= min_ratio
return _

@_validate_caches # type: ignore[arg-type]
@_validate_caches
# unnecessary because up-to-date objects are referred to directly
def translate(self, start: datetime, t: bool = False) -> MultiCsaCurve:
"""
Expand Down Expand Up @@ -2721,7 +2721,7 @@ def translate(self, start: datetime, t: bool = False) -> MultiCsaCurve:
multi_csa_min_step=self.multi_csa_min_step,
)

@_validate_caches # type: ignore[arg-type]
@_validate_caches
# unnecessary because up-to-date objects are referred to directly
def roll(self, tenor: datetime | str) -> MultiCsaCurve:
"""
Expand Down Expand Up @@ -2751,7 +2751,7 @@ def roll(self, tenor: datetime | str) -> MultiCsaCurve:
multi_csa_min_step=self.multi_csa_min_step,
)

@_validate_caches # type: ignore[arg-type]
@_validate_caches
def shift(
self,
spread: DualTypes,
Expand Down
15 changes: 9 additions & 6 deletions python/rateslib/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections.abc import Callable
from datetime import datetime
from enum import Enum
from typing import Any
from typing import Any, ParamSpec, TypeVar

import matplotlib.dates as mdates
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -392,16 +392,19 @@ def _make_py_json(json: str, class_name: str) -> str:
"""Modifies the output JSON output for Rust structs wrapped by Python classes."""
return '{"Py":' + json + "}"

P = ParamSpec("P")
R = TypeVar("R")

def _validate_caches(func: Callable[[*tuple[Any, ...]], Any]) -> Callable[[*tuple[Any, ...]], Any]:
def _validate_caches(func: Callable[P, R]) -> Callable[P, R]:
"""
Add a decorator to a class instance method to first validate the cache before performing
additional operations. If a change is detected the implemented `validate_cache` function
is responsible for resetting the cache and updating any `cache_id`s.
"""

def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
self._validate_cache()
return func(self, *args, **kwargs)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
self = args[0]
self._validate_cache() # type: ignore[attr-defined]
return func(*args, **kwargs)

return wrapper # type: ignore[return-value]
return wrapper
18 changes: 9 additions & 9 deletions python/rateslib/fx/fx_forwards.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ def _get_recursive_chain(
# Commercial use of this code, and/or copying and redistribution is prohibited.
# Contact rateslib at gmail.com if this code is observed outside its intended sphere.

@_validate_caches # type: ignore[arg-type]
@_validate_caches
def rate(
self,
pair: str,
Expand Down Expand Up @@ -580,7 +580,7 @@ def _get_d_f_idx_and_path(

return rate_, path

@_validate_caches # type: ignore[arg-type]
@_validate_caches
def positions(
self, value: Number, base: str | NoInput = NoInput(0), aggregate: bool = False
) -> Series[float] | DataFrame:
Expand Down Expand Up @@ -659,7 +659,7 @@ def positions(
_d: DataFrame = df.sort_index(axis=1)
return _d

@_validate_caches # type: ignore[arg-type]
@_validate_caches
def convert(
self,
value: DualTypes,
Expand Down Expand Up @@ -747,7 +747,7 @@ def convert(
crv = self.curve(foreign, collateral)
return fx_rate * value * crv[settlement_] / crv[value_date_]

@_validate_caches # type: ignore[arg-type]
@_validate_caches
# this is technically unnecessary since calls pre-cached method: convert
def convert_positions(
self,
Expand Down Expand Up @@ -817,16 +817,16 @@ def convert_positions(
d_sum: DualTypes = 0.0
for ccy in array_.index:
# typing d is a datetime by default.
value_: DualTypes | None = self.convert(array_.loc[ccy, d], ccy, base, d)
value_: DualTypes | None = self.convert(array_.loc[ccy, d], ccy, base, d) # type: ignore[arg-type]
d_sum += 0.0 if value_ is None else value_
if abs(d_sum) < 1e-2:
sum += d_sum
else: # only discount if there is a real value
value_ = self.convert(d_sum, base, base, d, self.immediate)
value_ = self.convert(d_sum, base, base, d, self.immediate) # type: ignore[arg-type]
sum += 0.0 if value_ is None else value_
return sum

@_validate_caches # type: ignore[arg-type]
@_validate_caches
def swap(
self,
pair: str,
Expand Down Expand Up @@ -978,7 +978,7 @@ def curve(
id=id,
)

@_validate_caches # type: ignore[arg-type]
@_validate_caches
def plot(
self,
pair: str,
Expand Down Expand Up @@ -1052,7 +1052,7 @@ def _set_ad_order(self, order: int) -> None:
self.fx_rates_immediate._set_ad_order(order)
self._cache_id = self._cache_id_associate # update the cache id after changing values

@_validate_caches # type: ignore[arg-type]
@_validate_caches
def to_json(self) -> str:
if isinstance(self.fx_rates, list):
fx_rates: list[str] | str = [_.to_json() for _ in self.fx_rates]
Expand Down

0 comments on commit 1e0d035

Please sign in to comment.