Skip to content

Commit

Permalink
Using both __init__ and __post_init__ now raises a warning, as these …
Browse files Browse the repository at this point in the history
…are typically mutually exclusive
  • Loading branch information
patrick-kidger committed Sep 25, 2023
1 parent d9b018a commit a3e9531
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 2 deletions.
2 changes: 1 addition & 1 deletion docs/api/module/advanced_fields.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class Positive(eqx.Module):

This method has three key differences compared to the `__post_init__` provided by dataclasses:

- It is not overridden by an `__init__` method of a subclass. In contrast, the following code has a silent bug:
- It is not overridden by an `__init__` method of a subclass. In contrast, the following code has a bug (Equinox will raise a warning if you do this):

```python
class Parent(eqx.Module):
Expand Down
37 changes: 36 additions & 1 deletion equinox/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import functools as ft
import inspect
import types
import warnings
import weakref
from collections.abc import Callable
from typing import Any, cast, Optional, TYPE_CHECKING, TypeVar, Union
Expand Down Expand Up @@ -172,7 +173,37 @@ def __new__(mcs, name, bases, dict_, /, strict: bool = False, **kwargs):
assert name == "Module"
_init = True # eqx.Module itself
if _init:
# Dataclass-generated __init__
init_doc = cls.__init__.__doc__
if not _init:
# User-provided __init__
# _Initable check to avoid printing out another warning on initialisation.
if getattr(cls, "__post_init__", None) is not None and not issubclass(
cls, _Initable
):
warnings.warn(
f"Class `{cls.__module__}.{cls.__qualname__}` has both an "
"`__init__` method and a `__post_init__` method. This means that "
"the `__post_init__` method will not be run!\n"
"The reason for this is that `__post_init__` is intended to be "
"used with the automatically-generated `__init__` method provided "
"by Python dataclasses, which are generated of the form:\n"
"```\n"
"def __init__(self, field1, field2)\n"
" self.field1 = field1\n"
" self.field2 = field2\n"
" self.__post_init__()\n"
"```\n"
"and as such a user-provided `__init__` overrides both the setting "
"of fields, and the calling of `__post_init__`.\n"
"The above is purely how Python dataclasses work, and has nothing "
"to do with Equinox!\n"
"If you are using `__post_init__` to check that certain invariants "
"hold, then consider using `__check_init__` instead. This is an "
"Equinox-specific extension that is always ran. See here for more "
"details: "
"https://docs.kidger.site/equinox/api/module/advanced_fields/#checking-invariants" # noqa: E501
)
# [Step 5] Register as a dataclass.
cls = dataclass(eq=False, repr=False, frozen=True, init=_init)(
cls # pyright: ignore
Expand Down Expand Up @@ -317,6 +348,10 @@ def __get__(self, instance, owner):
}


class _Initable:
pass


@ft.lru_cache(maxsize=128)
def _make_initable(cls: _ModuleMeta, wraps: bool) -> _ModuleMeta:
if wraps:
Expand All @@ -326,7 +361,7 @@ def _make_initable(cls: _ModuleMeta, wraps: bool) -> _ModuleMeta:
field.name for field in dataclasses.fields(cls) # pyright: ignore
}

class _InitableModule(cls): # pyright: ignore
class _InitableModule(cls, _Initable): # pyright: ignore
pass

def __setattr__(self, name, value):
Expand Down
27 changes: 27 additions & 0 deletions tests/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,3 +552,30 @@ class Abstract3(eqx.Module, strict=True):
@abc.abstractmethod
def foo(self):
pass


def test_post_init_warning():
class A(eqx.Module):
called = False

def __post_init__(self):
type(self).called = True

with pytest.warns(
UserWarning, match="test_module.test_post_init_warning.<locals>.B"
):

class B(A):
def __init__(self):
pass

with pytest.warns(
UserWarning, match="test_module.test_post_init_warning.<locals>.C"
):

class C(B):
pass

B()
C()
assert not A.called

0 comments on commit a3e9531

Please sign in to comment.