Skip to content

Commit

Permalink
Fixed eqx.field(metadata=...) resulting in static and converter being…
Browse files Browse the repository at this point in the history
… ignored.
  • Loading branch information
patrick-kidger committed Sep 28, 2023
1 parent 0c06b17 commit 9528aee
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 4 deletions.
6 changes: 3 additions & 3 deletions equinox/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@ class MyModule(eqx.Module):
to select only some fields.
"""
try:
metadata = dict(kwargs["metadata"])
metadata = dict(kwargs.pop("metadata")) # safety copy
except KeyError:
metadata = kwargs["metadata"] = {}
metadata = {}
if "converter" in metadata:
raise ValueError("Cannot use metadata with `static` already set.")
if "static" in metadata:
Expand All @@ -118,7 +118,7 @@ class MyModule(eqx.Module):
metadata["converter"] = converter
if static:
metadata["static"] = True
return dataclasses.field(**kwargs)
return dataclasses.field(metadata=metadata, **kwargs)


#
Expand Down
20 changes: 19 additions & 1 deletion tests/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import functools as ft
from collections.abc import Callable
from dataclasses import InitVar
from typing import Any
from typing import Any, Optional

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -972,3 +972,21 @@ def foo(self):
assert len(leaves) == 0
y = jtu.tree_unflatten(treedef, leaves)
assert y.foo == 1


# https://github.com/patrick-kidger/equinox/issues/522
def test_custom_field():
def my_field(*, foo: Optional[bool] = None, **kwargs: Any):
metadata = kwargs.pop("metadata", {})
if foo is not None:
metadata["foo"] = foo
return eqx.field(metadata=metadata, **kwargs)

class ExampleModel(eqx.Module):
dynamic: jax.Array = my_field(foo=True)
static: int = my_field(foo=False, static=True)

model = ExampleModel(dynamic=jnp.array(1), static=1)
dynamic_field, static_field = dataclasses.fields(model)
assert dynamic_field.metadata == dict(foo=True)
assert static_field.metadata == dict(foo=False, static=True)

0 comments on commit 9528aee

Please sign in to comment.