From 4f02af8a99a9c23e6c0d7ec5ffc1aabe4156e432 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Wed, 27 Sep 2023 16:05:37 -0700 Subject: [PATCH] dataclasses now have fields checked, not __init__. Previously, using the import hook with dataclasses resulted in the `__init__` method of the dataclass being checked. This was undesirable when using `eqx.field(converter=...)`, as the annotation didn't necessarily reflect the argument type. A typical example was ```python class Foo(eqx.Module): x: jax.Array = eqx.field(converter=jnp.ndarray) Foo(1) # 1 is not an array! But this code is valid. ``` After this change, we instead monkey-patch our checks to happen at the end of the `__init__` of the dataclass -- after conversion has run. Note that this requires https://github.com/patrick-kidger/equinox/pull/524. Otherwise, Equinox does conversion too late (in `_ModuleMeta.__call__`, after `__init__` has been run). --- jaxtyping/__init__.py | 17 +++++ jaxtyping/_decorator.py | 74 ++++++++++++++++++- jaxtyping/_import_hook.py | 26 ++++++- test/import_hook_tester.py | 145 ++++++++++++++++++++++++++++++++++--- 4 files changed, 246 insertions(+), 16 deletions(-) diff --git a/jaxtyping/__init__.py b/jaxtyping/__init__.py index 2c3fd73..0e6092c 100644 --- a/jaxtyping/__init__.py +++ b/jaxtyping/__init__.py @@ -19,6 +19,7 @@ import importlib.metadata import typing +import warnings # First import some things as normal from ._array_types import ( @@ -196,4 +197,20 @@ class PRNGKeyArray: del has_jax +check_equinox_version = True # easy-to-replace line with copybara +if check_equinox_version: + try: + eqx_version = importlib.metadata.version("equinox") + except importlib.metadata.PackageNotFoundError: + pass + else: + major, minor, patch = eqx_version.split(".") + equinox_version = (int(major), int(minor), int(patch)) + if equinox_version < (0, 11, 1): + warnings.warn( + "jaxtyping version >=0.2.23 should be used with Equinox version " + ">=0.11.1" + ) + + __version__ = importlib.metadata.version("jaxtyping") diff --git a/jaxtyping/_decorator.py b/jaxtyping/_decorator.py index a0c4f13..cdf8528 100644 --- a/jaxtyping/_decorator.py +++ b/jaxtyping/_decorator.py @@ -72,7 +72,7 @@ def batch_outer_product(x: Float32[Array, "b c1"], then the old one is returned to. For example, this means you could leave off the `@jaxtyped` decorator to enforce - that this function use the same axes sizes as the function it was called from. + that this function use the same axis sizes as the function it was called from. Likewise, this means you can use `isinstance` checks inside a function body and have them contribute to the same collection of consistency checks performed @@ -134,7 +134,51 @@ def wrapped_fn(*args, **kwargs): return wrapped_fn +@jaxtyped +def _check_dataclass_annotations(self, typechecker): + for field in dataclasses.fields(self): + for kls in self.__class__.__mro__: + try: + annotation = kls.__annotations__[field.name] + except KeyError: + pass + else: + break + else: + raise TypeError + if isinstance(annotation, str): + # Don't support stringified annotations. These are basically impossible to + # resolve correctly, so just skip them. + # This does mean that annotations like `type["Foo"]` will just fail. There + # doesn't seem to be any way to even detect a partially-stringified + # annotation. + continue + try: + value = getattr(self, field.name) + except AttributeError: + continue # allow uninitialised fields, which are allowed on dataclasses + + @typechecker + def typecheck(x: annotation): + pass + + typecheck(value) + + def _jaxtyped_typechecker(typechecker): + """A decorator added by the import hook to all classes. Only affects dataclasses. + + Will be called as + ``` + @_jaxtyped_typechecker(beartype.beartype) + @dataclasses.dataclass + class SomeDataclass: + ... + ``` + + After initialisation, this will check that all fields of the dataclass match their + specified type annotation. + """ # typechecker is expected to probably be either `typeguard.typechecked`, or # `beartype.beartype`, or `None`. @@ -144,8 +188,32 @@ def _jaxtyped_typechecker(typechecker): def _wrapper(kls): assert inspect.isclass(kls) if dataclasses.is_dataclass(kls): - init = jaxtyped(typechecker(kls.__init__)) - kls.__init__ = init + # This does not check that the arguments passed to `__init__` match the + # type annotations. There may be a custom user `__init__`, or a + # dataclass-generated `__init__` used alongside + # `equinox.field(converter=...)` + + init = kls.__init__ + + @ft.wraps(init) + def __init__(self, *args, **kwargs): + init(self, *args, **kwargs) + # `kls.__init__` is late-binding to the `__init__` function that we're + # in now. (Or to someone else's monkey-patch.) Either way, this checks + # that we're in the "top-level" `__init__`, and not one that is being + # called via `super()`. We don't want to trigger too early, before all + # fields have been assigned. + # + # We're not checking `if self.__class__ is kls` because Equinox replaces + # the with a defrozen version of itself during `__init__`, so the check + # wouldn't trigger. + # + # We're not doing this check by adding it to the end of the metaclass + # `__call__`, because Python doesn't allow you monkey-patch metaclasses. + if self.__class__.__init__ is kls.__init__: + _check_dataclass_annotations(self, typechecker) + + kls.__init__ = __init__ return kls return _wrapper diff --git a/jaxtyping/_import_hook.py b/jaxtyping/_import_hook.py index 0aec718..c6e490d 100644 --- a/jaxtyping/_import_hook.py +++ b/jaxtyping/_import_hook.py @@ -291,8 +291,8 @@ def install_import_hook(modules: Union[str, Sequence[str]], typechecker: Optiona install_import_hook(["foo", "bar.baz"], ...) ``` - The import hook will automatically decorate all functions, and the `__init__` method - of dataclasses. + The import hook will automatically decorate all functions, and check the attributes + assigned to dataclasses. If the function already has any decorators on it, then both the `@jaxtyped` and the typechecker decorators will get added at the bottom of the decorator list, e.g. @@ -366,6 +366,28 @@ def f(x: Float32[Array, "batch channels"]): (This is the author's preferred approach to performing runtime type-checking with jaxtyping!) + + !!! warning + + Stringified dataclass annotations, e.g. + ```python + @dataclass() + class Foo: + x: "int" + ``` + will be silently skipped without checking them. This is because these are + essentially impossible to resolve at runtime. Such stringified annotations + typically occur either when using them for forward references, or when using + `from __future__ import annotations`. (You should never use the latter, it is + largely incompatible with runtime type checking.) + + Partially stringified dataclass annotations, e.g. + ```python + @dataclass() + class Foo: + x: tuple["int"] + ``` + will likely raise an error, and must not be used at all. """ # noqa: E501 if isinstance(modules, str): diff --git a/test/import_hook_tester.py b/test/import_hook_tester.py index e9f4369..42d488b 100644 --- a/test/import_hook_tester.py +++ b/test/import_hook_tester.py @@ -25,10 +25,12 @@ from helpers import ParamError, ReturnError import jaxtyping -from jaxtyping import Float32 +from jaxtyping import Float32, Int +# # Test that functions get checked +# def g(x: Float32[jnp.ndarray, " b"]): @@ -39,21 +41,101 @@ def g(x: Float32[jnp.ndarray, " b"]): with pytest.raises(ParamError): g(jnp.array(1)) + +# # Test that Equinox modules get checked +# -class M(eqx.Module): +# Dataclass `__init__`, no converter +class Mod1(eqx.Module): foo: int bar: Float32[jnp.ndarray, " a"] -M(1, jnp.array([1.0])) +Mod1(1, jnp.array([1.0])) with pytest.raises(ParamError): - M(1.0, jnp.array([1.0])) + Mod1(1.0, jnp.array([1.0])) with pytest.raises(ParamError): - M(1, jnp.array(1.0)) + Mod1(1, jnp.array(1.0)) + + +# Dataclass `__init__`, converter +class Mod2(eqx.Module): + a: jnp.ndarray = eqx.field(converter=jnp.asarray) + + +Mod2(1) # This will fail unless we run typechecking after conversion + + +class BadMod2(eqx.Module): + a: jnp.ndarray = eqx.field(converter=lambda x: x) + + +with pytest.raises(ParamError): + BadMod2(1) +with pytest.raises(ParamError): + BadMod2("asdf") + + +# Custom `__init__`, no converter +class Mod3(eqx.Module): + foo: int + bar: Float32[jnp.ndarray, " a"] + + def __init__(self, foo: str, bar: Float32[jnp.ndarray, " a"]): + self.foo = int(foo) + self.bar = bar + + +Mod3("1", jnp.array([1.0])) +with pytest.raises(ParamError): + Mod3(1, jnp.array([1.0])) +with pytest.raises(ParamError): + Mod3("1", jnp.array(1.0)) + + +# Custom `__init__`, converter +class Mod4(eqx.Module): + a: Int[jnp.ndarray, ""] = eqx.field(converter=jnp.asarray) + + def __init__(self, a: str): + self.a = int(a) + + +Mod4("1") # This will fail unless we run typechecking after conversion + + +# Custom `__post_init__`, no converter +class Mod5(eqx.Module): + foo: int + bar: Float32[jnp.ndarray, " a"] + + def __post_init__(self): + pass + +Mod5(1, jnp.array([1.0])) +with pytest.raises(ParamError): + Mod5(1.0, jnp.array([1.0])) +with pytest.raises(ParamError): + Mod5(1, jnp.array(1.0)) + + +# Dataclass `__init__`, converter +class Mod6(eqx.Module): + a: jnp.ndarray = eqx.field(converter=jnp.asarray) + + def __post_init__(self): + pass + + +Mod6(1) # This will fail unless we run typechecking after conversion + + +# # Test that dataclasses get checked +# @dataclasses.dataclass @@ -68,7 +150,10 @@ class D: with pytest.raises(ParamError): D(1, jnp.array(1.0)) + +# # Test that methods get checked +# class N(eqx.Module): @@ -93,17 +178,55 @@ def bar(self) -> jnp.ndarray: with pytest.raises(ReturnError): bad_n.bar() -# Test that converters work +# +# Test that we don't get called in `super()`. +# -class BadConverter(eqx.Module): - a: jnp.ndarray = eqx.field(converter=lambda x: x) +called = False + + +class Base(eqx.Module): + x: int + + def __init__(self): + self.x = "not an int" + global called + assert not called + called = True + + +class Derived(Base): + def __init__(self): + assert not called + super().__init__() + assert called + self.x = 2 + + +Derived() + + +# +# Test that stringified type annotations work + + +class Foo: + pass + + +class Bar(eqx.Module): + x: type[Foo] + y: "type[Foo]" + # We deliberately don't test partially stringified annotations, like type["Foo"], + # as these are unsupported. + + +Bar(Foo, Foo) with pytest.raises(ParamError): - BadConverter(1) -with pytest.raises(ParamError): - BadConverter("asdf") + Bar(1, Foo) # Record that we've finished our checks successfully