From e2782c72a755faa939cf007dafc4c15636c87680 Mon Sep 17 00:00:00 2001 From: Ben Salmon Date: Thu, 22 Aug 2024 18:11:34 +0000 Subject: [PATCH] smaller model --- 03_COSDD/solution.ipynb | 88 +++++++++++++++++++++++++++-------------- 1 file changed, 59 insertions(+), 29 deletions(-) diff --git a/03_COSDD/solution.ipynb b/03_COSDD/solution.ipynb index 81a6035..88d9dd1 100755 --- a/03_COSDD/solution.ipynb +++ b/03_COSDD/solution.ipynb @@ -25,7 +25,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -53,7 +53,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -120,7 +120,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": { "tags": [ "solution" @@ -359,12 +359,20 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "real_batch_size = 4\n", - "n_grad_batches = 4\n", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Effective batch size: 16\n" + ] + } + ], + "source": [ + "real_batch_size = 16\n", + "n_grad_batches = 1\n", "print(f\"Effective batch size: {real_batch_size * n_grad_batches}\")\n", "crop_size = (256, 256)\n", "train_split = 0.9\n", @@ -464,9 +472,9 @@ "outputs": [], "source": [ "dimensions = ... ### Insert a value here\n", - "s_code_channels = 64\n", + "s_code_channels = 16\n", "\n", - "n_layers = 6\n", + "n_layers = 4\n", "z_dims = [s_code_channels // 2] * n_layers\n", "downsampling = [1] * n_layers\n", "lvae = LadderVAE(\n", @@ -484,9 +492,9 @@ " s_code_channels=s_code_channels,\n", " kernel_size=5,\n", " noise_direction=... ### Insert a value here\n", - " n_filters=64,\n", - " n_layers=4,\n", - " n_gaussians=5,\n", + " n_filters=16,\n", + " n_layers=3,\n", + " n_gaussians=4,\n", " dimensions=dimensions,\n", ")\n", "\n", @@ -518,24 +526,35 @@ " data_mean=low_snr.mean(),\n", " data_std=low_snr.std(),\n", " n_grad_batches=n_grad_batches,\n", - " checkpointed=True,\n", + " checkpointed=False,\n", ")" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": { "tags": [ "solution" ] }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/TA-bs/miniforge3/envs/05_image_restoration/lib/python3.10/site-packages/pytorch_lightning/utilities/parsing.py:208: Attribute 'vae' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['vae'])`.\n", + "/home/TA-bs/miniforge3/envs/05_image_restoration/lib/python3.10/site-packages/pytorch_lightning/utilities/parsing.py:208: Attribute 'ar_decoder' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['ar_decoder'])`.\n", + "/home/TA-bs/miniforge3/envs/05_image_restoration/lib/python3.10/site-packages/pytorch_lightning/utilities/parsing.py:208: Attribute 's_decoder' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['s_decoder'])`.\n", + "/home/TA-bs/miniforge3/envs/05_image_restoration/lib/python3.10/site-packages/pytorch_lightning/utilities/parsing.py:208: Attribute 'direct_denoiser' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['direct_denoiser'])`.\n" + ] + } + ], "source": [ "dimensions = 2 ### Insert a value here\n", - "s_code_channels = 64\n", + "s_code_channels = 16\n", "\n", - "n_layers = 6\n", + "n_layers = 4\n", "z_dims = [s_code_channels // 2] * n_layers\n", "downsampling = [1] * n_layers\n", "lvae = LadderVAE(\n", @@ -553,8 +572,8 @@ " s_code_channels=s_code_channels,\n", " kernel_size=5,\n", " noise_direction=\"x\", ### Insert a value here\n", - " n_filters=64,\n", - " n_layers=4,\n", + " n_filters=16,\n", + " n_layers=3,\n", " n_gaussians=4,\n", " dimensions=dimensions,\n", ")\n", @@ -587,7 +606,7 @@ " data_mean=low_snr.mean(),\n", " data_std=low_snr.std(),\n", " n_grad_batches=n_grad_batches,\n", - " checkpointed=True,\n", + " checkpointed=False,\n", ")" ] }, @@ -613,7 +632,7 @@ "3. Enter `tensorboard --logdir 05_image_restoration/03_COSDD/checkpoints`\n", "4. Finally, open a browser and enter localhost:6006 in the address bar.\n", "\n", - "Once you're in tensorboard, you'll see the training logs of your model and the logs of a model that's been trained for 3.5 hours.\n", + "Once you're in tensorboard, you'll see the training logs of your model and the logs of a model that's already been trained for 3.5 hours.\n", "" ] }, @@ -702,7 +721,7 @@ " devices=1,\n", " max_epochs=max_epochs,\n", " max_time=max_time, # Remove this time limit to train the model fully\n", - " log_every_n_steps=len(train_set) // (4 * real_batch_size),\n", + " log_every_n_steps=len(train_set) // (n_grad_batches * real_batch_size),\n", " callbacks=[EarlyStopping(patience=patience, monitor=\"val/elbo\")],\n", " precision=\"bf16-mixed\",\n", ")" @@ -710,13 +729,24 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": { "tags": [ "solution" ] }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using bfloat16 Automatic Mixed Precision (AMP)\n", + "GPU available: True (cuda), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "HPU available: False, using: 0 HPUs\n" + ] + } + ], "source": [ "model_name = \"mito-confocal\" ### Insert a value here\n", "checkpoint_path = os.path.join(\"checkpoints\", model_name)\n", @@ -732,7 +762,7 @@ " devices=1,\n", " max_epochs=max_epochs,\n", " max_time=max_time, # Remove this time limit to train the model fully\n", - " log_every_n_steps=len(train_set) // (4 * real_batch_size),\n", + " log_every_n_steps=len(train_set) // (n_grad_batches * real_batch_size),\n", " callbacks=[EarlyStopping(patience=patience, monitor=\"val/elbo\")],\n", " precision=\"bf16-mixed\",\n", ")" @@ -783,7 +813,7 @@ "metadata": {}, "source": [ "### 2.1. Load test data\n", - "The images that we want to denoise are loaded here. These are the same that we used for training, but we'll only load 10 to speed up inference." + "The images that we want to denoise are loaded here. These are the same that we used for training, but we'll only load 2 to speed up inference." ] }, { @@ -793,7 +823,7 @@ "outputs": [], "source": [ "lowsnr_path = \"./../data/mito-confocal-lowsnr.tif\"\n", - "n_test_images = 10\n", + "n_test_images = 2\n", "# load the data\n", "test_set = tifffile.imread(lowsnr_path)\n", "test_set = test_set[:n_test_images, np.newaxis]\n",