Skip to content

Commit

Permalink
dataclasses now have fields checked, not __init__.
Browse files Browse the repository at this point in the history
Previously, using the import hook with dataclasses resulted in the `__init__` method of the dataclass being checked.
This was undesirable when using `eqx.field(converter=...)`, as the annotation didn't necessarily reflect the argument type.
A typical example was
```python
class Foo(eqx.Module):
    x: jax.Array = eqx.field(converter=jnp.ndarray)

Foo(1)  # 1 is not an array! But this code is valid.
```

After this change, we instead monkey-patch our checks to happen at the end of the `__init__` of the dataclass -- after conversion has run.

Note that this requires patrick-kidger/equinox#524. Otherwise, Equinox does conversion too late (in `_ModuleMeta.__call__`, after `__init__` has been run).
  • Loading branch information
patrick-kidger committed Oct 5, 2023
1 parent ef102f4 commit 69e7272
Show file tree
Hide file tree
Showing 3 changed files with 196 additions and 15 deletions.
17 changes: 17 additions & 0 deletions jaxtyping/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import importlib.metadata
import typing
import warnings

# First import some things as normal
from ._array_types import (
Expand Down Expand Up @@ -196,4 +197,20 @@ class PRNGKeyArray:
del has_jax


check_equinox_version = True # easy-to-replace line with copybara
if check_equinox_version:
try:
eqx_version = importlib.metadata.version("equinox")
except importlib.metadata.PackageNotFoundError:
pass
else:
major, minor, _ = eqx_version.split(".")
equinox_version = (int(major), int(minor))
if equinox_version < (0, 11):
warnings.warn(
"jaxtyping version >=0.2.23 should be used with Equinox version "
">=0.11.0"
)


__version__ = importlib.metadata.version("jaxtyping")
67 changes: 64 additions & 3 deletions jaxtyping/_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def batch_outer_product(x: Float32[Array, "b c1"],
then the old one is returned to.
For example, this means you could leave off the `@jaxtyped` decorator to enforce
that this function use the same axes sizes as the function it was called from.
that this function use the same axis sizes as the function it was called from.
Likewise, this means you can use `isinstance` checks inside a function body
and have them contribute to the same collection of consistency checks performed
Expand Down Expand Up @@ -134,7 +134,44 @@ def wrapped_fn(*args, **kwargs):
return wrapped_fn


@jaxtyped
def _check_dataclass_annotations(self, typechecker):
for field in dataclasses.fields(self):
for kls in self.__class__.__mro__:
try:
annotation = kls.__annotations__[field.name]
except KeyError:
pass
else:
break
else:
raise TypeError
try:
value = getattr(self, field.name)
except AttributeError:
continue # allow uninitialised fields, which are allowed on dataclasses

@typechecker
def typecheck(x: annotation):
pass

typecheck(value)


def _jaxtyped_typechecker(typechecker):
"""A decorator added by the import hook to all classes. Only affects dataclasses.
Will be called as
```
@_jaxtyped_typechecker(beartype.beartype)
@dataclasses.dataclass
class SomeDataclass:
...
```
After initialisation, this will check that all fields of the dataclass match their
specified type annotation.
"""
# typechecker is expected to probably be either `typeguard.typechecked`, or
# `beartype.beartype`, or `None`.

Expand All @@ -144,8 +181,32 @@ def _jaxtyped_typechecker(typechecker):
def _wrapper(kls):
assert inspect.isclass(kls)
if dataclasses.is_dataclass(kls):
init = jaxtyped(typechecker(kls.__init__))
kls.__init__ = init
# This does not check that the arguments passed to `__init__` match the
# type annotations. There may be a custom user `__init__`, or a
# dataclass-generated `__init__` used alongside
# `equinox.field(converter=...)`

init = kls.__init__

@ft.wraps(init)
def __init__(self, *args, **kwargs):
init(self, *args, **kwargs)
# `kls.__init__` is late-binding to the `__init__` function that we're
# in now. (Or to someone else's monkey-patch.) Either way, this checks
# that we're in the "top-level" `__init__`, and not one that is being
# called via `super()`. We don't want to trigger too early, before all
# fields have been assigned.
#
# We're not checking `if self.__class__ is kls` because Equinox replaces
# the with a defrozen version of itself during `__init__`, so the check
# wouldn't trigger.
#
# We're not doing this check by adding it to the end of the metaclass
# `__call__`, because Python doesn't allow you monkey-patch metaclasses.
if self.__class__.__init__ is kls.__init__:
_check_dataclass_annotations(self, typechecker)

kls.__init__ = __init__
return kls

return _wrapper
127 changes: 115 additions & 12 deletions test/import_hook_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@
from helpers import ParamError, ReturnError

import jaxtyping
from jaxtyping import Float32
from jaxtyping import Float32, Int


#
# Test that functions get checked
#


def g(x: Float32[jnp.ndarray, " b"]):
Expand All @@ -39,21 +41,101 @@ def g(x: Float32[jnp.ndarray, " b"]):
with pytest.raises(ParamError):
g(jnp.array(1))


#
# Test that Equinox modules get checked
#


# Dataclass `__init__`, no converter
class Mod1(eqx.Module):
foo: int
bar: Float32[jnp.ndarray, " a"]


Mod1(1, jnp.array([1.0]))
with pytest.raises(ParamError):
Mod1(1.0, jnp.array([1.0]))
with pytest.raises(ParamError):
Mod1(1, jnp.array(1.0))


# Dataclass `__init__`, converter
class Mod2(eqx.Module):
a: jnp.ndarray = eqx.field(converter=jnp.asarray)


Mod2(1) # This will fail unless we run typechecking after conversion


class BadMod2(eqx.Module):
a: jnp.ndarray = eqx.field(converter=lambda x: x)


with pytest.raises(ParamError):
BadMod2(1)
with pytest.raises(ParamError):
BadMod2("asdf")


# Custom `__init__`, no converter
class Mod3(eqx.Module):
foo: int
bar: Float32[jnp.ndarray, " a"]

def __init__(self, foo: str, bar: Float32[jnp.ndarray, " a"]):
self.foo = int(foo)
self.bar = bar


Mod3("1", jnp.array([1.0]))
with pytest.raises(ParamError):
Mod3(1, jnp.array([1.0]))
with pytest.raises(ParamError):
Mod3("1", jnp.array(1.0))


# Custom `__init__`, converter
class Mod4(eqx.Module):
a: Int[jnp.ndarray, ""] = eqx.field(converter=jnp.asarray)

def __init__(self, a: str):
self.a = int(a)


class M(eqx.Module):
Mod4("1") # This will fail unless we run typechecking after conversion


# Custom `__post_init__`, no converter
class Mod5(eqx.Module):
foo: int
bar: Float32[jnp.ndarray, " a"]

def __post_init__(self):
pass


M(1, jnp.array([1.0]))
Mod5(1, jnp.array([1.0]))
with pytest.raises(ParamError):
M(1.0, jnp.array([1.0]))
Mod5(1.0, jnp.array([1.0]))
with pytest.raises(ParamError):
M(1, jnp.array(1.0))
Mod5(1, jnp.array(1.0))


# Dataclass `__init__`, converter
class Mod6(eqx.Module):
a: jnp.ndarray = eqx.field(converter=jnp.asarray)

def __post_init__(self):
pass


Mod6(1) # This will fail unless we run typechecking after conversion


#
# Test that dataclasses get checked
#


@dataclasses.dataclass
Expand All @@ -68,7 +150,10 @@ class D:
with pytest.raises(ParamError):
D(1, jnp.array(1.0))


#
# Test that methods get checked
#


class N(eqx.Module):
Expand All @@ -93,17 +178,35 @@ def bar(self) -> jnp.ndarray:
with pytest.raises(ReturnError):
bad_n.bar()

# Test that converters work

#
# Test that we don't get called in `super()`.
#

class BadConverter(eqx.Module):
a: jnp.ndarray = eqx.field(converter=lambda x: x)

called = False


class Base(eqx.Module):
x: int

def __init__(self):
self.x = "not an int"
global called
assert not called
called = True


class Derived(Base):
def __init__(self):
assert not called
super().__init__()
assert called
self.x = 2


Derived()

with pytest.raises(ParamError):
BadConverter(1)
with pytest.raises(ParamError):
BadConverter("asdf")

# Record that we've finished our checks successfully

Expand Down

0 comments on commit 69e7272

Please sign in to comment.