Skip to content

Commit

Permalink
Fixed spurious error when accessing methods during __init__.
Browse files Browse the repository at this point in the history
Drive-by: improved pretty-printing of dataclasses with unintialised fields.
  • Loading branch information
patrick-kidger committed Sep 22, 2023
1 parent 9cc2760 commit e210188
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 48 deletions.
100 changes: 55 additions & 45 deletions equinox/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,9 +293,47 @@ def __init__(self, method):
def __get__(self, instance, owner):
if instance is None:
return self.method
elif isinstance(instance, _Initable):
raise ValueError(
"""Cannot assign methods in __init__.
else:
# Why `inplace=True`?
# This is safe because the `BoundMethod` was only instantiated here.
# This is necessary so that `_method.__self__ is instance`, which is used
# as part of a no-cycle check in `_make_initable`.
_method = _module_update_wrapper(
BoundMethod(self.method, instance), None, inplace=True
)
return _method


_dummy_abstract = abc.abstractmethod(lambda self: 1)
_has_dataclass_init = weakref.WeakKeyDictionary()


_wrapper_field_names = {
"__module__",
"__name__",
"__qualname__",
"__doc__",
"__annotations__",
}


@ft.lru_cache(maxsize=128)
def _make_initable(cls: _ModuleMeta, wraps: bool) -> _ModuleMeta:
if wraps:
field_names = _wrapper_field_names
else:
field_names = {
field.name for field in dataclasses.fields(cls) # pyright: ignore
}

class _InitableModule(cls): # pyright: ignore
pass

def __setattr__(self, name, value):
if name in field_names:
if isinstance(value, BoundMethod) and value.__self__ is self:
raise ValueError(
"""Cannot assign methods in __init__.
That is, something like the following is not allowed:
```
Expand Down Expand Up @@ -326,49 +364,14 @@ def bar(self):
This is a check that was introduced in Equinox v0.11.0. Before this, the above error
went uncaught, possibly leading to silently wrong behaviour.
"""
)
)
else:
object.__setattr__(self, name, value)
else:
_method = module_update_wrapper(BoundMethod(self.method, instance))
return _method


_dummy_abstract = abc.abstractmethod(lambda self: 1)
_has_dataclass_init = weakref.WeakKeyDictionary()


_wrapper_field_names = {
"__module__",
"__name__",
"__qualname__",
"__doc__",
"__annotations__",
}


class _Initable:
pass


@ft.lru_cache(maxsize=128)
def _make_initable(cls: _ModuleMeta, wraps: bool) -> _ModuleMeta:
if wraps:
field_names = _wrapper_field_names
else:
field_names = {
field.name for field in dataclasses.fields(cls) # pyright: ignore
}

class _InitableModule(cls, _Initable): # pyright: ignore
pass
raise AttributeError(f"Cannot set attribute {name}")

# Done like this to avoid dataclasses complaining about overriding setattr on a
# frozen class.
def __setattr__(self, name, value):
if name in field_names:
object.__setattr__(self, name, value)
else:
raise AttributeError(f"Cannot set attribute {name}")

_InitableModule.__setattr__ = __setattr__
# Make beartype happy
_InitableModule.__init__ = cls.__init__ # pyright: ignore
Expand Down Expand Up @@ -637,16 +640,23 @@ def make_wrapper(fn):
A copy of `wrapper`, with the attributes `__module__`, `__name__`, `__qualname__`,
`__doc__`, and `__annotations__` copied over from the wrapped function.
"""
return _module_update_wrapper(wrapper, wrapped, inplace=False)


def _module_update_wrapper(
wrapper: Module, wrapped: Optional[Callable[_P, _T]], inplace: bool
) -> Callable[_P, _T]:
cls = wrapper.__class__
if not isinstance(getattr(cls, "__wrapped__", None), property):
raise ValueError("Wrapper module must supply `__wrapped__` as a property.")

if wrapped is None:
wrapped = wrapper.__wrapped__ # pyright: ignore

# Make a clone, to avoid mutating the original input.
leaves, treedef = jtu.tree_flatten(wrapper)
wrapper = jtu.tree_unflatten(treedef, leaves)
if not inplace:
# Make a clone, to avoid mutating the original input.
leaves, treedef = jtu.tree_flatten(wrapper)
wrapper = jtu.tree_unflatten(treedef, leaves)

initable_cls = _make_initable(cls, wraps=True)
object.__setattr__(wrapper, "__class__", initable_cls)
Expand Down
4 changes: 3 additions & 1 deletion equinox/_pretty_print.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import numpy as np
from jaxtyping import PyTree

from ._doc_utils import WithRepr


Dataclass = Any
NamedTuple = Any # workaround typeguard bug
Expand Down Expand Up @@ -140,7 +142,7 @@ def _pformat_dataclass(obj, **kwargs) -> pp.Doc:
# values to the module so the repr fails.
objs = named_objs(
[
(field.name, getattr(obj, field.name, "<uninitialised>"))
(field.name, getattr(obj, field.name, WithRepr(None, "<uninitialised>")))
for field in dataclasses.fields(obj)
if field.repr
],
Expand Down
16 changes: 14 additions & 2 deletions tests/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,9 @@ class A(eqx.Module):
x: jax.Array

def __init__(self, x):
# foo is assigned before x! Therefore
# self.bar.instance.x` doesn't exist yet.
# foo is assigned before x! We have that `self.bar.__self__` is a copy of
# `self`, but for which self.bar.__self__.x` doesn't exist yet. Then later
# calling `self.foo()` would raise an error.
self.foo = self.bar
self.x = x

Expand Down Expand Up @@ -190,6 +191,17 @@ def _transform(self, x):
SubComponent()


def test_method_access_during_init():
class Foo(eqx.Module):
def __init__(self):
self.method()

def method(self):
pass

Foo()


@pytest.mark.parametrize("new", (False, True))
def test_static_field(new):
if new:
Expand Down

0 comments on commit e210188

Please sign in to comment.