diff --git a/docs/api/module/advanced_fields.md b/docs/api/module/advanced_fields.md index 7343aa80..2835e242 100644 --- a/docs/api/module/advanced_fields.md +++ b/docs/api/module/advanced_fields.md @@ -33,7 +33,7 @@ class Positive(eqx.Module): This method has three key differences compared to the `__post_init__` provided by dataclasses: -- It is not overridden by an `__init__` method of a subclass. In contrast, the following code has a silent bug: +- It is not overridden by an `__init__` method of a subclass. In contrast, the following code has a bug (Equinox will raise a warning if you do this): ```python class Parent(eqx.Module): diff --git a/equinox/_module.py b/equinox/_module.py index a9e69085..e5c48684 100644 --- a/equinox/_module.py +++ b/equinox/_module.py @@ -8,6 +8,7 @@ import functools as ft import inspect import types +import warnings import weakref from collections.abc import Callable from typing import Any, cast, Optional, TYPE_CHECKING, TypeVar, Union @@ -172,7 +173,37 @@ def __new__(mcs, name, bases, dict_, /, strict: bool = False, **kwargs): assert name == "Module" _init = True # eqx.Module itself if _init: + # Dataclass-generated __init__ init_doc = cls.__init__.__doc__ + if not _init: + # User-provided __init__ + # _Initable check to avoid printing out another warning on initialisation. + if getattr(cls, "__post_init__", None) is not None and not issubclass( + cls, _Initable + ): + warnings.warn( + f"Class `{cls.__module__}.{cls.__qualname__}` has both an " + "`__init__` method and a `__post_init__` method. This means that " + "the `__post_init__` method will not be run!\n" + "The reason for this is that `__post_init__` is intended to be " + "used with the automatically-generated `__init__` method provided " + "by Python dataclasses, which are generated of the form:\n" + "```\n" + "def __init__(self, field1, field2)\n" + " self.field1 = field1\n" + " self.field2 = field2\n" + " self.__post_init__()\n" + "```\n" + "and as such a user-provided `__init__` overrides both the setting " + "of fields, and the calling of `__post_init__`.\n" + "The above is purely how Python dataclasses work, and has nothing " + "to do with Equinox!\n" + "If you are using `__post_init__` to check that certain invariants " + "hold, then consider using `__check_init__` instead. This is an " + "Equinox-specific extension that is always ran. See here for more " + "details: " + "https://docs.kidger.site/equinox/api/module/advanced_fields/#checking-invariants" # noqa: E501 + ) # [Step 5] Register as a dataclass. cls = dataclass(eq=False, repr=False, frozen=True, init=_init)( cls # pyright: ignore @@ -317,6 +348,10 @@ def __get__(self, instance, owner): } +class _Initable: + pass + + @ft.lru_cache(maxsize=128) def _make_initable(cls: _ModuleMeta, wraps: bool) -> _ModuleMeta: if wraps: @@ -326,7 +361,7 @@ def _make_initable(cls: _ModuleMeta, wraps: bool) -> _ModuleMeta: field.name for field in dataclasses.fields(cls) # pyright: ignore } - class _InitableModule(cls): # pyright: ignore + class _InitableModule(cls, _Initable): # pyright: ignore pass def __setattr__(self, name, value): diff --git a/tests/test_module.py b/tests/test_module.py index 4d31282f..83f50360 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -552,3 +552,30 @@ class Abstract3(eqx.Module, strict=True): @abc.abstractmethod def foo(self): pass + + +def test_post_init_warning(): + class A(eqx.Module): + called = False + + def __post_init__(self): + type(self).called = True + + with pytest.warns( + UserWarning, match="test_module.test_post_init_warning..B" + ): + + class B(A): + def __init__(self): + pass + + with pytest.warns( + UserWarning, match="test_module.test_post_init_warning..C" + ): + + class C(B): + pass + + B() + C() + assert not A.called