Skip to content

Commit

Permalink
Merge branch 'main' into rustcurve
Browse files Browse the repository at this point in the history
# Conflicts:
#	python/rateslib/curves/curves.py
  • Loading branch information
attack68 committed Dec 27, 2024
2 parents 662b51a + f5e76f4 commit 7970e69
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 53 deletions.
33 changes: 16 additions & 17 deletions python/rateslib/curves/curves.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from collections.abc import Callable
from datetime import datetime, timedelta
from math import comb, floor
from os import urandom
from typing import TYPE_CHECKING, Any
from uuid import uuid4

Expand Down Expand Up @@ -252,9 +253,8 @@ def __init__( # type: ignore[no-untyped-def]

self._set_ad_order(order=ad)

@property
def _cache_id(self) -> int:
return self._cache_id_store
def __hash__(self) -> int:
return self._state_id

def __eq__(self, other: Any) -> bool:
"""Test two curves are identical"""
Expand Down Expand Up @@ -636,7 +636,7 @@ def clear_cache(self) -> None:
Alternatively the curve caching as a feature can be set to *False* in ``defaults``.
"""
self._cache: dict[datetime, DualTypes] = dict()
self._cache_id_store: int = hash(uuid4())
self._state_id: int = hash(urandom(8)) # 64-bit entropy

def _cached_value(self, date: datetime, val: DualTypes) -> DualTypes:
if defaults.curve_caching:
Expand Down Expand Up @@ -702,6 +702,7 @@ def csolve(self) -> None:
self.spline = PPSplineDual2(4, t_posix, None)

self.spline.csolve(tau_posix, y, left_n, right_n, False) # type: ignore[attr-defined]
self.clear_cache()

def shift(
self,
Expand Down Expand Up @@ -2584,13 +2585,12 @@ def index_value(self, date: datetime, interpolation: str = "daily") -> DualTypes
def _get_node_vector(self) -> Arr1dObj | Arr1dF64:
raise NotImplementedError("Instances of CompositeCurve do not have solvable variables.")

@property
def _cache_id_associate(self) -> int:
return hash(sum(curve._cache_id for curve in self.curves))
def _composited_hashes(self) -> int:
return hash(sum(hash(curve) for curve in self.curves))

def _clear_cache(self) -> None:
"""
Clear the cache of values on a *CompositeCurve* type.
Clear the cache of values on a *CompositeCurve* type, and update the state id.
Returns
-------
Expand All @@ -2599,14 +2599,14 @@ def _clear_cache(self) -> None:
Notes
-----
This method is called automatically when any of the composited curves
are detected to have been mutated, via their ``_cache_id``, which therefore
are detected to have been mutated, via their ``_state_id``, which therefore
invalidates the cache on a composite curve.
"""
self._cache: dict[datetime, DualTypes] = dict()
self._cache_id_store = self._cache_id_associate
self._state_id = self._composited_hashes()

def _validate_cache(self) -> None:
if self._cache_id != self._cache_id_associate:
if hash(self) != self._composited_hashes():
# If any of the associated curves have been mutated then the cache is invalidated
self._clear_cache()

Expand Down Expand Up @@ -2975,6 +2975,11 @@ def __init__(
def id(self):
return self._id # overloads Curve getting id from CurveObj

def __hash__(self) -> int:
# ProxyCurve is directly associated with its FXForwards object
self.fx_forwards._validate_cache()
return hash(self.fx_forwards)

def __getitem__(self, date: datetime) -> DualTypes:
self.fx_forwards._validate_cache() # manually handle cache check

Expand All @@ -2985,12 +2990,6 @@ def __getitem__(self, date: datetime) -> DualTypes:
_3: DualTypes = self.fx_forwards.fx_curves[self.coll_pair][date]
return _1 / _2 * _3

@property
def _cache_id(self) -> int:
# the state of the cache for a ProxyCurve is fully dependent on the state of the
# cache of its contained FXForwards object, which is what derives its calculations.
return self.fx_forwards._cache_id

def to_json(self) -> str: # pragma: no cover # type: ignore
"""
Not implemented for :class:`~rateslib.fx.ProxyCurve` s.
Expand Down
25 changes: 13 additions & 12 deletions python/rateslib/fx/fx_forwards.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def update(self, fx_rates: list[dict[str, float]] | NoInput = NoInput(0)) -> Non
for fxr_obj, fxr_up in zip(self_fx_rates, fx_rates, strict=True):
fxr_obj.update(fxr_up)
self._calculate_immediate_rates(base=self.base, init=False)
self._cache_id = self._cache_id_associate
self._state_id = self._composited_hashes()

def __init__(
self,
Expand All @@ -173,19 +173,20 @@ def __init__(
self._validate_fx_curves(fx_curves)
self.fx_rates: FXRates | list[FXRates] = fx_rates
self._calculate_immediate_rates(base, init=True)
self._cache_id = self._cache_id_associate
pass

@property
def _cache_id_associate(self) -> int:
self_fx_rates = self.fx_rates if isinstance(self.fx_rates, list) else [self.fx_rates]
return hash(
sum(curve._cache_id for curve in self.fx_curves.values())
+ sum(fxr._state_id for fxr in self_fx_rates)
self._state_id = self._composited_hashes()

def __hash__(self) -> int:
return self._state_id

def _composited_hashes(self) -> int:
self_fx_rates = [self.fx_rates] if not isinstance(self.fx_rates, list) else self.fx_rates
total = sum(hash(curve) for curve in self.fx_curves.values()) + sum(
hash(fxr) for fxr in self_fx_rates
)
return hash(total)

def _validate_cache(self) -> None:
if self._cache_id != self._cache_id_associate:
if hash(self) != self._composited_hashes():
self.update()

def _validate_fx_curves(self, fx_curves: dict[str, Curve]) -> None:
Expand Down Expand Up @@ -1050,7 +1051,7 @@ def _set_ad_order(self, order: int) -> None:
else:
self.fx_rates._set_ad_order(order)
self.fx_rates_immediate._set_ad_order(order)
self._cache_id = self._cache_id_associate # update the cache id after changing values
self._state_id = self._composited_hashes() # update the cache id after changing values

@_validate_caches
def to_json(self) -> str:
Expand Down
33 changes: 27 additions & 6 deletions python/rateslib/fx_volatility.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from datetime import datetime, timedelta
from datetime import datetime as dt
from os import urandom
from uuid import uuid4

import numpy as np
Expand Down Expand Up @@ -118,6 +119,28 @@ def __init__(

self._set_ad_order(ad) # includes csolve

def clear_cache(self) -> None:
"""
Clear the cache of values on a *Smile* type.
Returns
-------
None
Notes
-----
This should be used if any modification has been made to the *Smile*.
Users are advised against making direct modification to *Curve* classes once
constructed to avoid the issue of un-cleared caches returning erroneous values.
Alternatively the curve caching as a feature can be set to *False* in ``defaults``.
"""
self._cache: dict[float, DualTypes] = dict()
self._state_id: int = hash(urandom(8)) # 64-bit entropy

def __hash__(self) -> int:
return self._state_id

def __iter__(self):
raise TypeError("`FXDeltaVolSmile` is not iterable.")

Expand Down Expand Up @@ -436,9 +459,7 @@ def csolve(self) -> None:

self.spline = Spline(4, self.t, None)
self.spline.csolve(tau, y, left_n, right_n, False)

# self._create_approx_spline_conversions(Spline)
return None
self.clear_cache()

# def _build_datatable(self):
# """
Expand Down Expand Up @@ -524,7 +545,7 @@ def csolve(self) -> None:
# self.spline_u_delta_approx.csolve(u, delta.tolist()[::-1], 0, 0, False)
# return None

def _set_ad_order(self, order: int):
def _set_ad_order(self, order: int) -> None:
if order == getattr(self, "ad", None):
return None
elif order not in [0, 1, 2]:
Expand All @@ -535,8 +556,7 @@ def _set_ad_order(self, order: int):
k: set_order_convert(v, order, [f"{self.id}{i}"])
for i, (k, v) in enumerate(self.nodes.items())
}
self.csolve()
return None
self.csolve() # also clears cache

def plot(
self,
Expand Down Expand Up @@ -625,6 +645,7 @@ def _set_node_vector(self, vector, ad):
*DualArgs[1:],
)
self.csolve()
self.clear_cache()

def _get_node_vector(self):
"""Get a 1d array of variables associated with nodes of this object updated by Solver"""
Expand Down
39 changes: 29 additions & 10 deletions python/tests/test_curves.py
Original file line number Diff line number Diff line change
Expand Up @@ -1244,11 +1244,30 @@ def test_cache_id_update(self, method, args):
},
id="sofr",
)
original = curve._cache_id
original = hash(curve)
getattr(curve, method)(*args)
new = curve._cache_id
new = hash(curve)
assert new != original

def test_csolve_clear_cache(self):
c = Curve(
nodes={dt(2000, 1, 1): 1.0, dt(2002, 1, 1): 0.99},
t=[
dt(2000, 1, 1),
dt(2000, 1, 1),
dt(2000, 1, 1),
dt(2000, 1, 1),
dt(2002, 1, 1),
dt(2002, 1, 1),
dt(2002, 1, 1),
dt(2002, 1, 1),
],
)
before = hash(c)
c.csolve()
after = hash(c)
assert before != after


class TestLineCurve:
def test_repr(self):
Expand Down Expand Up @@ -1287,9 +1306,9 @@ def test_cache_id_update(self, method, args):
},
id="sofr",
)
original = curve._cache_id
original = hash(curve)
getattr(curve, method)(*args)
new = curve._cache_id
new = hash(curve)
assert new != original


