diff --git a/examples/vision_transformer.ipynb b/examples/vision_transformer.ipynb new file mode 100644 index 00000000..4bed7268 --- /dev/null +++ b/examples/vision_transformer.ipynb @@ -0,0 +1,538 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "x96qLutZxINh" + }, + "source": [ + "# Vision Transformer (ViT)\n", + "\n", + "This example builds a vision transformer model using Equinox, an implementation based on the paper: An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale.\n", + "\n", + "This is a tutorial example, if you want to use a vision transformer Equinox model in your project, please refer to this implementation [Eqxvision: Vision Transformer](https://eqxvision.readthedocs.io/en/latest/api/models/classification/vit/).\n", + "\n", + "!!! cite \"Reference\"\n", + "\n", + " [arXiv link](https://arxiv.org/abs/2010.11929)\n", + "\n", + " ```bibtex\n", + " @article{dosovitskiy2020image,\n", + " title={An image is worth 16x16 words: Transformers for image recognition at scale},\n", + " author={Dosovitskiy, Alexey and Beyer, Lucas and Kolesnikov, Alexander and Weissenborn, Dirk and Zhai, Xiaohua and Unterthiner, Thomas and Dehghani, Mostafa and Minderer, Matthias and Heigold, Georg and Gelly, Sylvain and others},\n", + " journal={arXiv preprint arXiv:2010.11929},\n", + " year={2020}\n", + " }\n", + " ```" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": { + "id": "3-NddIhhxINj" + }, + "outputs": [], + "source": [ + "import functools\n", + "\n", + "from jaxtyping import PRNGKeyArray, Array, Float\n", + "\n", + "import numpy as np\n", + "\n", + "import einops # https://github.com/arogozhnikov/einops\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import jax.random as jr\n", + "import optax # https://github.com/deepmind/optax\n", + "\n", + "# We'll use PyTorch to load the dataset.\n", + "import torch\n", + "import torchvision\n", + "import torchvision.transforms as transforms\n", + "\n", + "import equinox as eqx" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "id": "bYi-XlXRxINl" + }, + "outputs": [], + "source": [ + "# Hyperparameters\n", + "lr = 0.0001\n", + "dropout_rate = 0.1\n", + "beta1 = 0.9\n", + "beta2 = 0.999\n", + "batch_size = 64\n", + "patch_size = 4\n", + "num_patches = 64\n", + "num_steps = 100000\n", + "image_size = (32, 32, 3)\n", + "embedding_dim = 512\n", + "hidden_dim = 256\n", + "num_heads = 8\n", + "num_layers = 6\n", + "height, width, channels = image_size\n", + "num_classes = 10" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ClIEhf1dBa8x" + }, + "source": [ + "Let's first load the CIFAR10 dataset using torchvision" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Vcwi4un6CMu_", + "outputId": "fad94424-789b-46b2-f3df-41f004ddfaf7" + }, + "outputs": [], + "source": [ + "transform_train = transforms.Compose(\n", + " [\n", + " transforms.RandomCrop(32, padding=4),\n", + " transforms.Resize((height, width)),\n", + " transforms.RandomHorizontalFlip(),\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n", + " ]\n", + ")\n", + "\n", + "transform_test = transforms.Compose(\n", + " [\n", + " transforms.Resize((height, width)),\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n", + " ]\n", + ")\n", + "\n", + "train_dataset = torchvision.datasets.CIFAR10(\n", + " \"CIFAR\",\n", + " train=True,\n", + " download=True,\n", + " transform=transform_train,\n", + ")\n", + "\n", + "test_dataset = torchvision.datasets.CIFAR10(\n", + " \"CIFAR\",\n", + " train=False,\n", + " download=True,\n", + " transform=transform_test,\n", + ")\n", + "\n", + "trainloader = torch.utils.data.DataLoader(\n", + " train_dataset, batch_size=batch_size, shuffle=True, drop_last=True\n", + ")\n", + "\n", + "testloader = torch.utils.data.DataLoader(\n", + " test_dataset, batch_size=batch_size, shuffle=True, drop_last=True\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "h-Q4A5H8OQLs" + }, + "source": [ + "Now Let's start by making the patch embeddings layer that will turn images into embedded patches to be processed then by the attention layers." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "id": "SFo1GzZvxINl" + }, + "outputs": [], + "source": [ + "class PatchEmbedding(eqx.Module):\n", + " linear: eqx.nn.Embedding\n", + " patch_size: int\n", + "\n", + " def __init__(\n", + " self,\n", + " input_channels: int,\n", + " output_shape: int,\n", + " patch_size: int,\n", + " key: PRNGKeyArray,\n", + " ):\n", + " self.patch_size = patch_size\n", + "\n", + " self.linear = eqx.nn.Linear(\n", + " self.patch_size**2 * input_channels,\n", + " output_shape,\n", + " key=key,\n", + " )\n", + "\n", + " def __call__(\n", + " self, x: Float[Array, \"channels height width\"]\n", + " ) -> Float[Array, \"num_patches embedding_dim\"]:\n", + " x = einops.rearrange(\n", + " x,\n", + " \"c (h ph) (w pw) -> (h w) (c ph pw)\",\n", + " ph=self.patch_size,\n", + " pw=self.patch_size,\n", + " )\n", + " x = jax.vmap(self.linear)(x)\n", + "\n", + " return x" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dOJp3mZNOjJW" + }, + "source": [ + "After that, we implement the attention block which is the core of the transformer architecture." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "id": "mDg-L_9ixINm" + }, + "outputs": [], + "source": [ + "class AttentionBlock(eqx.Module):\n", + " layer_norm: eqx.nn.LayerNorm\n", + " attention: eqx.nn.MultiheadAttention\n", + " linear1: eqx.nn.Linear\n", + " linear2: eqx.nn.Linear\n", + " dropout: eqx.nn.Dropout\n", + "\n", + " def __init__(\n", + " self,\n", + " input_shape: int,\n", + " hidden_dim: int,\n", + " num_heads: int,\n", + " dropout_rate: float,\n", + " key: PRNGKeyArray,\n", + " ):\n", + " key1, key2, key3 = jr.split(key, 3)\n", + "\n", + " self.layer_norm = eqx.nn.LayerNorm(input_shape)\n", + " self.attention = eqx.nn.MultiheadAttention(num_heads, input_shape, key=key1)\n", + "\n", + " self.linear1 = eqx.nn.Linear(input_shape, hidden_dim, key=key2)\n", + " self.dropout = eqx.nn.Dropout(dropout_rate)\n", + " self.linear2 = eqx.nn.Linear(hidden_dim, input_shape, key=key3)\n", + "\n", + " def __call__(\n", + " self,\n", + " x: Float[Array, \"num_patches embedding_dim\"],\n", + " enable_dropout: bool,\n", + " key: PRNGKeyArray,\n", + " ) -> Float[Array, \"num_patches embedding_dim\"]:\n", + " input_x = self.layer_norm(x)\n", + " x = x + self.attention(input_x, input_x, input_x)\n", + "\n", + " input_x = self.layer_norm(x)\n", + " input_x = jax.vmap(self.linear1)(input_x)\n", + " input_x = jax.nn.gelu(input_x)\n", + "\n", + " key1, key2 = jr.split(key, num=2)\n", + "\n", + " input_x = self.dropout(input_x, inference=not enable_dropout, key=key1)\n", + " input_x = jax.vmap(self.linear2)(input_x)\n", + " input_x = self.dropout(input_x, inference=not enable_dropout, key=key2)\n", + "\n", + " x = x + input_x\n", + "\n", + " return x" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "t_RB1ip0PEk4" + }, + "source": [ + "Lastly, we build the full Vision Transformer model, which is composed of embeddings layers, a series of transformer blocks, and a classification head." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": { + "id": "nG6fLPhyQEBx" + }, + "outputs": [], + "source": [ + "class VisionTransformer(eqx.Module):\n", + " patch_embedding: PatchEmbedding\n", + " positional_embedding: jnp.ndarray\n", + " cls_token: jnp.ndarray\n", + " attention_blocks: list[AttentionBlock]\n", + " dropout: eqx.nn.Dropout\n", + " mlp: eqx.nn.Sequential\n", + " num_layers: int\n", + "\n", + " def __init__(\n", + " self,\n", + " embedding_dim: int,\n", + " hidden_dim: int,\n", + " num_heads: int,\n", + " num_layers: int,\n", + " dropout_rate: float,\n", + " patch_size: int,\n", + " num_patches: int,\n", + " num_classes: int,\n", + " key: PRNGKeyArray,\n", + " ):\n", + " key1, key2, key3, key4, key5 = jr.split(key, 5)\n", + "\n", + " self.patch_embedding = PatchEmbedding(channels, embedding_dim, patch_size, key1)\n", + "\n", + " self.positional_embedding = jr.normal(key2, (num_patches + 1, embedding_dim))\n", + "\n", + " self.cls_token = jr.normal(key3, (1, embedding_dim))\n", + "\n", + " self.num_layers = num_layers\n", + "\n", + " self.attention_blocks = [\n", + " AttentionBlock(embedding_dim, hidden_dim, num_heads, dropout_rate, key4)\n", + " for _ in range(self.num_layers)\n", + " ]\n", + "\n", + " self.dropout = eqx.nn.Dropout(dropout_rate)\n", + "\n", + " self.mlp = eqx.nn.Sequential(\n", + " [\n", + " eqx.nn.LayerNorm(embedding_dim),\n", + " eqx.nn.Linear(embedding_dim, num_classes, key=key5),\n", + " ]\n", + " )\n", + "\n", + " def __call__(\n", + " self,\n", + " x: Float[Array, \"channels height width\"],\n", + " enable_dropout: bool,\n", + " key: PRNGKeyArray,\n", + " ) -> Float[Array, \"num_classes\"]:\n", + " x = self.patch_embedding(x)\n", + "\n", + " x = jnp.concatenate((self.cls_token, x), axis=0)\n", + "\n", + " x += self.positional_embedding[\n", + " : x.shape[0]\n", + " ] # Slice to the same length as x, as the positional embedding may be longer.\n", + "\n", + " dropout_key, *attention_keys = jr.split(key, num=self.num_layers + 1)\n", + "\n", + " x = self.dropout(x, inference=not enable_dropout, key=dropout_key)\n", + "\n", + " for block, attention_key in zip(self.attention_blocks, attention_keys):\n", + " x = block(x, enable_dropout, key=attention_key)\n", + "\n", + " x = x[0] # Select the CLS token.\n", + " x = self.mlp(x)\n", + "\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": { + "id": "agBSRsXVxINn" + }, + "outputs": [], + "source": [ + "@eqx.filter_value_and_grad\n", + "def compute_grads(\n", + " model: VisionTransformer, images: jnp.ndarray, labels: jnp.ndarray, key\n", + "):\n", + " logits = jax.vmap(model, in_axes=(0, None, 0))(images, True, key)\n", + " loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels)\n", + "\n", + " return jnp.mean(loss)\n", + "\n", + "\n", + "@eqx.filter_jit\n", + "def step_model(\n", + " model: VisionTransformer,\n", + " optimizer: optax.GradientTransformation,\n", + " state: optax.OptState,\n", + " images: jnp.ndarray,\n", + " labels: jnp.ndarray,\n", + " key,\n", + "):\n", + " loss, grads = compute_grads(model, images, labels, key)\n", + " updates, new_state = optimizer.update(grads, state, model)\n", + "\n", + " model = eqx.apply_updates(model, updates)\n", + "\n", + " return model, new_state, loss\n", + "\n", + "\n", + "def train(\n", + " model: VisionTransformer,\n", + " optimizer: optax.GradientTransformation,\n", + " state: optax.OptState,\n", + " data_loader: torch.utils.data.DataLoader,\n", + " num_steps: int,\n", + " print_every: int = 1000,\n", + " key=None,\n", + "):\n", + " losses = []\n", + "\n", + " def infinite_trainloader():\n", + " while True:\n", + " yield from data_loader\n", + "\n", + " for step, batch in zip(range(num_steps), infinite_trainloader()):\n", + " images, labels = batch\n", + "\n", + " images = images.numpy()\n", + " labels = labels.numpy()\n", + "\n", + " key, *subkeys = jr.split(key, num=batch_size + 1)\n", + " subkeys = jnp.array(subkeys)\n", + "\n", + " (model, state, loss) = step_model(\n", + " model, optimizer, state, images, labels, subkeys\n", + " )\n", + "\n", + " losses.append(loss)\n", + "\n", + " if (step % print_every) == 0 or step == num_steps - 1:\n", + " print(f\"Step: {step}/{num_steps}, Loss: {loss}.\")\n", + "\n", + " return model, state, losses" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "y3Bm_Xln-rSp" + }, + "outputs": [], + "source": [ + "key = jr.PRNGKey(2003)\n", + "\n", + "model = VisionTransformer(\n", + " embedding_dim=embedding_dim,\n", + " hidden_dim=hidden_dim,\n", + " num_heads=num_heads,\n", + " num_layers=num_layers,\n", + " dropout_rate=dropout_rate,\n", + " patch_size=patch_size,\n", + " num_patches=num_patches,\n", + " num_classes=num_classes,\n", + " key=key,\n", + ")\n", + "\n", + "optimizer = optax.adamw(\n", + " learning_rate=lr,\n", + " b1=beta1,\n", + " b2=beta2,\n", + ")\n", + "\n", + "state = optimizer.init(eqx.filter(model, eqx.is_array))\n", + "\n", + "model, state, losses = train(model, optimizer, state, trainloader, num_steps, key=key)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "X4GPbpuMQEB1" + }, + "source": [ + "And now let's see how the vision transformer performs on the CIFAR10 dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "pZu5pMRZW3tF", + "outputId": "d74d06b7-0340-4e4f-b723-8e5d532aecb5", + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Accuracy: 76.2%\n" + ] + } + ], + "source": [ + "accuracies = []\n", + "\n", + "for batch in range(len(test_dataset) // batch_size):\n", + " images, labels = next(iter(testloader))\n", + "\n", + " logits = jax.vmap(functools.partial(model, enable_dropout=False))(\n", + " images.numpy(), key=jax.random.split(key, num=batch_size)\n", + " )\n", + "\n", + " predictions = jnp.argmax(logits, axis=-1)\n", + "\n", + " accuracy = jnp.mean(predictions == labels.numpy())\n", + "\n", + " accuracies.append(accuracy)\n", + "\n", + "print(f\"Accuracy: {np.sum(accuracies) / len(accuracies) * 100}%\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "K6i8KZl4Ba87" + }, + "source": [ + "Of course this is not the best accuracy you can get on CIFAR10, but with more training and hyperparameter tuning, you can get better results using the vision transformer." + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.6" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}