Skip to content

Commit

Permalink
Moved conversion to Foo.__init__ from MetaFoo.__call__.
Browse files Browse the repository at this point in the history
This is necessary to allow downstream libraries, like jaxtyping, to
monkey-patch in their own checks.
  • Loading branch information
patrick-kidger committed Sep 28, 2023
1 parent 7ea43cf commit 0103397
Show file tree
Hide file tree
Showing 2 changed files with 238 additions and 90 deletions.
244 changes: 157 additions & 87 deletions equinox/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,81 +155,164 @@ def __new__(
for k, v in cls.__dict__.items():
if _not_magic(k) and inspect.isfunction(v):
setattr(cls, k, _wrap_method(v))
# [Step 3] Create a default `__init__` method if a user method isn't provided.

# [Step 3] Handle initialisation.
#
# If a superclass has a custom `__init__`, then don't create a default
# `__init__` here. (Otherwise e.g. if `B` has a custom init then
# `class A(B): pass` would set a dataclass init on `A`.)
# If a superclass has a default `__init__`, then do create a new default one
# here. (Dataclass default `__init__`s don't call `super()`, so they must be
# overriden directly.)
# For context, with any Python dataclass, there are three possible scenarios for
# initialisation:
# (a) a user-provided `__init__` method is supplied;
# (b) dataclasses creates `__init__` , without a user-provided `__post_init__`
# (c) dataclasses creates `__init__` , with a user-provided `__post_init__`

# Create a dataclass `__init__` method if a user method isn't provided.
# If a user passed one on this class, then we definitely have a custom __init__.
# Else just use whatever our superclass does. Note that this is different to
# default dataclass behaviour. Given
# ```
# @dataclass
# class Foo: def __init__(...): ...
# @dataclass
# class Bar(Foo): pass
# ```
# then `Bar` will end up with a dataclass-provided `__init__`. That ends up
# being ergonomically very annoying, so we disable it.
added_custom_init = "__init__" in cls.__dict__
if added_custom_init:
_dataclass_init = False
has_dataclass_init = False
else:
for kls in cls.__mro__:
for kls in cls.__mro__[1:-1]:
try:
_dataclass_init = _has_dataclass_init[kls]
has_dataclass_init = _has_dataclass_init[kls]
except KeyError:
# Non-Module superclasses.
pass
if kls.__init__ is not object.__init__:
has_dataclass_init = False
break
else:
break
else:
assert name == "Module"
_dataclass_init = True # eqx.Module itself
if _dataclass_init: # Using a dataclass-generated `__init__`.
has_dataclass_init = True # eqx.Module itself

# Check for a common error. (Check for `_Initable` to avoid duplicate warnings.)
if (
not has_dataclass_init
and hasattr(cls, "__post_init__")
and not issubclass(cls, _Initable)
):
warnings.warn(
f"Class `{cls.__module__}.{cls.__qualname__}` has both an "
"`__init__` method and a `__post_init__` method. This means that "
"the `__post_init__` method will not be run!\n"
"The reason for this is that `__post_init__` is intended to be "
"used with the automatically-generated `__init__` method provided "
"by Python dataclasses, which are generated of the form:\n"
"```\n"
"def __init__(self, field1, field2)\n"
" self.field1 = field1\n"
" self.field2 = field2\n"
" self.__post_init__()\n"
"```\n"
"and as such a user-provided `__init__` overrides both the setting "
"of fields, and the calling of `__post_init__`.\n"
"The above is how Python dataclasses work, and has nothing "
"to do with Equinox!\n"
"If you are using `__post_init__` to check that certain invariants "
"hold, then consider using `__check_init__` instead. This is an "
"Equinox-specific extension that is always ran. See here for more "
"details: "
"https://docs.kidger.site/equinox/api/module/advanced_fields/#checking-invariants" # noqa: E501
)

# Add support for `eqx.field(converter=...)` when using `__post_init__`.
# (Scenario (c) above. Scenarios (a) and (b) are handled later.)
if has_dataclass_init and hasattr(cls, "__post_init__"):
post_init = cls.__post_init__

@ft.wraps(post_init) # pyright: ignore
def __post_init__(self, *args, **kwargs):
# This `if` is to handle `super()` correctly.
# We want to only convert once, at the top level.
#
# This check is basically testing whether or not the function we're in
# now (`cls.__post_init__`) is at the top level
# (`self.__class__.__post_init__`). If we are, do conversion. If we're
# not, it's presumably because someone is calling us via `super()` in
# the middle of their own `__post_init__`. No conversion then; their own
# version of this wrapper will do it at the appropriate time instead.
#
# One small foible: we write `cls.__post_init__`, rather than just
# `__post_init__`, to refer to this function. This allows someone else
# to also monkey-patch `cls.__post_init__` if they wish, and this won't
# remove conversion. (Conversion is a at-the-top-level thing, not a
# this-particular-function thing.)
#
# This top-level business means that this is very nearly the same as
# doing conversion in `_ModuleMeta.__call__`. The differences are that
# (a) that wouldn't allow us to convert fields before the user-provided
# `__post_init__`, and (b) it allows other libraries (i.e. jaxtyping)
# to later monkey-patch `__init__`, and we have our converter run before
# their own monkey-patched-in code.
if self.__class__.__post_init__ is cls.__post_init__:
# Convert all fields currently available.
_convert_fields(self, init=True)
post_init(self, *args, **kwargs) # pyright: ignore
if self.__class__.__post_init__ is cls.__post_init__:
# Convert all the fields filled in by `__post_init__` as well.
_convert_fields(self, init=False)

cls.__post_init__ = __post_init__ # pyright: ignore
else:
post_init = None

# Fairly common to write `Superclass.__init__.__doc__ = "..."` with
# dataclass-provided inits; here we look through the class hierarchy and will
# copy this doc forward.
if has_dataclass_init:
init_doc = cls.__init__.__doc__
if not _dataclass_init: # Using a user-provided `__init__`.
# Check `_Initable` to avoid printing out duplicate warnings.
if getattr(cls, "__post_init__", None) is not None and not issubclass(
cls, _Initable
):
warnings.warn(
f"Class `{cls.__module__}.{cls.__qualname__}` has both an "
"`__init__` method and a `__post_init__` method. This means that "
"the `__post_init__` method will not be run!\n"
"The reason for this is that `__post_init__` is intended to be "
"used with the automatically-generated `__init__` method provided "
"by Python dataclasses, which are generated of the form:\n"
"```\n"
"def __init__(self, field1, field2)\n"
" self.field1 = field1\n"
" self.field2 = field2\n"
" self.__post_init__()\n"
"```\n"
"and as such a user-provided `__init__` overrides both the setting "
"of fields, and the calling of `__post_init__`.\n"
"The above is how Python dataclasses work, and has nothing "
"to do with Equinox!\n"
"If you are using `__post_init__` to check that certain invariants "
"hold, then consider using `__check_init__` instead. This is an "
"Equinox-specific extension that is always ran. See here for more "
"details: "
"https://docs.kidger.site/equinox/api/module/advanced_fields/#checking-invariants" # noqa: E501
)

# [Step 4] Register as a dataclass.
cls = dataclass(eq=False, repr=False, frozen=True, init=_dataclass_init)(
cls = dataclass(eq=False, repr=False, frozen=True, init=has_dataclass_init)(
cls # pyright: ignore
)
# [Step 3b] -- finish off the business of default `__init__` methods.
# (This part has to happen after dataclass registration.)
_has_dataclass_init[cls] = _dataclass_init
if _dataclass_init:
# Assign `__doc__` in case its been manually overriden:
# ```
# class Foo(eqx.Module):
# x: int
#
# Foo.__init__.__doc__ = "Foo should be called with with an integer `x`."
#
# class Bar(Foo):
# pass
# ```
# E.g. `Bar.__init__.__doc__` may be used during documentation generation.
# [Step 3b] -- finish off building `__init__` methods. Until we'd done
# dataclass'ification then we didn't necessarily have our `__init__` method.

# Registering here records that the `dataclass(...)` call has happened.
_has_dataclass_init[cls] = has_dataclass_init

# Now handle conversion for cases (a) and (b) above, in which there is no
# `__post_init__`.
if post_init is None:
init = cls.__init__

@ft.wraps(init) # pyright: ignore
def __init__(self, *args, **kwargs):
init(self, *args, **kwargs)
# Same `if` trick as with `__post_init__`.
if self.__class__.__init__ is cls.__init__:
_convert_fields(self, init=True)
_convert_fields(self, init=False)

cls.__init__ = __init__

# Assign `__doc__` in case it has been manually overriden:
# ```
# class Foo(eqx.Module):
# x: int
#
# Foo.__init__.__doc__ = "Foo should be called with with an integer `x`."
#
# class Bar(Foo):
# pass
#
# # Now we try to access `Bar.__init__.__doc__`. (E.g. during docgen.)
# ```
if has_dataclass_init:
cls.__init__.__doc__ = init_doc # pyright: ignore
# TODO: is this next line still necessary?
cls.__init__.__module__ = cls.__module__

# [Step 5] We support an optional `strict` mode for Rust-like strictness in the
# type checking.
# In practice this is probably too much for your average user, but it's a great
Expand Down Expand Up @@ -347,17 +430,15 @@ def __new__(
# This method is called whenever you initialise a module: `MyModule(...)`
def __call__(cls, *args, **kwargs):
if _is_force_abstract[cls]:
# Any other is-abstract checks will be handled in super().__call__.
raise TypeError("Cannot instantiate abstract `equinox.Module`.")
# [Step 1] Modules are immutable -- except during construction. So defreeze
# before init.
if _has_dataclass_init[cls]:
post_init = getattr(cls, "__post_init__", None)
else:
post_init = None
initable_cls = _make_initable(cls, post_init, wraps=False)
# [Step 2] Instantiate the class as normal. (`__init__` and `__post_init__`)
# and then re-freeze.
post_init = getattr(cls, "__post_init__", None)
initable_cls = _make_initable(cls, cls.__init__, post_init, wraps=False)
# [Step 2] Instantiate the class as normal.
self = super(_ModuleMeta, initable_cls).__call__(*args, **kwargs)
assert not _is_abstract(cls)
# [Step 3] Check that all fields are occupied.
missing_names = {
field.name
Expand All @@ -371,14 +452,10 @@ def __call__(cls, *args, **kwargs):
f"The following fields were not initialised during __init__: "
f"{missing_names}"
)
# [Step 4] Run any custom converters.
if post_init is None:
# `if post_init is not None` then conversion happened in the override of
# `__post_init__` in `initable_cls`.
_convert_fields(self, init=True)
_convert_fields(self, init=False)
# Freeze.
object.__setattr__(self, "__class__", cls)
# [Step 5] Run any custom validators.
# [Step 4] Run any custom validators. (After freezing; as they run
# unconditionally across the whole MRO, they aren't allowed to mutate.)
for kls in cls.__mro__:
try:
check = kls.__dict__["__check_init__"]
Expand Down Expand Up @@ -478,13 +555,14 @@ class _Initable:
pass


# We pass `post_init` as an argument, rather than just looking up `cls.__post_init__`,
# in case someone instantiates a Module, then monkey-patches `__post_init__`, then
# instantiates another Module.
# Which is a crazy thing to do, but never let it be said that Equinox doesn't handle
# your edge cases.
@ft.lru_cache(maxsize=128)
def _make_initable(cls: _ModuleMeta, post_init, wraps: bool) -> _ModuleMeta:
def _make_initable(cls: _ModuleMeta, init, post_init, wraps: bool) -> _ModuleMeta:
# Used as part of the key. Don't cache if these have changed.
# In practice, monkey-patching these on the class -- after you've already
# instantiated it somewhere! -- is an *ahem*, adventurous, thing to do. But never
# let it be said that Equinox doesn't support you in your questionable life choices!
del init, post_init

if wraps:
field_names = _wrapper_field_names
else:
Expand All @@ -493,15 +571,7 @@ def _make_initable(cls: _ModuleMeta, post_init, wraps: bool) -> _ModuleMeta:
}

class _InitableModule(cls, _Initable): # pyright: ignore
if post_init is not None:

@ft.wraps(post_init)
def __post_init__(self, *args, **kwargs):
# Do conversion before `__post_init__`, as a user convenience.
_convert_fields(self, init=True)
post_init(self, *args, **kwargs)
# Now convert all the fields filled in by `__post_init__` as well.
_convert_fields(self, init=False)
pass

def __setattr__(self, name, value):
if name in field_names:
Expand Down Expand Up @@ -856,7 +926,7 @@ def _module_update_wrapper(
leaves, treedef = jtu.tree_flatten(wrapper)
wrapper = jtu.tree_unflatten(treedef, leaves)

initable_cls = _make_initable(cls, None, wraps=True)
initable_cls = _make_initable(cls, None, None, wraps=True)
object.__setattr__(wrapper, "__class__", initable_cls)
try:
# Like `ft.update_wrapper(wrapper, wrapped, updated=())`.
Expand Down
Loading

0 comments on commit 0103397

Please sign in to comment.