diff --git a/jaxtyping/__init__.py b/jaxtyping/__init__.py index 2c3fd73..e799f29 100644 --- a/jaxtyping/__init__.py +++ b/jaxtyping/__init__.py @@ -19,6 +19,7 @@ import importlib.metadata import typing +import warnings # First import some things as normal from ._array_types import ( @@ -196,4 +197,20 @@ class PRNGKeyArray: del has_jax +try: + import equinox +except Exception: + pass +else: + check_equinox_version = True # easy-to-replace line with copybara + if check_equinox_version: + major, minor, _ = equinox.__version__ + equinox_version = (int(major), int(minor)) + if equinox_version < (0, 11): + warnings.warn( + "jaxtyping version >=0.2.23 should be used with Equinox version " + ">=0.11.0" + ) + + __version__ = importlib.metadata.version("jaxtyping") diff --git a/jaxtyping/_decorator.py b/jaxtyping/_decorator.py index a0c4f13..d443885 100644 --- a/jaxtyping/_decorator.py +++ b/jaxtyping/_decorator.py @@ -72,7 +72,7 @@ def batch_outer_product(x: Float32[Array, "b c1"], then the old one is returned to. For example, this means you could leave off the `@jaxtyped` decorator to enforce - that this function use the same axes sizes as the function it was called from. + that this function use the same axis sizes as the function it was called from. Likewise, this means you can use `isinstance` checks inside a function body and have them contribute to the same collection of consistency checks performed @@ -134,7 +134,44 @@ def wrapped_fn(*args, **kwargs): return wrapped_fn +@jaxtyped +def _check_dataclass_annotations(self, typechecker): + for field in dataclasses.fields(self): + for kls in self.__class__.__mro__: + try: + annotation = kls.__annotations__[field.name] + except KeyError: + pass + else: + break + else: + raise TypeError + try: + value = getattr(self, field.name) + except AttributeError: + continue # allow uninitialised fields, which are allowed on dataclasses + + @typechecker + def typecheck(x: annotation): + pass + + typecheck(value) + + def _jaxtyped_typechecker(typechecker): + """A decorator added by the import hook to all classes. Only affects dataclasses. + + Will be called as + ``` + @_jaxtyped_typechecker(beartype.beartype) + @dataclasses.dataclass + class SomeDataclass: + ... + ``` + + After initialisation, this will check that all fields of the dataclass match their + specified type annotation. + """ # typechecker is expected to probably be either `typeguard.typechecked`, or # `beartype.beartype`, or `None`. @@ -144,8 +181,32 @@ def _jaxtyped_typechecker(typechecker): def _wrapper(kls): assert inspect.isclass(kls) if dataclasses.is_dataclass(kls): - init = jaxtyped(typechecker(kls.__init__)) - kls.__init__ = init + # This does not check that the arguments passed to `__init__` match the + # type annotations. There may be a custom user `__init__`, or a + # dataclass-generated `__init__` used alongside + # `equinox.field(converter=...)` + + init = kls.__init__ + + @ft.wraps(init) + def __init__(self, *args, **kwargs): + init(self, *args, **kwargs) + # `kls.__init__` is late-binding to the `__init__` function that we're + # in now. (Or to someone else's monkey-patch.) Either way, this checks + # that we're in the "top-level" `__init__`, and not one that is being + # called via `super()`. We don't want to trigger too early, before all + # fields have been assigned. + # + # We're not checking `if self.__class__ is kls` because Equinox replaces + # the with a defrozen version of itself during `__init__`, so the check + # wouldn't trigger. + # + # We're not doing this check by adding it to the end of the metaclass + # `__call__`, because Python doesn't allow you monkey-patch metaclasses. + if self.__class__.__init__ is kls.__init__: + _check_dataclass_annotations(self, typechecker) + + kls.__init__ = __init__ return kls return _wrapper diff --git a/test/import_hook_tester.py b/test/import_hook_tester.py index e9f4369..8f5fa4b 100644 --- a/test/import_hook_tester.py +++ b/test/import_hook_tester.py @@ -25,10 +25,12 @@ from helpers import ParamError, ReturnError import jaxtyping -from jaxtyping import Float32 +from jaxtyping import Float32, Int +# # Test that functions get checked +# def g(x: Float32[jnp.ndarray, " b"]): @@ -39,21 +41,100 @@ def g(x: Float32[jnp.ndarray, " b"]): with pytest.raises(ParamError): g(jnp.array(1)) + +# # Test that Equinox modules get checked +# + + +# Dataclass `__init__`, no converter +class Mod1(eqx.Module): + foo: int + bar: Float32[jnp.ndarray, " a"] + + +Mod1(1, jnp.array([1.0])) +with pytest.raises(ParamError): + Mod1(1.0, jnp.array([1.0])) +with pytest.raises(ParamError): + Mod1(1, jnp.array(1.0)) + + +# Dataclass `__init__`, converter +class Mod2(eqx.Module): + a: jnp.ndarray = eqx.field(converter=jnp.asarray) + +Mod2(1) # This will fail unless we run typechecking after conversion -class M(eqx.Module): + +class BadMod2(eqx.Module): + a: jnp.ndarray = eqx.field(converter=lambda x: x) + + +with pytest.raises(ParamError): + BadMod2(1) +with pytest.raises(ParamError): + BadMod2("asdf") + +# Custom `__init__`, no converter +class Mod3(eqx.Module): foo: int bar: Float32[jnp.ndarray, " a"] + def __init__(self, foo: str, bar: Float32[jnp.ndarray, " a"]): + self.foo = int(foo) + self.bar = bar + -M(1, jnp.array([1.0])) +Mod3("1", jnp.array([1.0])) with pytest.raises(ParamError): - M(1.0, jnp.array([1.0])) + Mod3(1, jnp.array([1.0])) with pytest.raises(ParamError): - M(1, jnp.array(1.0)) + Mod3("1", jnp.array(1.0)) + +# Custom `__init__`, converter +class Mod4(eqx.Module): + a: Int[jnp.ndarray, ""] = eqx.field(converter=jnp.asarray) + + def __init__(self, a: str): + self.a = int(a) + + +Mod4("1") # This will fail unless we run typechecking after conversion + + +# Custom `__post_init__`, no converter +class Mod5(eqx.Module): + foo: int + bar: Float32[jnp.ndarray, " a"] + + def __post_init__(self): + pass + + +Mod5(1, jnp.array([1.0])) +with pytest.raises(ParamError): + Mod5(1.0, jnp.array([1.0])) +with pytest.raises(ParamError): + Mod5(1, jnp.array(1.0)) + + +# Dataclass `__init__`, converter +class Mod6(eqx.Module): + a: jnp.ndarray = eqx.field(converter=jnp.asarray) + + def __post_init__(self): + pass + + +Mod6(1) # This will fail unless we run typechecking after conversion + + +# # Test that dataclasses get checked +# @dataclasses.dataclass @@ -68,7 +149,10 @@ class D: with pytest.raises(ParamError): D(1, jnp.array(1.0)) + +# # Test that methods get checked +# class N(eqx.Module): @@ -93,17 +177,35 @@ def bar(self) -> jnp.ndarray: with pytest.raises(ReturnError): bad_n.bar() -# Test that converters work +# +# Test that we don't get called in `super()`. +# -class BadConverter(eqx.Module): - a: jnp.ndarray = eqx.field(converter=lambda x: x) +called = False + + +class Base(eqx.Module): + x: int + + def __init__(self): + self.x = "not an int" + global called + assert not called + called = True + + +class Derived(Base): + def __init__(self): + assert not called + super().__init__() + assert called + self.x = 2 + + +Derived() -with pytest.raises(ParamError): - BadConverter(1) -with pytest.raises(ParamError): - BadConverter("asdf") # Record that we've finished our checks successfully