Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using both __init__ and __post_init__ now raises a warning, as these are typically mutually exclusive #514

Merged
merged 1 commit into from
Sep 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/api/module/advanced_fields.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class Positive(eqx.Module):

This method has three key differences compared to the `__post_init__` provided by dataclasses:

- It is not overridden by an `__init__` method of a subclass. In contrast, the following code has a silent bug:
- It is not overridden by an `__init__` method of a subclass. In contrast, the following code has a bug (Equinox will raise a warning if you do this):

```python
class Parent(eqx.Module):
Expand Down
37 changes: 36 additions & 1 deletion equinox/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import functools as ft
import inspect
import types
import warnings
import weakref
from collections.abc import Callable
from typing import Any, cast, Optional, TYPE_CHECKING, TypeVar, Union
Expand Down Expand Up @@ -172,7 +173,37 @@ def __new__(mcs, name, bases, dict_, /, strict: bool = False, **kwargs):
assert name == "Module"
_init = True # eqx.Module itself
if _init:
# Dataclass-generated __init__
init_doc = cls.__init__.__doc__
if not _init:
# User-provided __init__
# _Initable check to avoid printing out another warning on initialisation.
if getattr(cls, "__post_init__", None) is not None and not issubclass(
cls, _Initable
):
warnings.warn(
f"Class `{cls.__module__}.{cls.__qualname__}` has both an "
"`__init__` method and a `__post_init__` method. This means that "
"the `__post_init__` method will not be run!\n"
"The reason for this is that `__post_init__` is intended to be "
"used with the automatically-generated `__init__` method provided "
"by Python dataclasses, which are generated of the form:\n"
"```\n"
"def __init__(self, field1, field2)\n"
" self.field1 = field1\n"
" self.field2 = field2\n"
" self.__post_init__()\n"
"```\n"
"and as such a user-provided `__init__` overrides both the setting "
"of fields, and the calling of `__post_init__`.\n"
"The above is purely how Python dataclasses work, and has nothing "
"to do with Equinox!\n"
"If you are using `__post_init__` to check that certain invariants "
"hold, then consider using `__check_init__` instead. This is an "
"Equinox-specific extension that is always ran. See here for more "
"details: "
"https://docs.kidger.site/equinox/api/module/advanced_fields/#checking-invariants" # noqa: E501
)
# [Step 5] Register as a dataclass.
cls = dataclass(eq=False, repr=False, frozen=True, init=_init)(
cls # pyright: ignore
Expand Down Expand Up @@ -317,6 +348,10 @@ def __get__(self, instance, owner):
}


class _Initable:
pass


@ft.lru_cache(maxsize=128)
def _make_initable(cls: _ModuleMeta, wraps: bool) -> _ModuleMeta:
if wraps:
Expand All @@ -326,7 +361,7 @@ def _make_initable(cls: _ModuleMeta, wraps: bool) -> _ModuleMeta:
field.name for field in dataclasses.fields(cls) # pyright: ignore
}

class _InitableModule(cls): # pyright: ignore
class _InitableModule(cls, _Initable): # pyright: ignore
pass

def __setattr__(self, name, value):
Expand Down
27 changes: 27 additions & 0 deletions tests/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,3 +552,30 @@ class Abstract3(eqx.Module, strict=True):
@abc.abstractmethod
def foo(self):
pass


def test_post_init_warning():
class A(eqx.Module):
called = False

def __post_init__(self):
type(self).called = True

with pytest.warns(
UserWarning, match="test_module.test_post_init_warning.<locals>.B"
):

class B(A):
def __init__(self):
pass

with pytest.warns(
UserWarning, match="test_module.test_post_init_warning.<locals>.C"
):

class C(B):
pass

B()
C()
assert not A.called
Loading