From 3df04c11769a939f4925871d0e14533db85c62d9 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Wed, 27 Sep 2023 18:00:56 -0700 Subject: [PATCH] Updated FAQ on sharing layers to be mention eqx.nn.Shared --- docs/faq.md | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/docs/faq.md b/docs/faq.md index d4fd692a..8e9b164f 100644 --- a/docs/faq.md +++ b/docs/faq.md @@ -31,17 +31,13 @@ class Module(eqx.Module): self.linear1 = shared_linear self.linear2 = shared_linear ``` -in which the same object is saved multiple times in the model. +in which the same object is saved multiple times in the model. However, after making some gradient updates you'll find that `self.linear1` and `self.linear2` are now different. -Don't do this! +This is intended. In Equinox+JAX, models are Py*Trees*, not DAGs. (Directed acyclic graphs.) This is basically just an arbitrary choice JAX made a long time ago in its design, but it does generally make reasoning about your code fairly easy. (You never need to track if an object is used in multiple places.) -After making some gradient updates you'll find that `self.linear1` and `self.linear2` are now different. +That said, it can sometimes happen that you really do want to tie together multiple nodes in your PyTree. If this is the case, then use [`equinox.nn.Shared`][], which provides this behaviour. (It stores things as a tree, and then inserts a reference to each node into the right place whenever you need it.) -Recall that in Equinox, models are PyTrees. Meanwhile, JAX treats all PyTrees as *trees*: that is, the same object does not appear more in the tree than once. (If it did, then it would be a *directed acyclic graph* instead.) If JAX ever encounters the same object multiple times then it will unwittingly make independent copies of the object whenever it transforms the overall PyTree. - -The resolution is simple: just don't store the same object in multiple places in the PyTree. - -You can check for whether you have duplicate nodes by using the [`equinox.tree_check`][] function. +You can also check for whether you have duplicate nodes by using the [`equinox.tree_check`][] function. ## How do I input higher-order tensors (e.g. with batch dimensions) into my model?