Skip to content

Commit

Permalink
dataclasses now have fields checked, not __init__.
Browse files Browse the repository at this point in the history
Previously, using the import hook with dataclasses resulted in the `__init__` method of the dataclass being checked.
This was undesirable when using `eqx.field(converter=...)`, as the annotation didn't necessarily reflect the argument type.
A typical example was
```python
class Foo(eqx.Module):
    x: jax.Array = eqx.field(converter=jnp.ndarray)

Foo(1)  # 1 is not an array! But this code is valid.
```

After this change, we instead monkey-patch our checks to happen at the end of the `__init__` of the dataclass -- after conversion has run.

Note that this requires patrick-kidger/equinox#524. Otherwise, Equinox does conversion too late (in `_ModuleMeta.__call__`, after `__init__` has been run).
  • Loading branch information
patrick-kidger committed Oct 5, 2023
1 parent ef102f4 commit 90263d0
Show file tree
Hide file tree
Showing 4 changed files with 256 additions and 16 deletions.
17 changes: 17 additions & 0 deletions jaxtyping/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import importlib.metadata
import typing
import warnings

# First import some things as normal
from ._array_types import (
Expand Down Expand Up @@ -196,4 +197,20 @@ class PRNGKeyArray:
del has_jax


check_equinox_version = True # easy-to-replace line with copybara
if check_equinox_version:
try:
eqx_version = importlib.metadata.version("equinox")
except importlib.metadata.PackageNotFoundError:
pass
else:
major, minor, patch = eqx_version.split(".")
equinox_version = (int(major), int(minor), int(patch))
if equinox_version < (0, 11, 0):
warnings.warn(
"jaxtyping version >=0.2.23 should be used with Equinox version "
">=0.11.1"
)


__version__ = importlib.metadata.version("jaxtyping")
83 changes: 80 additions & 3 deletions jaxtyping/_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import threading
import types
import weakref
from typing import get_args, get_origin


try:
Expand Down Expand Up @@ -72,7 +73,7 @@ def batch_outer_product(x: Float32[Array, "b c1"],
then the old one is returned to.
For example, this means you could leave off the `@jaxtyped` decorator to enforce
that this function use the same axes sizes as the function it was called from.
that this function use the same axis sizes as the function it was called from.
Likewise, this means you can use `isinstance` checks inside a function body
and have them contribute to the same collection of consistency checks performed
Expand Down Expand Up @@ -134,7 +135,59 @@ def wrapped_fn(*args, **kwargs):
return wrapped_fn


@jaxtyped
def _check_dataclass_annotations(self, typechecker):
for field in dataclasses.fields(self):
for kls in self.__class__.__mro__:
try:
annotation = kls.__annotations__[field.name]
except KeyError:
pass
else:
break
else:
raise TypeError
if isinstance(annotation, str):
# Don't support stringified annotations. These are basically impossible to
# resolve correctly, so just skip them.
# This does mean that annotations like `type["Foo"]` will just fail. There
# doesn't seem to be any way to even detect a partially-stringified
# annotation.
continue
if get_origin(annotation) is type:
args = get_args(annotation)
if len(args) == 1 and isinstance(args[0], str):
# We also special-case this one kind of partially-stringified type
# annotation, so as to support Equinox <v0.11.1.
# This was fixed in Equinox in
# https://github.com/patrick-kidger/equinox/pull/543
continue
try:
value = getattr(self, field.name)
except AttributeError:
continue # allow uninitialised fields, which are allowed on dataclasses

@typechecker
def typecheck(x: annotation):
pass

typecheck(value)


def _jaxtyped_typechecker(typechecker):
"""A decorator added by the import hook to all classes. Only affects dataclasses.
Will be called as
```
@_jaxtyped_typechecker(beartype.beartype)
@dataclasses.dataclass
class SomeDataclass:
...
```
After initialisation, this will check that all fields of the dataclass match their
specified type annotation.
"""
# typechecker is expected to probably be either `typeguard.typechecked`, or
# `beartype.beartype`, or `None`.

Expand All @@ -144,8 +197,32 @@ def _jaxtyped_typechecker(typechecker):
def _wrapper(kls):
assert inspect.isclass(kls)
if dataclasses.is_dataclass(kls):
init = jaxtyped(typechecker(kls.__init__))
kls.__init__ = init
# This does not check that the arguments passed to `__init__` match the
# type annotations. There may be a custom user `__init__`, or a
# dataclass-generated `__init__` used alongside
# `equinox.field(converter=...)`

init = kls.__init__

@ft.wraps(init)
def __init__(self, *args, **kwargs):
init(self, *args, **kwargs)
# `kls.__init__` is late-binding to the `__init__` function that we're
# in now. (Or to someone else's monkey-patch.) Either way, this checks
# that we're in the "top-level" `__init__`, and not one that is being
# called via `super()`. We don't want to trigger too early, before all
# fields have been assigned.
#
# We're not checking `if self.__class__ is kls` because Equinox replaces
# the with a defrozen version of itself during `__init__`, so the check
# wouldn't trigger.
#
# We're not doing this check by adding it to the end of the metaclass
# `__call__`, because Python doesn't allow you monkey-patch metaclasses.
if self.__class__.__init__ is kls.__init__:
_check_dataclass_annotations(self, typechecker)

kls.__init__ = __init__
return kls

return _wrapper
26 changes: 24 additions & 2 deletions jaxtyping/_import_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,8 @@ def install_import_hook(modules: Union[str, Sequence[str]], typechecker: Optiona
install_import_hook(["foo", "bar.baz"], ...)
```
The import hook will automatically decorate all functions, and the `__init__` method
of dataclasses.
The import hook will automatically decorate all functions, and check the attributes
assigned to dataclasses.
If the function already has any decorators on it, then both the `@jaxtyped` and the
typechecker decorators will get added at the bottom of the decorator list, e.g.
Expand Down Expand Up @@ -366,6 +366,28 @@ def f(x: Float32[Array, "batch channels"]):
(This is the author's preferred approach to performing runtime type-checking
with jaxtyping!)
!!! warning
Stringified dataclass annotations, e.g.
```python
@dataclass()
class Foo:
x: "int"
```
will be silently skipped without checking them. This is because these are
essentially impossible to resolve at runtime. Such stringified annotations
typically occur either when using them for forward references, or when using
`from __future__ import annotations`. (You should never use the latter, it is
largely incompatible with runtime type checking.)
Partially stringified dataclass annotations, e.g.
```python
@dataclass()
class Foo:
x: tuple["int"]
```
will likely raise an error, and must not be used at all.
""" # noqa: E501

if isinstance(modules, str):
Expand Down
146 changes: 135 additions & 11 deletions test/import_hook_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@
from helpers import ParamError, ReturnError

import jaxtyping
from jaxtyping import Float32
from jaxtyping import Float32, Int


#
# Test that functions get checked
#


def g(x: Float32[jnp.ndarray, " b"]):
Expand All @@ -39,21 +41,101 @@ def g(x: Float32[jnp.ndarray, " b"]):
with pytest.raises(ParamError):
g(jnp.array(1))


#
# Test that Equinox modules get checked
#


class M(eqx.Module):
# Dataclass `__init__`, no converter
class Mod1(eqx.Module):
foo: int
bar: Float32[jnp.ndarray, " a"]


M(1, jnp.array([1.0]))
Mod1(1, jnp.array([1.0]))
with pytest.raises(ParamError):
M(1.0, jnp.array([1.0]))
Mod1(1.0, jnp.array([1.0]))
with pytest.raises(ParamError):
M(1, jnp.array(1.0))
Mod1(1, jnp.array(1.0))


# Dataclass `__init__`, converter
class Mod2(eqx.Module):
a: jnp.ndarray = eqx.field(converter=jnp.asarray)


Mod2(1) # This will fail unless we run typechecking after conversion


class BadMod2(eqx.Module):
a: jnp.ndarray = eqx.field(converter=lambda x: x)


with pytest.raises(ParamError):
BadMod2(1)
with pytest.raises(ParamError):
BadMod2("asdf")


# Custom `__init__`, no converter
class Mod3(eqx.Module):
foo: int
bar: Float32[jnp.ndarray, " a"]

def __init__(self, foo: str, bar: Float32[jnp.ndarray, " a"]):
self.foo = int(foo)
self.bar = bar


Mod3("1", jnp.array([1.0]))
with pytest.raises(ParamError):
Mod3(1, jnp.array([1.0]))
with pytest.raises(ParamError):
Mod3("1", jnp.array(1.0))


# Custom `__init__`, converter
class Mod4(eqx.Module):
a: Int[jnp.ndarray, ""] = eqx.field(converter=jnp.asarray)

def __init__(self, a: str):
self.a = int(a)


Mod4("1") # This will fail unless we run typechecking after conversion


# Custom `__post_init__`, no converter
class Mod5(eqx.Module):
foo: int
bar: Float32[jnp.ndarray, " a"]

def __post_init__(self):
pass


Mod5(1, jnp.array([1.0]))
with pytest.raises(ParamError):
Mod5(1.0, jnp.array([1.0]))
with pytest.raises(ParamError):
Mod5(1, jnp.array(1.0))


# Dataclass `__init__`, converter
class Mod6(eqx.Module):
a: jnp.ndarray = eqx.field(converter=jnp.asarray)

def __post_init__(self):
pass


Mod6(1) # This will fail unless we run typechecking after conversion


#
# Test that dataclasses get checked
#


@dataclasses.dataclass
Expand All @@ -68,7 +150,10 @@ class D:
with pytest.raises(ParamError):
D(1, jnp.array(1.0))


#
# Test that methods get checked
#


class N(eqx.Module):
Expand All @@ -93,17 +178,56 @@ def bar(self) -> jnp.ndarray:
with pytest.raises(ReturnError):
bad_n.bar()

# Test that converters work

#
# Test that we don't get called in `super()`.
#

class BadConverter(eqx.Module):
a: jnp.ndarray = eqx.field(converter=lambda x: x)

called = False


class Base(eqx.Module):
x: int

def __init__(self):
self.x = "not an int"
global called
assert not called
called = True


class Derived(Base):
def __init__(self):
assert not called
super().__init__()
assert called
self.x = 2


Derived()


#
# Test that stringified type annotations work


class Foo:
pass


class Bar(eqx.Module):
x: type[Foo]
y: "type[Foo]"
# Note that this is the *only* kind of partially-stringified type annotation that
# is supported. This is for compatibility with older Equinox versions.
z: type["Foo"]


Bar(Foo, Foo, Foo)

with pytest.raises(ParamError):
BadConverter(1)
with pytest.raises(ParamError):
BadConverter("asdf")
Bar(1, Foo, Foo)

# Record that we've finished our checks successfully

Expand Down

0 comments on commit 90263d0

Please sign in to comment.