Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Moved conversion to Foo.__init__ from MetaFoo.__call__. #524

Merged
merged 1 commit into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
244 changes: 157 additions & 87 deletions equinox/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,81 +155,164 @@ def __new__(
for k, v in cls.__dict__.items():
if _not_magic(k) and inspect.isfunction(v):
setattr(cls, k, _wrap_method(v))
# [Step 3] Create a default `__init__` method if a user method isn't provided.

# [Step 3] Handle initialisation.
#
# If a superclass has a custom `__init__`, then don't create a default
# `__init__` here. (Otherwise e.g. if `B` has a custom init then
# `class A(B): pass` would set a dataclass init on `A`.)
# If a superclass has a default `__init__`, then do create a new default one
# here. (Dataclass default `__init__`s don't call `super()`, so they must be
# overriden directly.)
# For context, with any Python dataclass, there are three possible scenarios for
# initialisation:
# (a) a user-provided `__init__` method is supplied;
# (b) dataclasses creates `__init__` , without a user-provided `__post_init__`
# (c) dataclasses creates `__init__` , with a user-provided `__post_init__`

# Create a dataclass `__init__` method if a user method isn't provided.
# If a user passed one on this class, then we definitely have a custom __init__.
# Else just use whatever our superclass does. Note that this is different to
# default dataclass behaviour. Given
# ```
# @dataclass
# class Foo: def __init__(...): ...
# @dataclass
# class Bar(Foo): pass
# ```
# then `Bar` will end up with a dataclass-provided `__init__`. That ends up
# being ergonomically very annoying, so we disable it.
added_custom_init = "__init__" in cls.__dict__
if added_custom_init:
_dataclass_init = False
has_dataclass_init = False
else:
for kls in cls.__mro__:
for kls in cls.__mro__[1:-1]:
try:
_dataclass_init = _has_dataclass_init[kls]
has_dataclass_init = _has_dataclass_init[kls]
except KeyError:
# Non-Module superclasses.
pass
if kls.__init__ is not object.__init__:
has_dataclass_init = False
break
else:
break
else:
assert name == "Module"
_dataclass_init = True # eqx.Module itself
if _dataclass_init: # Using a dataclass-generated `__init__`.
has_dataclass_init = True # eqx.Module itself

# Check for a common error. (Check for `_Initable` to avoid duplicate warnings.)
if (
not has_dataclass_init
and hasattr(cls, "__post_init__")
and not issubclass(cls, _Initable)
):
warnings.warn(
f"Class `{cls.__module__}.{cls.__qualname__}` has both an "
"`__init__` method and a `__post_init__` method. This means that "
"the `__post_init__` method will not be run!\n"
"The reason for this is that `__post_init__` is intended to be "
"used with the automatically-generated `__init__` method provided "
"by Python dataclasses, which are generated of the form:\n"
"```\n"
"def __init__(self, field1, field2)\n"
" self.field1 = field1\n"
" self.field2 = field2\n"
" self.__post_init__()\n"
"```\n"
"and as such a user-provided `__init__` overrides both the setting "
"of fields, and the calling of `__post_init__`.\n"
"The above is how Python dataclasses work, and has nothing "
"to do with Equinox!\n"
"If you are using `__post_init__` to check that certain invariants "
"hold, then consider using `__check_init__` instead. This is an "
"Equinox-specific extension that is always ran. See here for more "
"details: "
"https://docs.kidger.site/equinox/api/module/advanced_fields/#checking-invariants" # noqa: E501
)

# Add support for `eqx.field(converter=...)` when using `__post_init__`.
# (Scenario (c) above. Scenarios (a) and (b) are handled later.)
if has_dataclass_init and hasattr(cls, "__post_init__"):
post_init = cls.__post_init__

@ft.wraps(post_init) # pyright: ignore
def __post_init__(self, *args, **kwargs):
# This `if` is to handle `super()` correctly.
# We want to only convert once, at the top level.
#
# This check is basically testing whether or not the function we're in
# now (`cls.__post_init__`) is at the top level
# (`self.__class__.__post_init__`). If we are, do conversion. If we're
# not, it's presumably because someone is calling us via `super()` in
# the middle of their own `__post_init__`. No conversion then; their own
# version of this wrapper will do it at the appropriate time instead.
#
# One small foible: we write `cls.__post_init__`, rather than just
# `__post_init__`, to refer to this function. This allows someone else
# to also monkey-patch `cls.__post_init__` if they wish, and this won't
# remove conversion. (Conversion is a at-the-top-level thing, not a
# this-particular-function thing.)
#
# This top-level business means that this is very nearly the same as
# doing conversion in `_ModuleMeta.__call__`. The differences are that
# (a) that wouldn't allow us to convert fields before the user-provided
# `__post_init__`, and (b) it allows other libraries (i.e. jaxtyping)
# to later monkey-patch `__init__`, and we have our converter run before
# their own monkey-patched-in code.
if self.__class__.__post_init__ is cls.__post_init__:
# Convert all fields currently available.
_convert_fields(self, init=True)
post_init(self, *args, **kwargs) # pyright: ignore
if self.__class__.__post_init__ is cls.__post_init__:
# Convert all the fields filled in by `__post_init__` as well.
_convert_fields(self, init=False)

cls.__post_init__ = __post_init__ # pyright: ignore
else:
post_init = None

# Fairly common to write `Superclass.__init__.__doc__ = "..."` with
# dataclass-provided inits; here we look through the class hierarchy and will
# copy this doc forward.
if has_dataclass_init:
init_doc = cls.__init__.__doc__
if not _dataclass_init: # Using a user-provided `__init__`.
# Check `_Initable` to avoid printing out duplicate warnings.
if getattr(cls, "__post_init__", None) is not None and not issubclass(
cls, _Initable
):
warnings.warn(
f"Class `{cls.__module__}.{cls.__qualname__}` has both an "
"`__init__` method and a `__post_init__` method. This means that "
"the `__post_init__` method will not be run!\n"
"The reason for this is that `__post_init__` is intended to be "
"used with the automatically-generated `__init__` method provided "
"by Python dataclasses, which are generated of the form:\n"
"```\n"
"def __init__(self, field1, field2)\n"
" self.field1 = field1\n"
" self.field2 = field2\n"
" self.__post_init__()\n"
"```\n"
"and as such a user-provided `__init__` overrides both the setting "
"of fields, and the calling of `__post_init__`.\n"
"The above is how Python dataclasses work, and has nothing "
"to do with Equinox!\n"
"If you are using `__post_init__` to check that certain invariants "
"hold, then consider using `__check_init__` instead. This is an "
"Equinox-specific extension that is always ran. See here for more "
"details: "
"https://docs.kidger.site/equinox/api/module/advanced_fields/#checking-invariants" # noqa: E501
)

# [Step 4] Register as a dataclass.
cls = dataclass(eq=False, repr=False, frozen=True, init=_dataclass_init)(
cls = dataclass(eq=False, repr=False, frozen=True, init=has_dataclass_init)(
cls # pyright: ignore
)
# [Step 3b] -- finish off the business of default `__init__` methods.
# (This part has to happen after dataclass registration.)
_has_dataclass_init[cls] = _dataclass_init
if _dataclass_init:
# Assign `__doc__` in case its been manually overriden:
# ```
# class Foo(eqx.Module):
# x: int
#
# Foo.__init__.__doc__ = "Foo should be called with with an integer `x`."
#
# class Bar(Foo):
# pass
# ```
# E.g. `Bar.__init__.__doc__` may be used during documentation generation.
# [Step 3b] -- finish off building `__init__` methods. Until we'd done
# dataclass'ification then we didn't necessarily have our `__init__` method.

# Registering here records that the `dataclass(...)` call has happened.
_has_dataclass_init[cls] = has_dataclass_init

# Now handle conversion for cases (a) and (b) above, in which there is no
# `__post_init__`.
if post_init is None:
init = cls.__init__

@ft.wraps(init) # pyright: ignore
def __init__(self, *args, **kwargs):
init(self, *args, **kwargs)
# Same `if` trick as with `__post_init__`.
if self.__class__.__init__ is cls.__init__:
_convert_fields(self, init=True)
_convert_fields(self, init=False)

cls.__init__ = __init__

# Assign `__doc__` in case it has been manually overriden:
# ```
# class Foo(eqx.Module):
# x: int
#
# Foo.__init__.__doc__ = "Foo should be called with with an integer `x`."
#
# class Bar(Foo):
# pass
#
# # Now we try to access `Bar.__init__.__doc__`. (E.g. during docgen.)
# ```
if has_dataclass_init:
cls.__init__.__doc__ = init_doc # pyright: ignore
# TODO: is this next line still necessary?
cls.__init__.__module__ = cls.__module__

# [Step 5] We support an optional `strict` mode for Rust-like strictness in the
# type checking.
# In practice this is probably too much for your average user, but it's a great
Expand Down Expand Up @@ -347,17 +430,15 @@ def __new__(
# This method is called whenever you initialise a module: `MyModule(...)`
def __call__(cls, *args, **kwargs):
if _is_force_abstract[cls]:
# Any other is-abstract checks will be handled in super().__call__.
raise TypeError("Cannot instantiate abstract `equinox.Module`.")
# [Step 1] Modules are immutable -- except during construction. So defreeze
# before init.
if _has_dataclass_init[cls]:
post_init = getattr(cls, "__post_init__", None)
else:
post_init = None
initable_cls = _make_initable(cls, post_init, wraps=False)
# [Step 2] Instantiate the class as normal. (`__init__` and `__post_init__`)
# and then re-freeze.
post_init = getattr(cls, "__post_init__", None)
initable_cls = _make_initable(cls, cls.__init__, post_init, wraps=False)
# [Step 2] Instantiate the class as normal.
self = super(_ModuleMeta, initable_cls).__call__(*args, **kwargs)
assert not _is_abstract(cls)
# [Step 3] Check that all fields are occupied.
missing_names = {
field.name
Expand All @@ -371,14 +452,10 @@ def __call__(cls, *args, **kwargs):
f"The following fields were not initialised during __init__: "
f"{missing_names}"
)
# [Step 4] Run any custom converters.
if post_init is None:
# `if post_init is not None` then conversion happened in the override of
# `__post_init__` in `initable_cls`.
_convert_fields(self, init=True)
_convert_fields(self, init=False)
# Freeze.
object.__setattr__(self, "__class__", cls)
# [Step 5] Run any custom validators.
# [Step 4] Run any custom validators. (After freezing; as they run
# unconditionally across the whole MRO, they aren't allowed to mutate.)
for kls in cls.__mro__:
try:
check = kls.__dict__["__check_init__"]
Expand Down Expand Up @@ -478,13 +555,14 @@ class _Initable:
pass


# We pass `post_init` as an argument, rather than just looking up `cls.__post_init__`,
# in case someone instantiates a Module, then monkey-patches `__post_init__`, then
# instantiates another Module.
# Which is a crazy thing to do, but never let it be said that Equinox doesn't handle
# your edge cases.
@ft.lru_cache(maxsize=128)
def _make_initable(cls: _ModuleMeta, post_init, wraps: bool) -> _ModuleMeta:
def _make_initable(cls: _ModuleMeta, init, post_init, wraps: bool) -> _ModuleMeta:
# Used as part of the key. Don't cache if these have changed.
# In practice, monkey-patching these on the class -- after you've already
# instantiated it somewhere! -- is an *ahem*, adventurous, thing to do. But never
# let it be said that Equinox doesn't support you in your questionable life choices!
del init, post_init

if wraps:
field_names = _wrapper_field_names
else:
Expand All @@ -493,15 +571,7 @@ def _make_initable(cls: _ModuleMeta, post_init, wraps: bool) -> _ModuleMeta:
}

class _InitableModule(cls, _Initable): # pyright: ignore
if post_init is not None:

@ft.wraps(post_init)
def __post_init__(self, *args, **kwargs):
# Do conversion before `__post_init__`, as a user convenience.
_convert_fields(self, init=True)
post_init(self, *args, **kwargs)
# Now convert all the fields filled in by `__post_init__` as well.
_convert_fields(self, init=False)
pass

def __setattr__(self, name, value):
if name in field_names:
Expand Down Expand Up @@ -856,7 +926,7 @@ def _module_update_wrapper(
leaves, treedef = jtu.tree_flatten(wrapper)
wrapper = jtu.tree_unflatten(treedef, leaves)

initable_cls = _make_initable(cls, None, wraps=True)
initable_cls = _make_initable(cls, None, None, wraps=True)
object.__setattr__(wrapper, "__class__", initable_cls)
try:
# Like `ft.update_wrapper(wrapper, wrapped, updated=())`.
Expand Down
Loading
Loading