Expand Down Expand Up @@ -1354,7 +1373,7 @@ def test_typing_as_curve(self):
@pytest.mark.parametrize(
("method", "args"), [("clear_cache", tuple()), ("_set_node_vector", ([0.99], 1))]
)
def test_cache_id_update(self, method, args):
def test_state_id_update(self, method, args):
curve = IndexCurve(
nodes={
dt(2022, 1, 1): 1.0,
Expand All @@ -1363,9 +1382,9 @@ def test_cache_id_update(self, method, args):
id="sofr",
index_base=200.0,
)
original = curve._cache_id
original = hash(curve)
getattr(curve, method)(*args)
new = curve._cache_id
new = hash(curve)
assert new != original


Expand Down Expand Up @@ -2062,10 +2081,10 @@ def test_cache_is_validated_on_getitem(self):
curve = fxf.curve("cad", "eur")
fxr1.update({"usdeur": 100000000.0})
fxf.curve("eur", "eur")._set_node_vector([0.5], 1)
prior_id = fxf._cache_id
before = hash(fxf)
curve[dt(2022, 1, 9)]
new_id = fxf._cache_id
assert prior_id != new_id
after = hash(fxf)
assert before != after


class TestPlotCurve:
Expand Down
16 changes: 8 additions & 8 deletions python/tests/test_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1252,7 +1252,7 @@ class TestFXForwards:
("to_json", tuple()),
],
)
def test_cache_id_update_on_fxr_update(self, method, args):
def test_hash_update_on_fxr_update(self, method, args):
# test validate cache works correctly on various methods after FXRates update
fxr1 = FXRates({"eurusd": 1.05}, settlement=dt(2022, 1, 3))
fxr2 = FXRates({"usdcad": 1.1}, settlement=dt(2022, 1, 2))
Expand All @@ -1270,15 +1270,15 @@ def test_cache_id_update_on_fxr_update(self, method, args):
},
)

