Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added an example for a Vision Transformer (ViT) #483

Merged
merged 4 commits into from
Sep 23, 2023

Conversation

ahmed-alllam
Copy link
Contributor

This PR adds a practical example for a Vision Transformer (ViT) in Equinox, based on the paper: An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale.

@ASEM000
Copy link
Contributor

ASEM000 commented Sep 8, 2023

Cool work @ahmed-alllam , AFAIK, you can use keras_core with jax backend to fetch mnist. The benefit is that you don't need torch - which is a bit heavy package - . for ref on this point, check here

@patrick-kidger
Copy link
Owner

patrick-kidger commented Sep 8, 2023

Thanks for the contribution! Some quick comments:

  • I actually prefer the use of PyTorch to Keras here. Mainly because we're already using it in various spots, e.g. for DataLoaders, so better to be consistent.
  • I think MNIST is probably too simple a benchmark to really be worth applying a ViT to. Could you choose something more sophisticated?
  • Can you add a reference to eqxvision's implementation of a ViT? It'd be worth pointing people towards a featureful implementation.

@ahmed-alllam
Copy link
Contributor Author

Hi @patrick-kidger! Any updates on this PR?

@patrick-kidger
Copy link
Owner

Looking over this now. GitHub doesn't yet allow us to leave comments on .ipynb files inline, so feedback as follows. All of these comments are pretty nitty, as overall this looks very clean!

  • capitalise "equinox" -> "Equinox"
  • use just list instead of typing.List. As of Python 3.9, the latter is now deprecated.
  • use from jaxtyping import PRNGKeyArray to annotate random keys, rather than that jr.PRNGKey. (The latter is a function, not a type.)
  • Can you add a comment to the GitHub repo for einops and optax, where they're imported? (C.f. the other examples.)
  • Can you add jaxtyping shape annotations to the various __call__ methods?
  • The annotations for AttentionBlock.linear{1,2} should be Linear, not Sequential.
  • prefer to do e.g. key1, key2, key3 = jr.split(key, 3), rather than indexing like keys[0] Preferring unpacking over indexing is general best practice in Python.
  • In VisionTransformer.__call__, can you do one big split, rather than multiple little splits? One big split is much more efficient at runtime.
  • You're using a mix of jax.Array and jnp.ndarray in annotations. Can you pick just one?

@ahmed-alllam
Copy link
Contributor Author

@patrick-kidger Addressed your feedback and made the necessary changes. Please review and merge if everything looks good. Thanks!

@patrick-kidger
Copy link
Owner

patrick-kidger commented Sep 22, 2023

I think the positional_embedding looks wrong -- you've defined a whole matrix and then only ever index into it at a single static value.

I think the x = x[0] line could do with a comment attached. The first time you see it, it's a bit surprising that only a single patch gets passed to the MLP.

Other than that, I think this looks very tidily done.

@ahmed-alllam
Copy link
Contributor Author

ahmed-alllam commented Sep 22, 2023

Thank you for pointing that out!

You're right about the positional_embedding. It was a small oversight on my part. The intention was indeed to slice the positional_embedding array to fit the number of patches in the image. I've corrected it to:

x += self.positional_embedding[: x.shape[0]]    # Slice to the same length as x, as the positional embedding may be longer.

I've also added a comment to clarify the x = x[0] line:

x = x[0]    # Select the CLS token.

Also, do you have any insights into why the checks are failing? They passed successfully for all previous commits, and this was just a minor change. I've gone through the logs, and it appears there might be a dependency issue stemming from PyRight.

@patrick-kidger
Copy link
Owner

Try rebasing against dev. (And making that the target branch.) This was a change in JAX; now fixed in both dev and in the JAX release (jax-ml/jax#17684).

@ahmed-alllam ahmed-alllam changed the base branch from main to dev September 22, 2023 23:51
@patrick-kidger patrick-kidger merged commit d9b018a into patrick-kidger:dev Sep 23, 2023
2 checks passed
@patrick-kidger
Copy link
Owner

Alright, LGTM! Thank you for the example. This will appear in the docs for the next release of Equinox. (Once dev -> main.)

patrick-kidger pushed a commit that referenced this pull request Sep 29, 2023
* Added an example for a vision transformer (vit)

* Changed dataset to CIFAR10, added reference to eqxvision's ViT module

* Refactored the Vision Transformer example for improved code structure and readability.

* Fixed a small issue in positional embeddings
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants