Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed spurious error when accessing methods during __init__. #508

Merged
merged 1 commit into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 55 additions & 45 deletions equinox/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,9 +293,47 @@ def __init__(self, method):
def __get__(self, instance, owner):
if instance is None:
return self.method
elif isinstance(instance, _Initable):
raise ValueError(
"""Cannot assign methods in __init__.
else:
# Why `inplace=True`?
# This is safe because the `BoundMethod` was only instantiated here.
# This is necessary so that `_method.__self__ is instance`, which is used
# as part of a no-cycle check in `_make_initable`.
_method = _module_update_wrapper(
BoundMethod(self.method, instance), None, inplace=True
)
return _method


_dummy_abstract = abc.abstractmethod(lambda self: 1)
_has_dataclass_init = weakref.WeakKeyDictionary()


_wrapper_field_names = {
"__module__",
"__name__",
"__qualname__",
"__doc__",
"__annotations__",
}


@ft.lru_cache(maxsize=128)
def _make_initable(cls: _ModuleMeta, wraps: bool) -> _ModuleMeta:
if wraps:
field_names = _wrapper_field_names
else:
field_names = {
field.name for field in dataclasses.fields(cls) # pyright: ignore
}

class _InitableModule(cls): # pyright: ignore
pass

def __setattr__(self, name, value):
if name in field_names:
if isinstance(value, BoundMethod) and value.__self__ is self:
raise ValueError(
"""Cannot assign methods in __init__.

That is, something like the following is not allowed:
```
Expand Down Expand Up @@ -326,49 +364,14 @@ def bar(self):
This is a check that was introduced in Equinox v0.11.0. Before this, the above error
went uncaught, possibly leading to silently wrong behaviour.
"""
)
)
else:
object.__setattr__(self, name, value)
else:
_method = module_update_wrapper(BoundMethod(self.method, instance))
return _method


_dummy_abstract = abc.abstractmethod(lambda self: 1)
_has_dataclass_init = weakref.WeakKeyDictionary()


_wrapper_field_names = {
"__module__",
"__name__",
"__qualname__",
"__doc__",
"__annotations__",
}


class _Initable:
pass


@ft.lru_cache(maxsize=128)
def _make_initable(cls: _ModuleMeta, wraps: bool) -> _ModuleMeta:
if wraps:
field_names = _wrapper_field_names
else:
field_names = {
field.name for field in dataclasses.fields(cls) # pyright: ignore
}

class _InitableModule(cls, _Initable): # pyright: ignore
pass
raise AttributeError(f"Cannot set attribute {name}")

# Done like this to avoid dataclasses complaining about overriding setattr on a
# frozen class.
def __setattr__(self, name, value):
if name in field_names:
object.__setattr__(self, name, value)
else:
raise AttributeError(f"Cannot set attribute {name}")

_InitableModule.__setattr__ = __setattr__
# Make beartype happy
_InitableModule.__init__ = cls.__init__ # pyright: ignore
Expand Down Expand Up @@ -637,16 +640,23 @@ def make_wrapper(fn):
A copy of `wrapper`, with the attributes `__module__`, `__name__`, `__qualname__`,
`__doc__`, and `__annotations__` copied over from the wrapped function.
"""
return _module_update_wrapper(wrapper, wrapped, inplace=False)


def _module_update_wrapper(
wrapper: Module, wrapped: Optional[Callable[_P, _T]], inplace: bool
) -> Callable[_P, _T]:
cls = wrapper.__class__
if not isinstance(getattr(cls, "__wrapped__", None), property):
raise ValueError("Wrapper module must supply `__wrapped__` as a property.")

if wrapped is None:
wrapped = wrapper.__wrapped__ # pyright: ignore

# Make a clone, to avoid mutating the original input.
leaves, treedef = jtu.tree_flatten(wrapper)
wrapper = jtu.tree_unflatten(treedef, leaves)
if not inplace:
# Make a clone, to avoid mutating the original input.
leaves, treedef = jtu.tree_flatten(wrapper)
wrapper = jtu.tree_unflatten(treedef, leaves)

initable_cls = _make_initable(cls, wraps=True)
object.__setattr__(wrapper, "__class__", initable_cls)
Expand Down
4 changes: 3 additions & 1 deletion equinox/_pretty_print.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import numpy as np
from jaxtyping import PyTree

from ._doc_utils import WithRepr


Dataclass = Any
NamedTuple = Any # workaround typeguard bug
Expand Down Expand Up @@ -140,7 +142,7 @@ def _pformat_dataclass(obj, **kwargs) -> pp.Doc:
# values to the module so the repr fails.
objs = named_objs(
[
(field.name, getattr(obj, field.name, "<uninitialised>"))
(field.name, getattr(obj, field.name, WithRepr(None, "<uninitialised>")))
for field in dataclasses.fields(obj)
if field.repr
],
Expand Down
16 changes: 14 additions & 2 deletions tests/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,9 @@ class A(eqx.Module):
x: jax.Array

def __init__(self, x):
# foo is assigned before x! Therefore
# self.bar.instance.x` doesn't exist yet.
# foo is assigned before x! We have that `self.bar.__self__` is a copy of
# `self`, but for which self.bar.__self__.x` doesn't exist yet. Then later
# calling `self.foo()` would raise an error.
self.foo = self.bar
self.x = x

Expand Down Expand Up @@ -190,6 +191,17 @@ def _transform(self, x):
SubComponent()


def test_method_access_during_init():
class Foo(eqx.Module):
def __init__(self):
self.method()

def method(self):
pass

Foo()


@pytest.mark.parametrize("new", (False, True))
def test_static_field(new):
if new:
Expand Down