Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated FAQ on sharing layers to be mention eqx.nn.Shared #529

Merged
merged 1 commit into from
Sep 28, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions docs/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,13 @@ class Module(eqx.Module):
self.linear1 = shared_linear
self.linear2 = shared_linear
```
in which the same object is saved multiple times in the model.
in which the same object is saved multiple times in the model. However, after making some gradient updates you'll find that `self.linear1` and `self.linear2` are now different.

Don't do this!
This is intended. In Equinox+JAX, models are Py*Trees*, not DAGs. (Directed acyclic graphs.) This is basically just an arbitrary choice JAX made a long time ago in its design, but it does generally make reasoning about your code fairly easy. (You never need to track if an object is used in multiple places.)

After making some gradient updates you'll find that `self.linear1` and `self.linear2` are now different.
That said, it can sometimes happen that you really do want to tie together multiple nodes in your PyTree. If this is the case, then use [`equinox.nn.Shared`][], which provides this behaviour. (It stores things as a tree, and then inserts a reference to each node into the right place whenever you need it.)

Recall that in Equinox, models are PyTrees. Meanwhile, JAX treats all PyTrees as *trees*: that is, the same object does not appear more in the tree than once. (If it did, then it would be a *directed acyclic graph* instead.) If JAX ever encounters the same object multiple times then it will unwittingly make independent copies of the object whenever it transforms the overall PyTree.

The resolution is simple: just don't store the same object in multiple places in the PyTree.

You can check for whether you have duplicate nodes by using the [`equinox.tree_check`][] function.
You can also check for whether you have duplicate nodes by using the [`equinox.tree_check`][] function.

## How do I input higher-order tensors (e.g. with batch dimensions) into my model?

Expand Down
Loading