original_id = fxf._cache_id
before = hash(fxf)
getattr(fxf, method)(*args)
# no cache update is necessary
assert original_id == fxf._cache_id
assert before == hash(fxf)

fxr1.update({"eurusd": 2.0})
getattr(fxf, method)(*args)
# cache update should have occurred
assert original_id != fxf._cache_id
assert before != hash(fxf)

@pytest.mark.parametrize(
("method", "args"),
Expand All @@ -1291,7 +1291,7 @@ def test_cache_id_update_on_fxr_update(self, method, args):
("to_json", tuple()),
],
)
def test_cache_id_update_on_curve_update(self, method, args):
def test_hash_update_on_curve_update(self, method, args):
# test validate cache works correctly on various methods after Curve update
fxr1 = FXRates({"eurusd": 1.05}, settlement=dt(2022, 1, 3))
fxr2 = FXRates({"usdcad": 1.1}, settlement=dt(2022, 1, 2))
Expand All @@ -1309,12 +1309,12 @@ def test_cache_id_update_on_curve_update(self, method, args):
},
)

original_id = fxf._cache_id
before = hash(fxf)
getattr(fxf, method)(*args)
# no cache update is necessary
assert original_id == fxf._cache_id
assert before == hash(fxf)

fxf.curve("eur", "eur")._set_node_vector([0.998], 1)
getattr(fxf, method)(*args)
# cache update should have occurred
assert original_id != fxf._cache_id
assert before != hash(fxf)
27 changes: 27 additions & 0 deletions python/tests/test_fx_volatility.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,33 @@ def test_iter_raises(self) -> None:
with pytest.raises(TypeError, match="`FXDeltaVolSmile` is not iterable."):
fxvs.__iter__()

def test_hash(self):
fxvs = FXDeltaVolSmile(
nodes={0.25: 10.0, 0.5: 10.0, 0.75: 11.0},
delta_type="forward",
eval_date=dt(2023, 3, 16),
expiry=dt(2023, 6, 16),
id="vol",
)
assert hash(fxvs) == fxvs._state_id

@pytest.mark.parametrize(
("method", "args"),
[("csolve", tuple()), ("_set_ad_order", (1,)), ("_set_node_vector", ([9.9, 9.8, 9.9], 1))],
)
def test_clear_cache(self, method, args):
fxvs = FXDeltaVolSmile(
nodes={0.25: 10.0, 0.5: 10.0, 0.75: 11.0},
delta_type="forward",
eval_date=dt(2023, 3, 16),
expiry=dt(2023, 6, 16),
id="vol",
)
before = hash(fxvs)
getattr(fxvs, method)(*args)
after = hash(fxvs)
assert before != after


class TestFXDeltaVolSurface:
def test_expiry_before_eval(self) -> None:
Expand Down

0 comments on commit 7970e69

Please sign in to comment.