From 1e0d035894de8546aef54949afc80eafa0fc2bac Mon Sep 17 00:00:00 2001 From: "JHM Darbyshire (M1)" Date: Tue, 24 Dec 2024 21:16:24 +0100 Subject: [PATCH] validate caches on FXForwards --- python/rateslib/curves/curves.py | 26 +++++++++++++------------- python/rateslib/default.py | 15 +++++++++------ python/rateslib/fx/fx_forwards.py | 18 +++++++++--------- 3 files changed, 31 insertions(+), 28 deletions(-) diff --git a/python/rateslib/curves/curves.py b/python/rateslib/curves/curves.py index 1e20ee36..64d7f81a 100644 --- a/python/rateslib/curves/curves.py +++ b/python/rateslib/curves/curves.py @@ -2309,8 +2309,8 @@ def _check_init_attribute(self, attr: str) -> None: f"Cannot composite curves with different attributes, got for '{attr}': {attrs},", ) - @_validate_caches # type: ignore[arg-type] - def rate( + @_validate_caches + def rate( # type: ignore[override] self, effective: datetime, termination: datetime | str | NoInput = NoInput(0), @@ -2392,7 +2392,7 @@ def rate( return _ - @_validate_caches # type: ignore[arg-type] + @_validate_caches def __getitem__(self, date: datetime) -> DualTypes: if defaults.curve_caching and date in self._cache: return self._cache[date] @@ -2424,7 +2424,7 @@ def __getitem__(self, date: datetime) -> DualTypes: f"Base curve type is unrecognised: {self._base_type}", ) # pragma: no cover - @_validate_caches # type: ignore[arg-type] + @_validate_caches def shift( self, spread: DualTypes, @@ -2471,7 +2471,7 @@ def shift( _.collateral = _drb(None, collateral) return _ - @_validate_caches # type: ignore[arg-type] + @_validate_caches def translate(self, start: datetime, t: bool = False) -> CompositeCurve: """ Create a new curve with an initial node date moved forward keeping all else @@ -2497,7 +2497,7 @@ def translate(self, start: datetime, t: bool = False) -> CompositeCurve: # cache check unnecessary since translate is constructed from up-to-date objects directly return CompositeCurve(curves=[curve.translate(start, t) for curve in self.curves]) - @_validate_caches # type: ignore[arg-type] + @_validate_caches def roll(self, tenor: datetime | str) -> CompositeCurve: """ Create a new curve with its shape translated in time @@ -2523,7 +2523,7 @@ def roll(self, tenor: datetime | str) -> CompositeCurve: # cache check unnecessary since roll is constructed from up-to-date objects directly return CompositeCurve(curves=[curve.roll(tenor) for curve in self.curves]) - @_validate_caches # type: ignore[arg-type] + @_validate_caches def index_value(self, date: datetime, interpolation: str = "daily") -> DualTypes: """ Calculate the accrued value of the index from the ``index_base``, which is taken @@ -2603,8 +2603,8 @@ def __init__( self.multi_csa_max_step = min(1825, multi_csa_max_step) super().__init__(curves, id) - @_validate_caches # type: ignore[arg-type] - def rate( + @_validate_caches + def rate( # type: ignore[override] self, effective: datetime, termination: datetime | str, @@ -2646,7 +2646,7 @@ def rate( _: DualTypes = (df_num / df_den - 1) * 100 / (d * n) return _ - @_validate_caches # type: ignore[arg-type] + @_validate_caches def __getitem__(self, date: datetime) -> DualTypes: # will return a composited discount factor if date == self.curves[0].node_dates[0]: @@ -2691,7 +2691,7 @@ def _get_step(step: int) -> int: _ *= min_ratio return _ - @_validate_caches # type: ignore[arg-type] + @_validate_caches # unnecessary because up-to-date objects are referred to directly def translate(self, start: datetime, t: bool = False) -> MultiCsaCurve: """ @@ -2721,7 +2721,7 @@ def translate(self, start: datetime, t: bool = False) -> MultiCsaCurve: multi_csa_min_step=self.multi_csa_min_step, ) - @_validate_caches # type: ignore[arg-type] + @_validate_caches # unnecessary because up-to-date objects are referred to directly def roll(self, tenor: datetime | str) -> MultiCsaCurve: """ @@ -2751,7 +2751,7 @@ def roll(self, tenor: datetime | str) -> MultiCsaCurve: multi_csa_min_step=self.multi_csa_min_step, ) - @_validate_caches # type: ignore[arg-type] + @_validate_caches def shift( self, spread: DualTypes, diff --git a/python/rateslib/default.py b/python/rateslib/default.py index 218252bb..006669fa 100644 --- a/python/rateslib/default.py +++ b/python/rateslib/default.py @@ -4,7 +4,7 @@ from collections.abc import Callable from datetime import datetime from enum import Enum -from typing import Any +from typing import Any, ParamSpec, TypeVar import matplotlib.dates as mdates import matplotlib.pyplot as plt @@ -392,16 +392,19 @@ def _make_py_json(json: str, class_name: str) -> str: """Modifies the output JSON output for Rust structs wrapped by Python classes.""" return '{"Py":' + json + "}" +P = ParamSpec("P") +R = TypeVar("R") -def _validate_caches(func: Callable[[*tuple[Any, ...]], Any]) -> Callable[[*tuple[Any, ...]], Any]: +def _validate_caches(func: Callable[P, R]) -> Callable[P, R]: """ Add a decorator to a class instance method to first validate the cache before performing additional operations. If a change is detected the implemented `validate_cache` function is responsible for resetting the cache and updating any `cache_id`s. """ - def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: - self._validate_cache() - return func(self, *args, **kwargs) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + self = args[0] + self._validate_cache() # type: ignore[attr-defined] + return func(*args, **kwargs) - return wrapper # type: ignore[return-value] + return wrapper diff --git a/python/rateslib/fx/fx_forwards.py b/python/rateslib/fx/fx_forwards.py index 80361557..8916192c 100644 --- a/python/rateslib/fx/fx_forwards.py +++ b/python/rateslib/fx/fx_forwards.py @@ -455,7 +455,7 @@ def _get_recursive_chain( # Commercial use of this code, and/or copying and redistribution is prohibited. # Contact rateslib at gmail.com if this code is observed outside its intended sphere. - @_validate_caches # type: ignore[arg-type] + @_validate_caches def rate( self, pair: str, @@ -580,7 +580,7 @@ def _get_d_f_idx_and_path( return rate_, path - @_validate_caches # type: ignore[arg-type] + @_validate_caches def positions( self, value: Number, base: str | NoInput = NoInput(0), aggregate: bool = False ) -> Series[float] | DataFrame: @@ -659,7 +659,7 @@ def positions( _d: DataFrame = df.sort_index(axis=1) return _d - @_validate_caches # type: ignore[arg-type] + @_validate_caches def convert( self, value: DualTypes, @@ -747,7 +747,7 @@ def convert( crv = self.curve(foreign, collateral) return fx_rate * value * crv[settlement_] / crv[value_date_] - @_validate_caches # type: ignore[arg-type] + @_validate_caches # this is technically unnecessary since calls pre-cached method: convert def convert_positions( self, @@ -817,16 +817,16 @@ def convert_positions( d_sum: DualTypes = 0.0 for ccy in array_.index: # typing d is a datetime by default. - value_: DualTypes | None = self.convert(array_.loc[ccy, d], ccy, base, d) + value_: DualTypes | None = self.convert(array_.loc[ccy, d], ccy, base, d) # type: ignore[arg-type] d_sum += 0.0 if value_ is None else value_ if abs(d_sum) < 1e-2: sum += d_sum else: # only discount if there is a real value - value_ = self.convert(d_sum, base, base, d, self.immediate) + value_ = self.convert(d_sum, base, base, d, self.immediate) # type: ignore[arg-type] sum += 0.0 if value_ is None else value_ return sum - @_validate_caches # type: ignore[arg-type] + @_validate_caches def swap( self, pair: str, @@ -978,7 +978,7 @@ def curve( id=id, ) - @_validate_caches # type: ignore[arg-type] + @_validate_caches def plot( self, pair: str, @@ -1052,7 +1052,7 @@ def _set_ad_order(self, order: int) -> None: self.fx_rates_immediate._set_ad_order(order) self._cache_id = self._cache_id_associate # update the cache id after changing values - @_validate_caches # type: ignore[arg-type] + @_validate_caches def to_json(self) -> str: if isinstance(self.fx_rates, list): fx_rates: list[str] | str = [_.to_json() for _ in self.fx_rates]