From f0eea2e2abf8f0d722da86dee9052305b70dd9bd Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Wed, 27 Sep 2023 15:34:58 -0700 Subject: [PATCH] Moved conversion to Foo.__init__ from MetaFoo.__call__. This is necessary to allow downstream libraries, like jaxtyping, to monkey-patch in their own checks. --- equinox/_module.py | 244 ++++++++++++++++++++++++++++--------------- tests/test_module.py | 84 ++++++++++++++- 2 files changed, 238 insertions(+), 90 deletions(-) diff --git a/equinox/_module.py b/equinox/_module.py index d4e1b822..29b53520 100644 --- a/equinox/_module.py +++ b/equinox/_module.py @@ -155,81 +155,164 @@ def __new__( for k, v in cls.__dict__.items(): if _not_magic(k) and inspect.isfunction(v): setattr(cls, k, _wrap_method(v)) - # [Step 3] Create a default `__init__` method if a user method isn't provided. + + # [Step 3] Handle initialisation. # - # If a superclass has a custom `__init__`, then don't create a default - # `__init__` here. (Otherwise e.g. if `B` has a custom init then - # `class A(B): pass` would set a dataclass init on `A`.) - # If a superclass has a default `__init__`, then do create a new default one - # here. (Dataclass default `__init__`s don't call `super()`, so they must be - # overriden directly.) + # For context, with any Python dataclass, there are three possible scenarios for + # initialisation: + # (a) a user-provided `__init__` method is supplied; + # (b) dataclasses creates `__init__` , without a user-provided `__post_init__` + # (c) dataclasses creates `__init__` , with a user-provided `__post_init__` + + # Create a dataclass `__init__` method if a user method isn't provided. + # If a user passed one on this class, then we definitely have a custom __init__. + # Else just use whatever our superclass does. Note that this is different to + # default dataclass behaviour. Given + # ``` + # @dataclass + # class Foo: def __init__(...): ... + # @dataclass + # class Bar(Foo): pass + # ``` + # then `Bar` will end up with a dataclass-provided `__init__`. That ends up + # being ergonomically very annoying, so we disable it. added_custom_init = "__init__" in cls.__dict__ if added_custom_init: - _dataclass_init = False + has_dataclass_init = False else: - for kls in cls.__mro__: + for kls in cls.__mro__[1:-1]: try: - _dataclass_init = _has_dataclass_init[kls] + has_dataclass_init = _has_dataclass_init[kls] except KeyError: # Non-Module superclasses. - pass + if kls.__init__ is not object.__init__: + has_dataclass_init = False + break else: break else: assert name == "Module" - _dataclass_init = True # eqx.Module itself - if _dataclass_init: # Using a dataclass-generated `__init__`. + has_dataclass_init = True # eqx.Module itself + + # Check for a common error. (Check for `_Initable` to avoid duplicate warnings.) + if ( + not has_dataclass_init + and hasattr(cls, "__post_init__") + and not issubclass(cls, _Initable) + ): + warnings.warn( + f"Class `{cls.__module__}.{cls.__qualname__}` has both an " + "`__init__` method and a `__post_init__` method. This means that " + "the `__post_init__` method will not be run!\n" + "The reason for this is that `__post_init__` is intended to be " + "used with the automatically-generated `__init__` method provided " + "by Python dataclasses, which are generated of the form:\n" + "```\n" + "def __init__(self, field1, field2)\n" + " self.field1 = field1\n" + " self.field2 = field2\n" + " self.__post_init__()\n" + "```\n" + "and as such a user-provided `__init__` overrides both the setting " + "of fields, and the calling of `__post_init__`.\n" + "The above is how Python dataclasses work, and has nothing " + "to do with Equinox!\n" + "If you are using `__post_init__` to check that certain invariants " + "hold, then consider using `__check_init__` instead. This is an " + "Equinox-specific extension that is always ran. See here for more " + "details: " + "https://docs.kidger.site/equinox/api/module/advanced_fields/#checking-invariants" # noqa: E501 + ) + + # Add support for `eqx.field(converter=...)` when using `__post_init__`. + # (Scenario (c) above. Scenarios (a) and (b) are handled later.) + if has_dataclass_init and hasattr(cls, "__post_init__"): + post_init = cls.__post_init__ + + @ft.wraps(post_init) # pyright: ignore + def __post_init__(self, *args, **kwargs): + # This `if` is to handle `super()` correctly. + # We want to only convert once, at the top level. + # + # This check is basically testing whether or not the function we're in + # now (`cls.__post_init__`) is at the top level + # (`self.__class__.__post_init__`). If we are, do conversion. If we're + # not, it's presumably because someone is calling us via `super()` in + # the middle of their own `__post_init__`. No conversion then; their own + # version of this wrapper will do it at the appropriate time instead. + # + # One small foible: we write `cls.__post_init__`, rather than just + # `__post_init__`, to refer to this function. This allows someone else + # to also monkey-patch `cls.__post_init__` if they wish, and this won't + # remove conversion. (Conversion is a at-the-top-level thing, not a + # this-particular-function thing.) + # + # This top-level business means that this is very nearly the same as + # doing conversion in `_ModuleMeta.__call__`. The differences are that + # (a) that wouldn't allow us to convert fields before the user-provided + # `__post_init__`, and (b) it allows other libraries (i.e. jaxtyping) + # to later monkey-patch `__init__`, and we have our converter run before + # their own monkey-patched-in code. + if self.__class__.__post_init__ is cls.__post_init__: + # Convert all fields currently available. + _convert_fields(self, init=True) + post_init(self, *args, **kwargs) # pyright: ignore + if self.__class__.__post_init__ is cls.__post_init__: + # Convert all the fields filled in by `__post_init__` as well. + _convert_fields(self, init=False) + + cls.__post_init__ = __post_init__ # pyright: ignore + else: + post_init = None + + # Fairly common to write `Superclass.__init__.__doc__ = "..."` with + # dataclass-provided inits; here we look through the class hierarchy and will + # copy this doc forward. + if has_dataclass_init: init_doc = cls.__init__.__doc__ - if not _dataclass_init: # Using a user-provided `__init__`. - # Check `_Initable` to avoid printing out duplicate warnings. - if getattr(cls, "__post_init__", None) is not None and not issubclass( - cls, _Initable - ): - warnings.warn( - f"Class `{cls.__module__}.{cls.__qualname__}` has both an " - "`__init__` method and a `__post_init__` method. This means that " - "the `__post_init__` method will not be run!\n" - "The reason for this is that `__post_init__` is intended to be " - "used with the automatically-generated `__init__` method provided " - "by Python dataclasses, which are generated of the form:\n" - "```\n" - "def __init__(self, field1, field2)\n" - " self.field1 = field1\n" - " self.field2 = field2\n" - " self.__post_init__()\n" - "```\n" - "and as such a user-provided `__init__` overrides both the setting " - "of fields, and the calling of `__post_init__`.\n" - "The above is how Python dataclasses work, and has nothing " - "to do with Equinox!\n" - "If you are using `__post_init__` to check that certain invariants " - "hold, then consider using `__check_init__` instead. This is an " - "Equinox-specific extension that is always ran. See here for more " - "details: " - "https://docs.kidger.site/equinox/api/module/advanced_fields/#checking-invariants" # noqa: E501 - ) + # [Step 4] Register as a dataclass. - cls = dataclass(eq=False, repr=False, frozen=True, init=_dataclass_init)( + cls = dataclass(eq=False, repr=False, frozen=True, init=has_dataclass_init)( cls # pyright: ignore ) - # [Step 3b] -- finish off the business of default `__init__` methods. - # (This part has to happen after dataclass registration.) - _has_dataclass_init[cls] = _dataclass_init - if _dataclass_init: - # Assign `__doc__` in case its been manually overriden: - # ``` - # class Foo(eqx.Module): - # x: int - # - # Foo.__init__.__doc__ = "Foo should be called with with an integer `x`." - # - # class Bar(Foo): - # pass - # ``` - # E.g. `Bar.__init__.__doc__` may be used during documentation generation. + # [Step 3b] -- finish off building `__init__` methods. Until we'd done + # dataclass'ification then we didn't necessarily have our `__init__` method. + + # Registering here records that the `dataclass(...)` call has happened. + _has_dataclass_init[cls] = has_dataclass_init + + # Now handle conversion for cases (a) and (b) above, in which there is no + # `__post_init__`. + if post_init is None: + init = cls.__init__ + + @ft.wraps(init) # pyright: ignore + def __init__(self, *args, **kwargs): + init(self, *args, **kwargs) + # Same `if` trick as with `__post_init__`. + if self.__class__.__init__ is cls.__init__: + _convert_fields(self, init=True) + _convert_fields(self, init=False) + + cls.__init__ = __init__ + + # Assign `__doc__` in case it has been manually overriden: + # ``` + # class Foo(eqx.Module): + # x: int + # + # Foo.__init__.__doc__ = "Foo should be called with with an integer `x`." + # + # class Bar(Foo): + # pass + # + # # Now we try to access `Bar.__init__.__doc__`. (E.g. during docgen.) + # ``` + if has_dataclass_init: cls.__init__.__doc__ = init_doc # pyright: ignore # TODO: is this next line still necessary? cls.__init__.__module__ = cls.__module__ + # [Step 5] We support an optional `strict` mode for Rust-like strictness in the # type checking. # In practice this is probably too much for your average user, but it's a great @@ -347,17 +430,15 @@ def __new__( # This method is called whenever you initialise a module: `MyModule(...)` def __call__(cls, *args, **kwargs): if _is_force_abstract[cls]: + # Any other is-abstract checks will be handled in super().__call__. raise TypeError("Cannot instantiate abstract `equinox.Module`.") # [Step 1] Modules are immutable -- except during construction. So defreeze # before init. - if _has_dataclass_init[cls]: - post_init = getattr(cls, "__post_init__", None) - else: - post_init = None - initable_cls = _make_initable(cls, post_init, wraps=False) - # [Step 2] Instantiate the class as normal. (`__init__` and `__post_init__`) - # and then re-freeze. + post_init = getattr(cls, "__post_init__", None) + initable_cls = _make_initable(cls, cls.__init__, post_init, wraps=False) + # [Step 2] Instantiate the class as normal. self = super(_ModuleMeta, initable_cls).__call__(*args, **kwargs) + assert not _is_abstract(cls) # [Step 3] Check that all fields are occupied. missing_names = { field.name @@ -371,14 +452,10 @@ def __call__(cls, *args, **kwargs): f"The following fields were not initialised during __init__: " f"{missing_names}" ) - # [Step 4] Run any custom converters. - if post_init is None: - # `if post_init is not None` then conversion happened in the override of - # `__post_init__` in `initable_cls`. - _convert_fields(self, init=True) - _convert_fields(self, init=False) + # Freeze. object.__setattr__(self, "__class__", cls) - # [Step 5] Run any custom validators. + # [Step 4] Run any custom validators. (After freezing; as they run + # unconditionally across the whole MRO, they aren't allowed to mutate.) for kls in cls.__mro__: try: check = kls.__dict__["__check_init__"] @@ -478,13 +555,14 @@ class _Initable: pass -# We pass `post_init` as an argument, rather than just looking up `cls.__post_init__`, -# in case someone instantiates a Module, then monkey-patches `__post_init__`, then -# instantiates another Module. -# Which is a crazy thing to do, but never let it be said that Equinox doesn't handle -# your edge cases. @ft.lru_cache(maxsize=128) -def _make_initable(cls: _ModuleMeta, post_init, wraps: bool) -> _ModuleMeta: +def _make_initable(cls: _ModuleMeta, init, post_init, wraps: bool) -> _ModuleMeta: + # Used as part of the key. Don't cache if these have changed. + # In practice, monkey-patching these on the class -- after you've already + # instantiated it somewhere! -- is an *ahem*, adventurous, thing to do. But never + # let it be said that Equinox doesn't support you in your questionable life choices! + del init, post_init + if wraps: field_names = _wrapper_field_names else: @@ -493,15 +571,7 @@ def _make_initable(cls: _ModuleMeta, post_init, wraps: bool) -> _ModuleMeta: } class _InitableModule(cls, _Initable): # pyright: ignore - if post_init is not None: - - @ft.wraps(post_init) - def __post_init__(self, *args, **kwargs): - # Do conversion before `__post_init__`, as a user convenience. - _convert_fields(self, init=True) - post_init(self, *args, **kwargs) - # Now convert all the fields filled in by `__post_init__` as well. - _convert_fields(self, init=False) + pass def __setattr__(self, name, value): if name in field_names: @@ -856,7 +926,7 @@ def _module_update_wrapper( leaves, treedef = jtu.tree_flatten(wrapper) wrapper = jtu.tree_unflatten(treedef, leaves) - initable_cls = _make_initable(cls, None, wraps=True) + initable_cls = _make_initable(cls, None, None, wraps=True) object.__setattr__(wrapper, "__class__", initable_cls) try: # Like `ft.update_wrapper(wrapper, wrapped, updated=())`. diff --git a/tests/test_module.py b/tests/test_module.py index 2c181fde..f9c8311b 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -255,8 +255,26 @@ def __post_init__(self): assert called -# Please don't actually do this in real code. -# +def test_converter_monkeypatched_init(): + class Foo(eqx.Module): + field: jax.Array = eqx.field(converter=jnp.asarray) + + assert shaped_allclose(Foo(1.0).field, jnp.array(1.0)) # pyright: ignore + + called = False + init = Foo.__init__ + + def __init__(self, *args, **kwargs): + nonlocal called + assert not called + called = True + init(self, *args, **kwargs) + + Foo.__init__ = __init__ + assert shaped_allclose(Foo(1.0).field, jnp.array(1.0)) # pyright: ignore + assert called + + # Note that `Foo` had to start with a `__post_init__` method for this to work. # Dataclasses check for the presence of a `__post_init__` method when the class is # created, and at that time creates a flag declaring whether to run `__post_init__` at @@ -277,16 +295,76 @@ def __post_init__(self): assert called1 called2 = False + post_init = Foo.__post_init__ def __post_init__(self): + nonlocal called1 nonlocal called2 assert not called2 + called1 = False called2 = True - assert shaped_allclose(self.field, jnp.array(1.0)) + post_init(self) Foo.__post_init__ = __post_init__ # pyright: ignore assert shaped_allclose(Foo(1.0).field, jnp.array(1.0)) # pyright: ignore assert called2 + assert called1 + + +@pytest.mark.parametrize("base_is_module", (False, True)) +def test_converter_init_hierarchy(base_is_module): + class A(eqx.Module if base_is_module else object): # pyright: ignore + def __init__(self, x): + nonlocal called + assert not called + called = True + self.x = x + + class B(eqx.Module): + x: jax.Array = eqx.field(converter=jnp.asarray) + + class C(A, B): + # Use `A.__init__` + pass + + class D(B, A): + # Use the autogenerated `B.__init__` + pass + + # In either case, conversion should happen. + + called = False + assert shaped_allclose(C(1).x, jnp.array(1)) + assert called + + assert shaped_allclose(D(1).x, jnp.array(1)) # pyright: ignore + # No `called` check, we're not using `A.__init__`. + + +@pytest.mark.parametrize("base_is_module", (False, True)) +def test_converter_post_init_hierarchy(base_is_module): + class A(eqx.Module if base_is_module else object): # pyright: ignore + def __post_init__(self): + nonlocal called + assert not called + called = True + + class B(eqx.Module): + x: jax.Array = eqx.field(converter=jnp.asarray) + + class C(A, B): + pass + + class D(B, A): + pass + + called = False + assert shaped_allclose(C(1).x, jnp.array(1)) # pyright: ignore + assert called + + called = False + assert shaped_allclose(D(1).x, jnp.array(1)) # pyright: ignore + assert called def test_init_and_postinit():