diff --git a/python/rateslib/curves/curves.py b/python/rateslib/curves/curves.py index 0b07c0f0..890556ff 100644 --- a/python/rateslib/curves/curves.py +++ b/python/rateslib/curves/curves.py @@ -13,6 +13,7 @@ from collections.abc import Callable from datetime import datetime, timedelta from math import comb, floor +from os import urandom from typing import TYPE_CHECKING, Any from uuid import uuid4 @@ -252,9 +253,8 @@ def __init__( # type: ignore[no-untyped-def] self._set_ad_order(order=ad) - @property - def _cache_id(self) -> int: - return self._cache_id_store + def __hash__(self) -> int: + return self._state_id def __eq__(self, other: Any) -> bool: """Test two curves are identical""" @@ -636,7 +636,7 @@ def clear_cache(self) -> None: Alternatively the curve caching as a feature can be set to *False* in ``defaults``. """ self._cache: dict[datetime, DualTypes] = dict() - self._cache_id_store: int = hash(uuid4()) + self._state_id: int = hash(urandom(8)) # 64-bit entropy def _cached_value(self, date: datetime, val: DualTypes) -> DualTypes: if defaults.curve_caching: @@ -702,6 +702,7 @@ def csolve(self) -> None: self.spline = PPSplineDual2(4, t_posix, None) self.spline.csolve(tau_posix, y, left_n, right_n, False) # type: ignore[attr-defined] + self.clear_cache() def shift( self, @@ -2584,13 +2585,12 @@ def index_value(self, date: datetime, interpolation: str = "daily") -> DualTypes def _get_node_vector(self) -> Arr1dObj | Arr1dF64: raise NotImplementedError("Instances of CompositeCurve do not have solvable variables.") - @property - def _cache_id_associate(self) -> int: - return hash(sum(curve._cache_id for curve in self.curves)) + def _composited_hashes(self) -> int: + return hash(sum(hash(curve) for curve in self.curves)) def _clear_cache(self) -> None: """ - Clear the cache of values on a *CompositeCurve* type. + Clear the cache of values on a *CompositeCurve* type, and update the state id. Returns ------- @@ -2599,14 +2599,14 @@ def _clear_cache(self) -> None: Notes ----- This method is called automatically when any of the composited curves - are detected to have been mutated, via their ``_cache_id``, which therefore + are detected to have been mutated, via their ``_state_id``, which therefore invalidates the cache on a composite curve. """ self._cache: dict[datetime, DualTypes] = dict() - self._cache_id_store = self._cache_id_associate + self._state_id = self._composited_hashes() def _validate_cache(self) -> None: - if self._cache_id != self._cache_id_associate: + if hash(self) != self._composited_hashes(): # If any of the associated curves have been mutated then the cache is invalidated self._clear_cache() @@ -2975,6 +2975,11 @@ def __init__( def id(self): return self._id # overloads Curve getting id from CurveObj + def __hash__(self) -> int: + # ProxyCurve is directly associated with its FXForwards object + self.fx_forwards._validate_cache() + return hash(self.fx_forwards) + def __getitem__(self, date: datetime) -> DualTypes: self.fx_forwards._validate_cache() # manually handle cache check @@ -2985,12 +2990,6 @@ def __getitem__(self, date: datetime) -> DualTypes: _3: DualTypes = self.fx_forwards.fx_curves[self.coll_pair][date] return _1 / _2 * _3 - @property - def _cache_id(self) -> int: - # the state of the cache for a ProxyCurve is fully dependent on the state of the - # cache of its contained FXForwards object, which is what derives its calculations. - return self.fx_forwards._cache_id - def to_json(self) -> str: # pragma: no cover # type: ignore """ Not implemented for :class:`~rateslib.fx.ProxyCurve` s. diff --git a/python/rateslib/fx/fx_forwards.py b/python/rateslib/fx/fx_forwards.py index 0559ace2..455629f2 100644 --- a/python/rateslib/fx/fx_forwards.py +++ b/python/rateslib/fx/fx_forwards.py @@ -161,7 +161,7 @@ def update(self, fx_rates: list[dict[str, float]] | NoInput = NoInput(0)) -> Non for fxr_obj, fxr_up in zip(self_fx_rates, fx_rates, strict=True): fxr_obj.update(fxr_up) self._calculate_immediate_rates(base=self.base, init=False) - self._cache_id = self._cache_id_associate + self._state_id = self._composited_hashes() def __init__( self, @@ -173,19 +173,20 @@ def __init__( self._validate_fx_curves(fx_curves) self.fx_rates: FXRates | list[FXRates] = fx_rates self._calculate_immediate_rates(base, init=True) - self._cache_id = self._cache_id_associate - pass - - @property - def _cache_id_associate(self) -> int: - self_fx_rates = self.fx_rates if isinstance(self.fx_rates, list) else [self.fx_rates] - return hash( - sum(curve._cache_id for curve in self.fx_curves.values()) - + sum(fxr._state_id for fxr in self_fx_rates) + self._state_id = self._composited_hashes() + + def __hash__(self) -> int: + return self._state_id + + def _composited_hashes(self) -> int: + self_fx_rates = [self.fx_rates] if not isinstance(self.fx_rates, list) else self.fx_rates + total = sum(hash(curve) for curve in self.fx_curves.values()) + sum( + hash(fxr) for fxr in self_fx_rates ) + return hash(total) def _validate_cache(self) -> None: - if self._cache_id != self._cache_id_associate: + if hash(self) != self._composited_hashes(): self.update() def _validate_fx_curves(self, fx_curves: dict[str, Curve]) -> None: @@ -1050,7 +1051,7 @@ def _set_ad_order(self, order: int) -> None: else: self.fx_rates._set_ad_order(order) self.fx_rates_immediate._set_ad_order(order) - self._cache_id = self._cache_id_associate # update the cache id after changing values + self._state_id = self._composited_hashes() # update the cache id after changing values @_validate_caches def to_json(self) -> str: diff --git a/python/rateslib/fx_volatility.py b/python/rateslib/fx_volatility.py index e18fdfdf..cd2eb373 100644 --- a/python/rateslib/fx_volatility.py +++ b/python/rateslib/fx_volatility.py @@ -2,6 +2,7 @@ from datetime import datetime, timedelta from datetime import datetime as dt +from os import urandom from uuid import uuid4 import numpy as np @@ -118,6 +119,28 @@ def __init__( self._set_ad_order(ad) # includes csolve + def clear_cache(self) -> None: + """ + Clear the cache of values on a *Smile* type. + + Returns + ------- + None + + Notes + ----- + This should be used if any modification has been made to the *Smile*. + Users are advised against making direct modification to *Curve* classes once + constructed to avoid the issue of un-cleared caches returning erroneous values. + + Alternatively the curve caching as a feature can be set to *False* in ``defaults``. + """ + self._cache: dict[float, DualTypes] = dict() + self._state_id: int = hash(urandom(8)) # 64-bit entropy + + def __hash__(self) -> int: + return self._state_id + def __iter__(self): raise TypeError("`FXDeltaVolSmile` is not iterable.") @@ -436,9 +459,7 @@ def csolve(self) -> None: self.spline = Spline(4, self.t, None) self.spline.csolve(tau, y, left_n, right_n, False) - - # self._create_approx_spline_conversions(Spline) - return None + self.clear_cache() # def _build_datatable(self): # """ @@ -524,7 +545,7 @@ def csolve(self) -> None: # self.spline_u_delta_approx.csolve(u, delta.tolist()[::-1], 0, 0, False) # return None - def _set_ad_order(self, order: int): + def _set_ad_order(self, order: int) -> None: if order == getattr(self, "ad", None): return None elif order not in [0, 1, 2]: @@ -535,8 +556,7 @@ def _set_ad_order(self, order: int): k: set_order_convert(v, order, [f"{self.id}{i}"]) for i, (k, v) in enumerate(self.nodes.items()) } - self.csolve() - return None + self.csolve() # also clears cache def plot( self, @@ -625,6 +645,7 @@ def _set_node_vector(self, vector, ad): *DualArgs[1:], ) self.csolve() + self.clear_cache() def _get_node_vector(self): """Get a 1d array of variables associated with nodes of this object updated by Solver""" diff --git a/python/tests/test_curves.py b/python/tests/test_curves.py index d137b396..1837f720 100644 --- a/python/tests/test_curves.py +++ b/python/tests/test_curves.py @@ -1244,11 +1244,30 @@ def test_cache_id_update(self, method, args): }, id="sofr", ) - original = curve._cache_id + original = hash(curve) getattr(curve, method)(*args) - new = curve._cache_id + new = hash(curve) assert new != original + def test_csolve_clear_cache(self): + c = Curve( + nodes={dt(2000, 1, 1): 1.0, dt(2002, 1, 1): 0.99}, + t=[ + dt(2000, 1, 1), + dt(2000, 1, 1), + dt(2000, 1, 1), + dt(2000, 1, 1), + dt(2002, 1, 1), + dt(2002, 1, 1), + dt(2002, 1, 1), + dt(2002, 1, 1), + ], + ) + before = hash(c) + c.csolve() + after = hash(c) + assert before != after + class TestLineCurve: def test_repr(self): @@ -1287,9 +1306,9 @@ def test_cache_id_update(self, method, args): }, id="sofr", ) - original = curve._cache_id + original = hash(curve) getattr(curve, method)(*args) - new = curve._cache_id + new = hash(curve) assert new != original @@ -1354,7 +1373,7 @@ def test_typing_as_curve(self): @pytest.mark.parametrize( ("method", "args"), [("clear_cache", tuple()), ("_set_node_vector", ([0.99], 1))] ) - def test_cache_id_update(self, method, args): + def test_state_id_update(self, method, args): curve = IndexCurve( nodes={ dt(2022, 1, 1): 1.0, @@ -1363,9 +1382,9 @@ def test_cache_id_update(self, method, args): id="sofr", index_base=200.0, ) - original = curve._cache_id + original = hash(curve) getattr(curve, method)(*args) - new = curve._cache_id + new = hash(curve) assert new != original @@ -2062,10 +2081,10 @@ def test_cache_is_validated_on_getitem(self): curve = fxf.curve("cad", "eur") fxr1.update({"usdeur": 100000000.0}) fxf.curve("eur", "eur")._set_node_vector([0.5], 1) - prior_id = fxf._cache_id + before = hash(fxf) curve[dt(2022, 1, 9)] - new_id = fxf._cache_id - assert prior_id != new_id + after = hash(fxf) + assert before != after class TestPlotCurve: diff --git a/python/tests/test_fx.py b/python/tests/test_fx.py index 418a8aaa..9ade24ec 100644 --- a/python/tests/test_fx.py +++ b/python/tests/test_fx.py @@ -1252,7 +1252,7 @@ class TestFXForwards: ("to_json", tuple()), ], ) - def test_cache_id_update_on_fxr_update(self, method, args): + def test_hash_update_on_fxr_update(self, method, args): # test validate cache works correctly on various methods after FXRates update fxr1 = FXRates({"eurusd": 1.05}, settlement=dt(2022, 1, 3)) fxr2 = FXRates({"usdcad": 1.1}, settlement=dt(2022, 1, 2)) @@ -1270,15 +1270,15 @@ def test_cache_id_update_on_fxr_update(self, method, args): }, ) - original_id = fxf._cache_id + before = hash(fxf) getattr(fxf, method)(*args) # no cache update is necessary - assert original_id == fxf._cache_id + assert before == hash(fxf) fxr1.update({"eurusd": 2.0}) getattr(fxf, method)(*args) # cache update should have occurred - assert original_id != fxf._cache_id + assert before != hash(fxf) @pytest.mark.parametrize( ("method", "args"), @@ -1291,7 +1291,7 @@ def test_cache_id_update_on_fxr_update(self, method, args): ("to_json", tuple()), ], ) - def test_cache_id_update_on_curve_update(self, method, args): + def test_hash_update_on_curve_update(self, method, args): # test validate cache works correctly on various methods after Curve update fxr1 = FXRates({"eurusd": 1.05}, settlement=dt(2022, 1, 3)) fxr2 = FXRates({"usdcad": 1.1}, settlement=dt(2022, 1, 2)) @@ -1309,12 +1309,12 @@ def test_cache_id_update_on_curve_update(self, method, args): }, ) - original_id = fxf._cache_id + before = hash(fxf) getattr(fxf, method)(*args) # no cache update is necessary - assert original_id == fxf._cache_id + assert before == hash(fxf) fxf.curve("eur", "eur")._set_node_vector([0.998], 1) getattr(fxf, method)(*args) # cache update should have occurred - assert original_id != fxf._cache_id + assert before != hash(fxf) diff --git a/python/tests/test_fx_volatility.py b/python/tests/test_fx_volatility.py index 429d2db3..53817912 100644 --- a/python/tests/test_fx_volatility.py +++ b/python/tests/test_fx_volatility.py @@ -341,6 +341,33 @@ def test_iter_raises(self) -> None: with pytest.raises(TypeError, match="`FXDeltaVolSmile` is not iterable."): fxvs.__iter__() + def test_hash(self): + fxvs = FXDeltaVolSmile( + nodes={0.25: 10.0, 0.5: 10.0, 0.75: 11.0}, + delta_type="forward", + eval_date=dt(2023, 3, 16), + expiry=dt(2023, 6, 16), + id="vol", + ) + assert hash(fxvs) == fxvs._state_id + + @pytest.mark.parametrize( + ("method", "args"), + [("csolve", tuple()), ("_set_ad_order", (1,)), ("_set_node_vector", ([9.9, 9.8, 9.9], 1))], + ) + def test_clear_cache(self, method, args): + fxvs = FXDeltaVolSmile( + nodes={0.25: 10.0, 0.5: 10.0, 0.75: 11.0}, + delta_type="forward", + eval_date=dt(2023, 3, 16), + expiry=dt(2023, 6, 16), + id="vol", + ) + before = hash(fxvs) + getattr(fxvs, method)(*args) + after = hash(fxvs) + assert before != after + class TestFXDeltaVolSurface: def test_expiry_before_eval(self) -> None: