From e312e876725a1bde33b36d823efc718340d6af68 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Wed, 27 Sep 2023 18:15:34 -0700 Subject: [PATCH] Fixed eqx.field(metadata=...) resulting in static and converter being ignored. --- equinox/_module.py | 6 +++--- tests/test_module.py | 20 +++++++++++++++++++- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/equinox/_module.py b/equinox/_module.py index 29b53520..23a7182e 100644 --- a/equinox/_module.py +++ b/equinox/_module.py @@ -92,9 +92,9 @@ class MyModule(eqx.Module): to select only some fields. """ try: - metadata = dict(kwargs["metadata"]) + metadata = dict(kwargs.pop("metadata")) # safety copy except KeyError: - metadata = kwargs["metadata"] = {} + metadata = {} if "converter" in metadata: raise ValueError("Cannot use metadata with `static` already set.") if "static" in metadata: @@ -118,7 +118,7 @@ class MyModule(eqx.Module): metadata["converter"] = converter if static: metadata["static"] = True - return dataclasses.field(**kwargs) + return dataclasses.field(metadata=metadata, **kwargs) # diff --git a/tests/test_module.py b/tests/test_module.py index f9c8311b..3840a3d9 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -3,7 +3,7 @@ import functools as ft from collections.abc import Callable from dataclasses import InitVar -from typing import Any +from typing import Any, Optional import jax import jax.numpy as jnp @@ -972,3 +972,21 @@ def foo(self): assert len(leaves) == 0 y = jtu.tree_unflatten(treedef, leaves) assert y.foo == 1 + + +# https://github.com/patrick-kidger/equinox/issues/522 +def test_custom_field(): + def my_field(*, foo: Optional[bool] = None, **kwargs: Any): + metadata = kwargs.pop("metadata", {}) + if foo is not None: + metadata["foo"] = foo + return eqx.field(metadata=metadata, **kwargs) + + class ExampleModel(eqx.Module): + dynamic: jax.Array = my_field(foo=True) + static: int = my_field(foo=False, static=True) + + model = ExampleModel(dynamic=jnp.array(1), static=1) + dynamic_field, static_field = dataclasses.fields(model) + assert dynamic_field.metadata == dict(foo=True) + assert static_field.metadata == dict(foo=False, static=True)