From d9a859d4377e023eb5fcb1390cd7cfa7d7a31233 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Fri, 22 Sep 2023 08:59:28 -0700 Subject: [PATCH] Fixed spurious error when accessing methods during __init__. Drive-by: improved pretty-printing of dataclasses with unintialised fields. --- equinox/_module.py | 100 +++++++++++++++++++++------------------ equinox/_pretty_print.py | 4 +- tests/test_module.py | 16 ++++++- 3 files changed, 72 insertions(+), 48 deletions(-) diff --git a/equinox/_module.py b/equinox/_module.py index e6b1e88f..a9e69085 100644 --- a/equinox/_module.py +++ b/equinox/_module.py @@ -293,9 +293,47 @@ def __init__(self, method): def __get__(self, instance, owner): if instance is None: return self.method - elif isinstance(instance, _Initable): - raise ValueError( - """Cannot assign methods in __init__. + else: + # Why `inplace=True`? + # This is safe because the `BoundMethod` was only instantiated here. + # This is necessary so that `_method.__self__ is instance`, which is used + # as part of a no-cycle check in `_make_initable`. + _method = _module_update_wrapper( + BoundMethod(self.method, instance), None, inplace=True + ) + return _method + + +_dummy_abstract = abc.abstractmethod(lambda self: 1) +_has_dataclass_init = weakref.WeakKeyDictionary() + + +_wrapper_field_names = { + "__module__", + "__name__", + "__qualname__", + "__doc__", + "__annotations__", +} + + +@ft.lru_cache(maxsize=128) +def _make_initable(cls: _ModuleMeta, wraps: bool) -> _ModuleMeta: + if wraps: + field_names = _wrapper_field_names + else: + field_names = { + field.name for field in dataclasses.fields(cls) # pyright: ignore + } + + class _InitableModule(cls): # pyright: ignore + pass + + def __setattr__(self, name, value): + if name in field_names: + if isinstance(value, BoundMethod) and value.__self__ is self: + raise ValueError( + """Cannot assign methods in __init__. That is, something like the following is not allowed: ``` @@ -326,49 +364,14 @@ def bar(self): This is a check that was introduced in Equinox v0.11.0. Before this, the above error went uncaught, possibly leading to silently wrong behaviour. """ - ) + ) + else: + object.__setattr__(self, name, value) else: - _method = module_update_wrapper(BoundMethod(self.method, instance)) - return _method - - -_dummy_abstract = abc.abstractmethod(lambda self: 1) -_has_dataclass_init = weakref.WeakKeyDictionary() - - -_wrapper_field_names = { - "__module__", - "__name__", - "__qualname__", - "__doc__", - "__annotations__", -} - - -class _Initable: - pass - - -@ft.lru_cache(maxsize=128) -def _make_initable(cls: _ModuleMeta, wraps: bool) -> _ModuleMeta: - if wraps: - field_names = _wrapper_field_names - else: - field_names = { - field.name for field in dataclasses.fields(cls) # pyright: ignore - } - - class _InitableModule(cls, _Initable): # pyright: ignore - pass + raise AttributeError(f"Cannot set attribute {name}") # Done like this to avoid dataclasses complaining about overriding setattr on a # frozen class. - def __setattr__(self, name, value): - if name in field_names: - object.__setattr__(self, name, value) - else: - raise AttributeError(f"Cannot set attribute {name}") - _InitableModule.__setattr__ = __setattr__ # Make beartype happy _InitableModule.__init__ = cls.__init__ # pyright: ignore @@ -637,6 +640,12 @@ def make_wrapper(fn): A copy of `wrapper`, with the attributes `__module__`, `__name__`, `__qualname__`, `__doc__`, and `__annotations__` copied over from the wrapped function. """ + return _module_update_wrapper(wrapper, wrapped, inplace=False) + + +def _module_update_wrapper( + wrapper: Module, wrapped: Optional[Callable[_P, _T]], inplace: bool +) -> Callable[_P, _T]: cls = wrapper.__class__ if not isinstance(getattr(cls, "__wrapped__", None), property): raise ValueError("Wrapper module must supply `__wrapped__` as a property.") @@ -644,9 +653,10 @@ def make_wrapper(fn): if wrapped is None: wrapped = wrapper.__wrapped__ # pyright: ignore - # Make a clone, to avoid mutating the original input. - leaves, treedef = jtu.tree_flatten(wrapper) - wrapper = jtu.tree_unflatten(treedef, leaves) + if not inplace: + # Make a clone, to avoid mutating the original input. + leaves, treedef = jtu.tree_flatten(wrapper) + wrapper = jtu.tree_unflatten(treedef, leaves) initable_cls = _make_initable(cls, wraps=True) object.__setattr__(wrapper, "__class__", initable_cls) diff --git a/equinox/_pretty_print.py b/equinox/_pretty_print.py index 76841c81..d922f921 100644 --- a/equinox/_pretty_print.py +++ b/equinox/_pretty_print.py @@ -10,6 +10,8 @@ import numpy as np from jaxtyping import PyTree +from ._doc_utils import WithRepr + Dataclass = Any NamedTuple = Any # workaround typeguard bug @@ -140,7 +142,7 @@ def _pformat_dataclass(obj, **kwargs) -> pp.Doc: # values to the module so the repr fails. objs = named_objs( [ - (field.name, getattr(obj, field.name, "")) + (field.name, getattr(obj, field.name, WithRepr(None, ""))) for field in dataclasses.fields(obj) if field.repr ], diff --git a/tests/test_module.py b/tests/test_module.py index cb983809..4d31282f 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -150,8 +150,9 @@ class A(eqx.Module): x: jax.Array def __init__(self, x): - # foo is assigned before x! Therefore - # self.bar.instance.x` doesn't exist yet. + # foo is assigned before x! We have that `self.bar.__self__` is a copy of + # `self`, but for which self.bar.__self__.x` doesn't exist yet. Then later + # calling `self.foo()` would raise an error. self.foo = self.bar self.x = x @@ -190,6 +191,17 @@ def _transform(self, x): SubComponent() +def test_method_access_during_init(): + class Foo(eqx.Module): + def __init__(self): + self.method() + + def method(self): + pass + + Foo() + + @pytest.mark.parametrize("new", (False, True)) def test_static_field(new): if new: