Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Moved conversion to Foo.__init__ from MetaFoo.__call__. #524

Merged
merged 1 commit into from
Sep 28, 2023

Conversation

patrick-kidger
Copy link
Owner

This is necessary to allow downstream libraries, like jaxtyping, to
monkey-patch in their own checks.

patrick-kidger added a commit to patrick-kidger/jaxtyping that referenced this pull request Sep 27, 2023
Previously, using the import hook with dataclasses resulted in the `__init__` method of the dataclass being checked.
This was undesirable when using `eqx.field(converter=...)`, as the annotation didn't necessarily reflect the argument type.
A typical example was
```python
class Foo(eqx.Module):
    x: jax.Array = eqx.field(converter=jnp.ndarray)

Foo(1)  # 1 is not an array! But this code is valid.
```

After this change, we instead monkey-patch our checks to happen at the end of the `__init__` of the dataclass -- after conversion has run.

Note that this requires patrick-kidger/equinox#524. Otherwise, Equinox does conversion too late (in `_ModuleMeta.__call__`, after `__init__` has been run).
This is necessary to allow downstream libraries, like jaxtyping, to
monkey-patch in their own checks.
@patrick-kidger patrick-kidger merged commit 89f3405 into dev Sep 28, 2023
2 checks passed
@patrick-kidger patrick-kidger deleted the conversion-in-init branch September 28, 2023 01:01
patrick-kidger added a commit to patrick-kidger/jaxtyping that referenced this pull request Sep 28, 2023
Previously, using the import hook with dataclasses resulted in the `__init__` method of the dataclass being checked.
This was undesirable when using `eqx.field(converter=...)`, as the annotation didn't necessarily reflect the argument type.
A typical example was
```python
class Foo(eqx.Module):
    x: jax.Array = eqx.field(converter=jnp.ndarray)

Foo(1)  # 1 is not an array! But this code is valid.
```

After this change, we instead monkey-patch our checks to happen at the end of the `__init__` of the dataclass -- after conversion has run.

Note that this requires patrick-kidger/equinox#524. Otherwise, Equinox does conversion too late (in `_ModuleMeta.__call__`, after `__init__` has been run).
patrick-kidger added a commit to patrick-kidger/jaxtyping that referenced this pull request Oct 4, 2023
Previously, using the import hook with dataclasses resulted in the `__init__` method of the dataclass being checked.
This was undesirable when using `eqx.field(converter=...)`, as the annotation didn't necessarily reflect the argument type.
A typical example was
```python
class Foo(eqx.Module):
    x: jax.Array = eqx.field(converter=jnp.ndarray)

Foo(1)  # 1 is not an array! But this code is valid.
```

After this change, we instead monkey-patch our checks to happen at the end of the `__init__` of the dataclass -- after conversion has run.

Note that this requires patrick-kidger/equinox#524. Otherwise, Equinox does conversion too late (in `_ModuleMeta.__call__`, after `__init__` has been run).
patrick-kidger added a commit to patrick-kidger/jaxtyping that referenced this pull request Oct 4, 2023
Previously, using the import hook with dataclasses resulted in the `__init__` method of the dataclass being checked.
This was undesirable when using `eqx.field(converter=...)`, as the annotation didn't necessarily reflect the argument type.
A typical example was
```python
class Foo(eqx.Module):
    x: jax.Array = eqx.field(converter=jnp.ndarray)

Foo(1)  # 1 is not an array! But this code is valid.
```

After this change, we instead monkey-patch our checks to happen at the end of the `__init__` of the dataclass -- after conversion has run.

Note that this requires patrick-kidger/equinox#524. Otherwise, Equinox does conversion too late (in `_ModuleMeta.__call__`, after `__init__` has been run).
patrick-kidger added a commit to patrick-kidger/jaxtyping that referenced this pull request Oct 4, 2023
Previously, using the import hook with dataclasses resulted in the `__init__` method of the dataclass being checked.
This was undesirable when using `eqx.field(converter=...)`, as the annotation didn't necessarily reflect the argument type.
A typical example was
```python
class Foo(eqx.Module):
    x: jax.Array = eqx.field(converter=jnp.ndarray)

Foo(1)  # 1 is not an array! But this code is valid.
```

After this change, we instead monkey-patch our checks to happen at the end of the `__init__` of the dataclass -- after conversion has run.

Note that this requires patrick-kidger/equinox#524. Otherwise, Equinox does conversion too late (in `_ModuleMeta.__call__`, after `__init__` has been run).
patrick-kidger added a commit to patrick-kidger/jaxtyping that referenced this pull request Oct 5, 2023
Previously, using the import hook with dataclasses resulted in the `__init__` method of the dataclass being checked.
This was undesirable when using `eqx.field(converter=...)`, as the annotation didn't necessarily reflect the argument type.
A typical example was
```python
class Foo(eqx.Module):
    x: jax.Array = eqx.field(converter=jnp.ndarray)

Foo(1)  # 1 is not an array! But this code is valid.
```

After this change, we instead monkey-patch our checks to happen at the end of the `__init__` of the dataclass -- after conversion has run.

Note that this requires patrick-kidger/equinox#524. Otherwise, Equinox does conversion too late (in `_ModuleMeta.__call__`, after `__init__` has been run).
patrick-kidger added a commit to patrick-kidger/jaxtyping that referenced this pull request Oct 5, 2023
Previously, using the import hook with dataclasses resulted in the `__init__` method of the dataclass being checked.
This was undesirable when using `eqx.field(converter=...)`, as the annotation didn't necessarily reflect the argument type.
A typical example was
```python
class Foo(eqx.Module):
    x: jax.Array = eqx.field(converter=jnp.ndarray)

Foo(1)  # 1 is not an array! But this code is valid.
```

After this change, we instead monkey-patch our checks to happen at the end of the `__init__` of the dataclass -- after conversion has run.

Note that this requires patrick-kidger/equinox#524. Otherwise, Equinox does conversion too late (in `_ModuleMeta.__call__`, after `__init__` has been run).
patrick-kidger added a commit to patrick-kidger/jaxtyping that referenced this pull request Oct 5, 2023
Previously, using the import hook with dataclasses resulted in the `__init__` method of the dataclass being checked.
This was undesirable when using `eqx.field(converter=...)`, as the annotation didn't necessarily reflect the argument type.
A typical example was
```python
class Foo(eqx.Module):
    x: jax.Array = eqx.field(converter=jnp.ndarray)

Foo(1)  # 1 is not an array! But this code is valid.
```

After this change, we instead monkey-patch our checks to happen at the end of the `__init__` of the dataclass -- after conversion has run.

Note that this requires patrick-kidger/equinox#524. Otherwise, Equinox does conversion too late (in `_ModuleMeta.__call__`, after `__init__` has been run).
patrick-kidger added a commit to patrick-kidger/jaxtyping that referenced this pull request Oct 9, 2023
Previously, using the import hook with dataclasses resulted in the `__init__` method of the dataclass being checked.
This was undesirable when using `eqx.field(converter=...)`, as the annotation didn't necessarily reflect the argument type.
A typical example was
```python
class Foo(eqx.Module):
    x: jax.Array = eqx.field(converter=jnp.ndarray)

Foo(1)  # 1 is not an array! But this code is valid.
```

After this change, we instead monkey-patch our checks to happen at the end of the `__init__` of the dataclass -- after conversion has run.

Note that this requires patrick-kidger/equinox#524. Otherwise, Equinox does conversion too late (in `_ModuleMeta.__call__`, after `__init__` has been run).
patrick-kidger added a commit to patrick-kidger/jaxtyping that referenced this pull request Oct 10, 2023
Previously, using the import hook with dataclasses resulted in the `__init__` method of the dataclass being checked.
This was undesirable when using `eqx.field(converter=...)`, as the annotation didn't necessarily reflect the argument type.
A typical example was
```python
class Foo(eqx.Module):
    x: jax.Array = eqx.field(converter=jnp.ndarray)

Foo(1)  # 1 is not an array! But this code is valid.
```

After this change, we instead monkey-patch our checks to happen at the end of the `__init__` of the dataclass -- after conversion has run.

Note that this requires patrick-kidger/equinox#524. Otherwise, Equinox does conversion too late (in `_ModuleMeta.__call__`, after `__init__` has been run).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant