Skip to content

Commit

Permalink
TYP: improve period typing (#557)
Browse files Browse the repository at this point in the history
Co-authored-by: JHM Darbyshire (M1) <[email protected]>
  • Loading branch information
attack68 and attack68 authored Dec 15, 2024
1 parent f090fc9 commit c03e16a
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 63 deletions.
4 changes: 2 additions & 2 deletions python/rateslib/instruments/rates_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
ZeroIndexLeg,
)
from rateslib.periods import (
_disc_from_curve,
_disc_required_maybe_from_curve,
_get_fx_and_base,
_maybe_local,
_trim_df_by_index,
Expand Down Expand Up @@ -2423,7 +2423,7 @@ def analytic_delta(
For arguments see :meth:`~rateslib.periods.BasePeriod.analytic_delta`.
"""
disc_curve_: Curve = _disc_from_curve(curve, disc_curve)
disc_curve_: Curve = _disc_required_maybe_from_curve(curve, disc_curve)
fx, base = _get_fx_and_base(self.leg1.currency, fx, base)
rate = self.rate([curve])
_ = self.leg1.notional * self.leg1.periods[0].dcf * disc_curve_[self._payment_date] / 10000
Expand Down
6 changes: 3 additions & 3 deletions python/rateslib/legs.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@
IndexCashflow,
IndexFixedPeriod,
IndexMixin,
_disc_from_curve,
_disc_maybe_from_curve,
_disc_required_maybe_from_curve,
_get_fx_and_base,
_validate_float_args,
)
Expand Down Expand Up @@ -1274,14 +1274,14 @@ def npv(
fx: float | FXRates | FXForwards | NoInput = NoInput(0),
base: str | NoInput = NoInput(0),
local: bool = False,
):
) -> dict[str, DualTypes] | DualTypes:
"""
Return the NPV of the *ZeroFloatLeg* via summing all periods.
For arguments see
:meth:`BasePeriod.npv()<rateslib.periods.BasePeriod.npv>`.
"""
disc_curve_: Curve = _disc_from_curve(curve, disc_curve)
disc_curve_: Curve = _disc_required_maybe_from_curve(curve, disc_curve)
fx, base = _get_fx_and_base(self.currency, fx, base)
value = (
self.rate(curve)
Expand Down
113 changes: 55 additions & 58 deletions python/rateslib/periods.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@

def _get_fx_and_base(
currency: str,
fx: float | FXRates | FXForwards | NoInput = NoInput(0),
fx: DualTypes | FXRates | FXForwards | NoInput = NoInput(0),
base: str | NoInput = NoInput(0),
) -> tuple[DualTypes, str | NoInput]:
"""
Expand Down Expand Up @@ -169,6 +169,20 @@ def _disc_maybe_from_curve(
return _


def _disc_required_maybe_from_curve(
curve: Curve | NoInput | dict[str, Curve],
disc_curve: Curve | NoInput,
) -> Curve:
"""Return a discount curve, pointed as the `curve` if not provided and if suitable Type."""
_: Curve | NoInput = _disc_maybe_from_curve(curve, disc_curve)
if isinstance(_, NoInput):
raise TypeError(
"`curves` have not been supplied correctly. "
"A `disc_curve` is required to perform function."
)
return _


class BasePeriod(metaclass=ABCMeta):
"""
Abstract base class with common parameters for all ``Period`` subclasses.
Expand Down Expand Up @@ -225,9 +239,9 @@ def __init__(
raise ValueError("`end` cannot be before `start`.")
self.start, self.end, self.payment = start, end, payment
self.frequency = frequency.upper()
self.notional = defaults.notional if notional is NoInput.blank else notional
self.currency = defaults.base_currency if currency is NoInput.blank else currency.lower()
self.convention = defaults.convention if convention is NoInput.blank else convention
self.notional = _drb(defaults.notional, notional)
self.currency = _drb(defaults.base_currency, currency).lower()
self.convention = _drb(defaults.convention, convention)
self.termination = termination
self.freq_months = defaults.frequency_months[self.frequency]
self.stub = stub
Expand Down Expand Up @@ -313,17 +327,17 @@ def analytic_delta(
period.analytic_delta(curve, curve, fxr)
period.analytic_delta(curve, curve, fxr, "gbp")
""" # noqa: E501
disc_curve_: Curve | NoInput = _disc_maybe_from_curve(curve, disc_curve)
fx, base = _get_fx_and_base(self.currency, fx, base)
_ = fx * self.notional * self.dcf * disc_curve_[self.payment] / 10000
return _
disc_curve_: Curve = _disc_required_maybe_from_curve(curve, disc_curve)
fx_, _ = _get_fx_and_base(self.currency, fx, base)
ret: DualTypes = fx_ * self.notional * self.dcf * disc_curve_[self.payment] / 10000
return ret

@abstractmethod
def cashflows(
self,
curve: Curve | NoInput = NoInput(0),
disc_curve: Curve | NoInput = NoInput(0),
fx: float | FXRates | FXForwards | NoInput = NoInput(0),
fx: DualTypes | FXRates | FXForwards | NoInput = NoInput(0),
base: str | NoInput = NoInput(0),
) -> dict[str, Any]:
"""
Expand Down Expand Up @@ -497,14 +511,15 @@ def analytic_delta(self, *args: Any, **kwargs: Any) -> DualTypes:
return super().analytic_delta(*args, **kwargs)

@property
def cashflow(self) -> float | None:
def cashflow(self) -> DualTypes | None:
"""
float, Dual or Dual2 : The calculated value from rate, dcf and notional.
"""
if isinstance(self.fixed_rate, NoInput):
return None
else:
return -self.notional * self.dcf * self.fixed_rate / 100
_: DualTypes = -self.notional * self.dcf * self.fixed_rate / 100
return _

# Licence: Creative Commons - Attribution-NonCommercial-NoDerivatives 4.0 International
# Commercial use of this code, and/or copying and redistribution is prohibited.
Expand All @@ -517,21 +532,17 @@ def npv(
fx: float | FXRates | FXForwards | NoInput = NoInput(0),
base: str | NoInput = NoInput(0),
local: bool = False,
) -> DualTypes:
) -> dict[str, DualTypes] | DualTypes:
"""
Return the NPV of the *FixedPeriod*.
See :meth:`BasePeriod.npv()<rateslib.periods.BasePeriod.npv>`
"""
disc_curve_: Curve | NoInput = _disc_maybe_from_curve(curve, disc_curve)
disc_curve_: Curve = _disc_required_maybe_from_curve(curve, disc_curve)
try:
value: DualTypes = self.cashflow * disc_curve_[self.payment] # type: ignore[operator, index]
value: DualTypes = self.cashflow * disc_curve_[self.payment] # type: ignore[operator]
except TypeError as e:
# either fixed rate is None or curve is None hence mypy error
if isinstance(disc_curve_, NoInput):
raise TypeError(
"`curves` have not been supplied correctly. `disc_curve` not found."
)
elif isinstance(self.fixed_rate, NoInput):
# either fixed rate is None
if isinstance(self.fixed_rate, NoInput):
raise TypeError("`fixed_rate` must be set on the Period for an `npv`.")
else:
raise e
Expand All @@ -541,31 +552,32 @@ def cashflows(
self,
curve: Curve | NoInput = NoInput(0),
disc_curve: Curve | NoInput = NoInput(0),
fx: float | FXRates | FXForwards | NoInput = NoInput(0),
fx: DualTypes | FXRates | FXForwards | NoInput = NoInput(0),
base: str | NoInput = NoInput(0),
) -> dict[str, Any]:
"""
Return the cashflows of the *FixedPeriod*.
See :meth:`BasePeriod.cashflows()<rateslib.periods.BasePeriod.cashflows>`
"""
disc_curve_: Curve | NoInput = _disc_maybe_from_curve(curve, disc_curve)
fx, base = _get_fx_and_base(self.currency, fx, base)
fx_, base_ = _get_fx_and_base(self.currency, fx, base)

if disc_curve_ is NoInput.blank or self.fixed_rate is NoInput.blank:
if isinstance(disc_curve_, NoInput) or isinstance(self.fixed_rate, NoInput):
npv = None
npv_fx = None
else:
npv = float(self.npv(curve, disc_curve_))
npv_fx = npv * float(fx)
npv_dual: DualTypes = self.npv(curve, disc_curve_, local=False) # type: ignore[assignment]
npv = _dual_float(npv_dual)
npv_fx = npv * _dual_float(fx_)

cashflow = None if self.cashflow is None else float(self.cashflow)
cashflow = None if self.cashflow is None else _dual_float(self.cashflow)
return {
**super().cashflows(curve, disc_curve_, fx, base),
**super().cashflows(curve, disc_curve_, fx_, base_),
defaults.headers["rate"]: self.fixed_rate,
defaults.headers["spread"]: None,
defaults.headers["cashflow"]: cashflow,
defaults.headers["npv"]: npv,
defaults.headers["fx"]: float(fx),
defaults.headers["fx"]: _dual_float(fx_),
defaults.headers["npv_fx"]: npv_fx,
}

Expand All @@ -582,10 +594,7 @@ def _validate_float_args(
-------
tuple
"""
if fixing_method is NoInput.blank:
fixing_method_: str = defaults.fixing_method
else:
fixing_method_ = fixing_method.lower()
fixing_method_: str = _drb(defaults.fixing_method, fixing_method).lower()
if fixing_method_ not in [
"ibor",
"rfr_payment_delay",
Expand All @@ -603,10 +612,7 @@ def _validate_float_args(
f"got '{fixing_method_}'.",
)

if method_param is NoInput.blank:
method_param_: int = defaults.fixing_method_param[fixing_method_]
else:
method_param_ = method_param
method_param_: int = _drb(defaults.fixing_method_param[fixing_method_], method_param)
if method_param_ != 0 and fixing_method_ == "rfr_payment_delay":
raise ValueError(
"`method_param` should not be used (or a value other than 0) when "
Expand All @@ -619,10 +625,9 @@ def _validate_float_args(
f'`method_param` must be >0 for "rfr_lockout" `fixing_method`, ' f"got {method_param_}",
)

if spread_compound_method is NoInput.blank:
spread_compound_method_: str = defaults.spread_compound_method
else:
spread_compound_method_ = spread_compound_method.lower()
spread_compound_method_: str = _drb(
defaults.spread_compound_method, spread_compound_method
).lower()
if spread_compound_method_ not in [
"none_simple",
"isda_compounding",
Expand Down Expand Up @@ -940,15 +945,15 @@ def cashflows(
self,
curve: Curve | dict[str, Curve] | NoInput = NoInput(0),
disc_curve: Curve | NoInput = NoInput(0),
fx: float | FXRates | FXForwards | NoInput = NoInput(0),
fx: DualTypes | FXRates | FXForwards | NoInput = NoInput(0),
base: str | NoInput = NoInput(0),
) -> dict[str, Any]:
"""
Return the cashflows of the *FloatPeriod*.
See
:meth:`BasePeriod.cashflows()<rateslib.periods.BasePeriod.cashflows>`
"""
fx, base = _get_fx_and_base(self.currency, fx, base)
fx_, base_ = _get_fx_and_base(self.currency, fx, base)
disc_curve_: Curve | NoInput = _disc_maybe_from_curve(curve, disc_curve)

try:
Expand All @@ -962,19 +967,19 @@ def cashflows(
else:
rate = 100 * cashflow / (-self.notional * self.dcf)

if disc_curve_ is not NoInput.blank:
if not isinstance(disc_curve_, NoInput):
npv = self.npv(curve, disc_curve_)
npv_fx = npv * float(fx)
npv_fx = npv * _dual_float(fx_)
else:
npv, npv_fx = None, None

return {
**super().cashflows(curve, disc_curve_, fx, base),
**super().cashflows(curve, disc_curve_, fx_, base_),
defaults.headers["rate"]: _float_or_none(rate),
defaults.headers["spread"]: float(self.float_spread),
defaults.headers["spread"]: _dual_float(self.float_spread),
defaults.headers["cashflow"]: _float_or_none(cashflow),
defaults.headers["npv"]: _float_or_none(npv),
defaults.headers["fx"]: float(fx),
defaults.headers["fx"]: _dual_float(fx_),
defaults.headers["npv_fx"]: npv_fx,
}

Expand Down Expand Up @@ -2845,7 +2850,7 @@ def npv(
Return the cashflows of the *IndexPeriod*.
See :meth:`BasePeriod.npv()<rateslib.periods.BasePeriod.npv>`
"""
disc_curve_: Curve = _disc_from_curve(curve, disc_curve)
disc_curve_: Curve = _disc_required_maybe_from_curve(curve, disc_curve)
if not isinstance(disc_curve, Curve) and curve is NoInput.blank:
raise TypeError("`curves` have not been supplied correctly.")
value = self.cashflow(curve) * disc_curve_[self.payment]
Expand Down Expand Up @@ -4314,19 +4319,11 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)


def _float_or_none(val):
def _float_or_none(val: DualTypes | None) -> float | None:
if val is None:
return None
else:
return float(val)


def _disc_from_curve(curve: Curve, disc_curve: Curve | NoInput) -> Curve:
if disc_curve is NoInput.blank:
_: Curve = curve
else:
_ = disc_curve
return _
return _dual_float(val)


def _get_ibor_curve_from_dict(months, d):
Expand Down

0 comments on commit c03e16a

Please sign in to comment.