diff --git a/examples/example_training_LVAE_split.ipynb b/examples/example_training_LVAE_split.ipynb new file mode 100644 index 00000000..d83d5f75 --- /dev/null +++ b/examples/example_training_LVAE_split.ipynb @@ -0,0 +1,439 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import socket\n", + "from pathlib import Path\n", + "from typing import Optional, Literal, Union\n", + "\n", + "import numpy as np\n", + "import torch \n", + "from torch.utils.data import Dataset, DataLoader\n", + "from pytorch_lightning import Trainer\n", + "from pytorch_lightning.loggers import WandbLogger\n", + "from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, EarlyStopping\n", + "\n", + "from careamics.config import VAEAlgorithmConfig\n", + "from careamics.config.architectures import LVAEModel\n", + "from careamics.config.callback_model import CheckpointModel, EarlyStoppingModel\n", + "from careamics.config.likelihood_model import (\n", + " GaussianLikelihoodConfig,\n", + " NMLikelihoodConfig,\n", + ")\n", + "from careamics.config.nm_model import GaussianMixtureNMConfig, MultiChannelNMConfig\n", + "from careamics.lightning import VAEModule\n", + "from careamics.models.lvae.noise_models import noise_model_factory" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Set some parameters for the current training simulation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "img_size: int = 64\n", + "\"\"\"Spatial size of the input image.\"\"\"\n", + "target_channels: int = 2\n", + "\"\"\"Number of channels in the target image.\"\"\"\n", + "multiscale_count: int = 5\n", + "\"\"\"The number of LC inputs plus one (the actual input).\"\"\"\n", + "predict_logvar: Optional[Literal[\"pixelwise\"]] = \"pixelwise\"\n", + "\"\"\"Whether to compute also the log-variance as LVAE output.\"\"\"\n", + "loss_type: Optional[Literal[\"musplit\", \"denoisplit\", \"denoisplit_musplit\"]] = \"musplit\"\n", + "\"\"\"The type of reconstruction loss (i.e., likelihood) to use.\"\"\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Create `Dataset` and `Dataloader`" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1.1. Dummy Data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class DummyDataset(Dataset):\n", + " def __init__(\n", + " self, \n", + " img_size: int = 64, \n", + " target_ch: int = 1,\n", + " multiscale_count: int = 1,\n", + " ):\n", + " self.num_samples = 100\n", + " self.img_size = img_size\n", + " self.target_ch = target_ch\n", + " self.multiscale_count = multiscale_count\n", + " \n", + " def __len__(self):\n", + " return self.num_samples\n", + " \n", + " def __getitem__(self, idx: int):\n", + " input_ = torch.randn(self.multiscale_count, self.img_size, self.img_size)\n", + " target = torch.randn(self.target_ch, self.img_size, self.img_size)\n", + " return input_, target\n", + "\n", + "def dummy_dataloader(\n", + " batch_size: int = 1,\n", + " img_size: int = 64,\n", + " target_ch: int = 1,\n", + " multiscale_count: int = 1,\n", + "):\n", + " dataset = DummyDataset(\n", + " img_size=img_size,\n", + " target_ch=target_ch,\n", + " multiscale_count=multiscale_count,\n", + " )\n", + " return DataLoader(dataset, batch_size=batch_size, num_workers=3, shuffle=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dloader = dummy_dataloader(\n", + " img_size=img_size,\n", + " target_ch=target_channels,\n", + " multiscale_count=multiscale_count,\n", + ")\n", + "input_, target = next(iter(dloader))\n", + "input_.shape, target.shape" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1.2. Real Data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Instantiate the lightning module" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def create_dummy_noise_model(\n", + " save_path: Optional[Union[Path, str]] = None,\n", + " n_gaussians: int = 3,\n", + " n_coeffs: int = 3,\n", + ") -> Path:\n", + " weights = np.random.rand(3*n_gaussians, n_coeffs)\n", + " nm_dict = {\n", + " \"trained_weight\": weights,\n", + " \"min_signal\": np.array([0]),\n", + " \"max_signal\": np.array([2**16 - 1]),\n", + " \"min_sigma\": 0.125,\n", + " }\n", + " out_path = Path(save_path) / \"dummy_noise_model.npz\"\n", + " np.savez(out_path, **nm_dict)\n", + " return out_path" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def create_split_lightning_model(\n", + " algorithm: str,\n", + " loss_type: str,\n", + " multiscale_count: int = 1,\n", + " predict_logvar: Optional[Literal[\"pixelwise\"]] = None,\n", + " target_ch: int = 1,\n", + " NM_path: Optional[Path] = None,\n", + ") -> VAEModule:\n", + " \"\"\"Instantiate the muSplit lightining model.\"\"\"\n", + " lvae_config = LVAEModel(\n", + " architecture=\"LVAE\",\n", + " input_shape=64,\n", + " multiscale_count=multiscale_count,\n", + " z_dims=[128, 128, 128, 128],\n", + " output_channels=target_ch,\n", + " predict_logvar=predict_logvar,\n", + " )\n", + "\n", + " # gaussian likelihood\n", + " if loss_type in [\"musplit\", \"denoisplit_musplit\"]:\n", + " gaussian_lik_config = GaussianLikelihoodConfig(\n", + " predict_logvar=predict_logvar,\n", + " logvar_lowerbound=0.0,\n", + " )\n", + " else:\n", + " gaussian_lik_config = None\n", + " # noise model likelihood\n", + " if loss_type in [\"denoisplit\", \"denoisplit_musplit\"]:\n", + " if NM_path is None:\n", + " NM_path = create_dummy_noise_model(Path(\"./\"), 3, 3)\n", + " gmm = GaussianMixtureNMConfig(\n", + " model_type=\"GaussianMixtureNoiseModel\",\n", + " path=NM_path,\n", + " )\n", + " noise_model_config = MultiChannelNMConfig(noise_models=[gmm] * target_ch)\n", + " nm = noise_model_factory(noise_model_config)\n", + " nm_lik_config = NMLikelihoodConfig(noise_model=nm)\n", + " else:\n", + " noise_model_config = None\n", + " nm_lik_config = None\n", + "\n", + " vae_config = VAEAlgorithmConfig(\n", + " algorithm_type=\"vae\",\n", + " algorithm=algorithm,\n", + " loss=loss_type,\n", + " model=lvae_config,\n", + " gaussian_likelihood_model=gaussian_lik_config,\n", + " noise_model=noise_model_config,\n", + " noise_model_likelihood_model=nm_lik_config,\n", + " )\n", + "\n", + " return VAEModule(\n", + " algorithm_config=vae_config,\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "algo = \"musplit\" if loss_type == \"musplit\" else \"denoisplit\"\n", + "lightning_model = create_split_lightning_model(\n", + " algorithm=algo,\n", + " loss_type=loss_type,\n", + " multiscale_count=multiscale_count,\n", + " predict_logvar=predict_logvar,\n", + " target_ch=target_channels,\n", + " NM_path=None\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Set utils for training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from datetime import datetime\n", + "\n", + "from careamics.lvae_training.train_utils import get_new_model_version\n", + "\n", + "def get_new_model_version(model_dir: Union[Path, str]) -> int:\n", + " \"\"\"Create a unique version ID for a new model run.\"\"\"\n", + " versions = []\n", + " for version_dir in os.listdir(model_dir):\n", + " try:\n", + " versions.append(int(version_dir))\n", + " except:\n", + " print(\n", + " f\"Invalid subdirectory:{model_dir}/{version_dir}. Only integer versions are allowed\"\n", + " )\n", + " exit()\n", + " if len(versions) == 0:\n", + " return \"0\"\n", + " return f\"{max(versions) + 1}\"\n", + "\n", + "def get_workdir(\n", + " root_dir: str,\n", + " model_name: str,\n", + ") -> tuple[Path, Path]:\n", + " \"\"\"Get the workdir for the current model.\n", + " \n", + " It has the following structure: \"root_dir/YYMM/model_name/version\"\n", + " \"\"\"\n", + " rel_path = datetime.now().strftime(\"%y%m\")\n", + " cur_workdir = os.path.join(root_dir, rel_path)\n", + " Path(cur_workdir).mkdir(exist_ok=True)\n", + "\n", + " rel_path = os.path.join(rel_path, model_name)\n", + " cur_workdir = os.path.join(root_dir, rel_path)\n", + " Path(cur_workdir).mkdir(exist_ok=True)\n", + "\n", + " rel_path = os.path.join(rel_path, get_new_model_version(cur_workdir))\n", + " cur_workdir = os.path.join(root_dir, rel_path)\n", + " try:\n", + " Path(cur_workdir).mkdir(exist_ok=False)\n", + " except FileExistsError:\n", + " print(\n", + " f\"Workdir {cur_workdir} already exists.\"\n", + " )\n", + " return cur_workdir, rel_path" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ROOT_DIR = \"/group/jug/federico/careamics_training/refac_v2/\"\n", + "workdir, exp_tag = get_workdir(ROOT_DIR, \"dummy_debugging\")\n", + "print(f\"Current workdir: {workdir}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define the logger\n", + "custom_logger = WandbLogger(\n", + " name=os.path.join(socket.gethostname(), exp_tag),\n", + " save_dir=workdir,\n", + " project=\"careamics_debugging_LVAE\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define callbacks (e.g., ModelCheckpoint, EarlyStopping, etc.)\n", + "early_stopping_config = EarlyStoppingModel(\n", + " monitor=\"val_loss\",\n", + " min_delta=1e-6,\n", + " patience=10,\n", + " mode=\"min\",\n", + " verbose=True,\n", + ")\n", + "checkpoint_config = CheckpointModel(\n", + " monitor=\"val_loss\",\n", + " save_top_k=2,\n", + " mode=\"min\",\n", + ")\n", + "custom_callbacks = [\n", + " EarlyStopping(**early_stopping_config.model_dump()), \n", + " ModelCheckpoint(**checkpoint_config.model_dump()),\n", + " LearningRateMonitor(logging_interval=\"epoch\")\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Save AlgorithmConfig\n", + "with open(os.path.join(workdir, \"algorithm_config.json\"), \"w\") as f:\n", + " f.write(lightning_model.algorithm_config.model_dump_json())\n", + "\n", + "custom_logger.experiment.config.update(\n", + " lightning_model.algorithm_config.model_dump() \n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Train the model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trainer = Trainer(\n", + " max_epochs=10,\n", + " accelerator=\"cpu\",\n", + " enable_progress_bar=True,\n", + " logger=custom_logger,\n", + " callbacks=custom_callbacks,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trainer.fit(\n", + " model=lightning_model,\n", + " train_dataloaders=dloader,\n", + " val_dataloaders=dloader,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "train_lvae", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.19" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/careamics/config/architectures/lvae_model.py b/src/careamics/config/architectures/lvae_model.py index f3fd507b..0f8c0f31 100644 --- a/src/careamics/config/architectures/lvae_model.py +++ b/src/careamics/config/architectures/lvae_model.py @@ -8,6 +8,7 @@ from .architecture_model import ArchitectureModel +# TODO: it is quite confusing to call this LVAEModel, as it is basically a config class LVAEModel(ArchitectureModel): """LVAE model.""" diff --git a/src/careamics/config/nm_model.py b/src/careamics/config/nm_model.py index ffcfa607..5e8c7f34 100644 --- a/src/careamics/config/nm_model.py +++ b/src/careamics/config/nm_model.py @@ -14,7 +14,10 @@ class GaussianMixtureNMConfig(BaseModel): """Gaussian mixture noise model.""" model_config = ConfigDict( - validate_assignment=True, arbitrary_types_allowed=True, extra="allow" + protected_namespaces=(), + validate_assignment=True, + arbitrary_types_allowed=True, + extra="allow", ) # model type model_type: Literal["GaussianMixtureNoiseModel"] diff --git a/src/careamics/config/vae_algorithm_model.py b/src/careamics/config/vae_algorithm_model.py index d9bddb36..fc021504 100644 --- a/src/careamics/config/vae_algorithm_model.py +++ b/src/careamics/config/vae_algorithm_model.py @@ -30,12 +30,18 @@ class VAEAlgorithmConfig(BaseModel): Attributes ---------- - algorithm : Literal["n2v", "custom"] + algorithm : algorithm: Literal["musplit", "denoisplit", "custom"] Algorithm to use. - loss : Literal["n2v", "mae", "mse"] + loss : Literal["musplit", "denoisplit", "denoisplit_musplit"] Loss function to use. - model : Union[UNetModel, LVAEModel, CustomModel] + model : Union[LVAEModel, CustomModel] Model architecture to use. + noise_model: Optional[MultiChannelNmModel] + Noise model to use. + noise_model_likelihood_model: Optional[NMLikelihoodModel] + Noise model likelihood model to use. + gaussian_likelihood_model: Optional[GaussianLikelihoodModel] + Gaussian likelihood model to use. optimizer : OptimizerModel, optional Optimizer to use. lr_scheduler : LrSchedulerModel, optional @@ -66,9 +72,10 @@ class VAEAlgorithmConfig(BaseModel): # - values can still be passed as strings and they will be cast to Enum algorithm_type: Literal["vae"] algorithm: Literal["musplit", "denoisplit", "custom"] - loss: Literal["musplit", "denoisplit"] + loss: Literal["musplit", "denoisplit", "denoisplit_musplit"] model: Union[LVAEModel, CustomModel] = Field(discriminator="architecture") + # TODO: these are configs, change naming of attrs noise_model: Optional[MultiChannelNMConfig] = None noise_model_likelihood_model: Optional[NMLikelihoodConfig] = None gaussian_likelihood_model: Optional[GaussianLikelihoodConfig] = None @@ -94,25 +101,63 @@ def algorithm_cross_validation(self: Self) -> Self: raise ValueError( f"Algorithm {self.algorithm} only supports loss `musplit`." ) - if self.model.predict_logvar != "pixelwise": - raise ValueError( - "Algorithm `musplit` only supports `predict_logvar` as `pixelwise`." - ) - # TODO add more checks if self.algorithm == SupportedAlgorithm.DENOISPLIT: - if self.loss != SupportedLoss.DENOISPLIT: + if self.loss not in [ + SupportedLoss.DENOISPLIT, + SupportedLoss.DENOISPLIT_MUSPLIT, + ]: raise ValueError( - f"Algorithm {self.algorithm} only supports loss `denoisplit`." + f"Algorithm {self.algorithm} only supports loss `denoisplit` " + "or `denoisplit_musplit." ) - if self.model.predict_logvar is not None: + if ( + self.loss == SupportedLoss.DENOISPLIT + and self.model.predict_logvar is not None + ): raise ValueError( - "Algorithm `denoisplit` only supports `predict_logvar` as `None`." + "Algorithm `denoisplit` with loss `denoisplit` only supports " + "`predict_logvar` as `None`." ) - if self.noise_model is None: raise ValueError("Algorithm `denoisplit` requires a noise model.") - # TODO: what if algorithm is not musplit or denoisplit + # TODO: what if algorithm is not musplit or denoisplit (HDN?) + return self + + @model_validator(mode="after") + def output_channels_validation(self: Self) -> Self: + """Validate the consistency between number of out channels and noise models. + + Returns + ------- + Self + The validated model. + """ + if self.noise_model is not None: + assert self.model.output_channels == len(self.noise_model.noise_models), ( + f"Number of output channels ({self.model.output_channels}) must match " + f"the number of noise models ({len(self.noise_model.noise_models)})." + ) + return self + + @model_validator(mode="after") + def predict_logvar_validation(self: Self) -> Self: + """Validate the consistency of `predict_logvar` throughout the model. + + Returns + ------- + Self + The validated model. + """ + if self.gaussian_likelihood_model is not None: + assert ( + self.model.predict_logvar + == self.gaussian_likelihood_model.predict_logvar + ), ( + f"Model `predict_logvar` ({self.model.predict_logvar}) must match " + "Gaussian likelihood model `predict_logvar` " + f"({self.gaussian_likelihood_model.predict_logvar}).", + ) return self def __str__(self) -> str: diff --git a/src/careamics/lightning/lightning_module.py b/src/careamics/lightning/lightning_module.py index d8619c6a..dc20ba7a 100644 --- a/src/careamics/lightning/lightning_module.py +++ b/src/careamics/lightning/lightning_module.py @@ -1,7 +1,8 @@ """CAREamics Lightning module.""" -from typing import Any, Literal, Optional, Union +from typing import Any, Callable, Literal, Optional, Union +import numpy as np import pytorch_lightning as L from torch import Tensor, nn @@ -27,6 +28,7 @@ ) from careamics.models.model_factory import model_factory from careamics.transforms import Denormalize, ImageRestorationTTA +from careamics.utils.metrics import RunningPSNR, scale_invariant_psnr from careamics.utils.torch_utils import get_optimizer, get_scheduler NoiseModel = Union[GaussianMixtureNoiseModel, MultiChannelNoiseModel] @@ -225,7 +227,7 @@ class VAEModule(L.LightningModule): Parameters ---------- - algorithm_config : Union[AlgorithmModel, dict] + algorithm_config : Union[VAEAlgorithmConfig, dict] Algorithm configuration. Attributes @@ -261,6 +263,9 @@ def __init__(self, algorithm_config: Union[VAEAlgorithmConfig, dict]) -> None: else algorithm_config ) + # TODO: log algorithm config + # self.save_hyperparameters(self.algorithm_config.model_dump()) + # create model and loss function self.model: nn.Module = model_factory(self.algorithm_config.model) self.noise_model: NoiseModel = noise_model_factory( @@ -285,28 +290,40 @@ def __init__(self, algorithm_config: Union[VAEAlgorithmConfig, dict]) -> None: self.lr_scheduler_name = self.algorithm_config.lr_scheduler.name self.lr_scheduler_params = self.algorithm_config.lr_scheduler.parameters - def forward(self, x: Any) -> Any: + # initialize running PSNR + self.running_psnr = [ + RunningPSNR() for _ in range(self.algorithm_config.model.output_channels) + ] + + def forward(self, x: Tensor) -> tuple[Tensor, dict[str, Any]]: """Forward pass. Parameters ---------- - x : Any - Input tensor. + x : Tensor + Input tensor of shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the + number of lateral inputs. Returns ------- - Any - Output tensor. + tuple[Tensor, dict[str, Any]] + A tuple with the output tensor and additional data from the top-down pass. """ return self.model(x) # TODO Different model can have more than one output - def training_step(self, batch: Tensor, batch_idx: Any) -> Any: + def training_step( + self, batch: tuple[Tensor, Tensor], batch_idx: Any + ) -> Optional[dict[str, Tensor]]: """Training step. Parameters ---------- - batch : Tensor - Input batch. + batch : tuple[Tensor, Tensor] + Input batch. It is a tuple with the input tensor and the target tensor. + The input tensor has shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the + number of lateral inputs. The target tensor has shape (B, C, [Z], Y, X), + where C is the number of target channels (e.g., 1 in HDN, >1 in + muSplit/denoiSplit). batch_idx : Any Batch index. @@ -315,9 +332,10 @@ def training_step(self, batch: Tensor, batch_idx: Any) -> Any: Any Loss value. """ - x, *aux = batch # TODO: check what is a `batch` + x, target = batch + + # Forward pass out = self.model(x) - target = aux[0] # Update loss parameters # TODO rethink loss parameters @@ -326,31 +344,49 @@ def training_step(self, batch: Tensor, batch_idx: Any) -> Any: # Compute loss loss = self.loss_func(out, target, self.loss_parameters) # TODO ugly ? - self.log_dict(loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) + # Logging + # TODO: implement a separate logging method? + self.log_dict(loss, on_step=True, on_epoch=True) + # self.log("lr", self, on_epoch=True) return loss - def validation_step(self, batch: Tensor, batch_idx: Any) -> None: + def validation_step(self, batch: tuple[Tensor, Tensor], batch_idx: Any) -> None: """Validation step. Parameters ---------- - batch : Tensor - Input batch. + batch : tuple[Tensor, Tensor] + Input batch. It is a tuple with the input tensor and the target tensor. + The input tensor has shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the + number of lateral inputs. The target tensor has shape (B, C, [Z], Y, X), + where C is the number of target channels (e.g., 1 in HDN, >1 in + muSplit/denoiSplit). batch_idx : Any Batch index. """ - x, *aux = batch + x, target = batch + + # Forward pass out = self.model(x) - val_loss = self.loss_func(out, *aux) - # log validation loss - self.log_dict( - val_loss, - on_step=False, - on_epoch=True, - prog_bar=True, - logger=True, - ) + # Compute loss + loss = self.loss_func(out, target, self.loss_parameters) + + # Logging + # Rename val_loss dict + loss = {"_".join(["val", k]): v for k, v in loss.items()} + self.log_dict(loss, on_epoch=True, prog_bar=True) + curr_psnr = self.compute_val_psnr(out, target) + for i, psnr in enumerate(curr_psnr): + self.log(f"val_psnr_ch{i+1}_batch", psnr, on_epoch=True) + + def on_validation_epoch_end(self) -> None: + """Validation epoch end.""" + psnr_ = self.reduce_running_psnr() + if psnr_ is not None: + self.log("val_psnr", psnr_, on_epoch=True, prog_bar=True) + else: + self.log("val_psnr", 0.0, on_epoch=True, prog_bar=True) def predict_step(self, batch: Tensor, batch_idx: Any) -> Any: """Prediction step. @@ -420,7 +456,99 @@ def configure_optimizers(self) -> Any: "monitor": "val_loss", # otherwise triggers MisconfigurationException } + # TODO: find a way to move the following methods to a separate module + # TODO: this same operation is done in many other places, like in loss_func + # should we refactor LadderVAE so that it already outputs + # tuple(`mean`, `logvar`, `td_data`)? + def get_reconstructed_tensor( + self, model_outputs: tuple[Tensor, dict[str, Any]] + ) -> Tensor: + """Get the reconstructed tensor from the LVAE model outputs. + + Parameters + ---------- + model_outputs : tuple[Tensor, dict[str, Any]] + Model outputs. It is a tuple with a tensor representing the predicted mean + and (optionally) logvar, and the top-down data dictionary. + + Returns + ------- + Tensor + Reconstructed tensor, i.e., the predicted mean. + """ + predictions, _ = model_outputs + if self.model.predict_logvar is None: + return predictions + elif self.model.predict_logvar == "pixelwise": + return predictions.chunk(2, dim=1)[0] + + def compute_val_psnr( + self, + model_output: tuple[Tensor, dict[str, Any]], + target: Tensor, + psnr_func: Callable = scale_invariant_psnr, + ) -> list[float]: + """Compute the PSNR for the current validation batch. + + Parameters + ---------- + model_output : tuple[Tensor, dict[str, Any]] + Model output, a tuple with the predicted mean and (optionally) logvar, + and the top-down data dictionary. + target : Tensor + Target tensor. + psnr_func : Callable, optional + PSNR function to use, by default `scale_invariant_psnr`. + Returns + ------- + list[float] + PSNR for each channel in the current batch. + """ + out_channels = target.shape[1] + + # get the reconstructed image + recons_img = self.get_reconstructed_tensor(model_output) + + # update running psnr + for i in range(out_channels): + self.running_psnr[i].update(rec=recons_img[:, i], tar=target[:, i]) + + # compute psnr for each channel in the current batch + # TODO: this doesn't need do be a method of this class + # and hence can be moved to a separate module + return [ + psnr_func( + gt=target[:, i].clone().detach().cpu().numpy(), + pred=recons_img[:, i].clone().detach().cpu().numpy(), + ) + for i in range(out_channels) + ] + + def reduce_running_psnr(self) -> Optional[float]: + """Reduce the running PSNR statistics and reset the running PSNR. + + Returns + ------- + Optional[float] + Running PSNR averaged over the different output channels. + """ + psnr_arr = [] # type: ignore + for i in range(len(self.running_psnr)): + psnr = self.running_psnr[i].get() + if psnr is None: + psnr_arr = None # type: ignore + break + psnr_arr.append(psnr.cpu().numpy()) + self.running_psnr[i].reset() + # TODO: this line forces it to be a method of this class + # alternative is returning also the reset `running_psnr` + if psnr_arr is not None: + psnr = np.mean(psnr_arr) + return psnr + + +# TODO: make this LVAE compatible (?) def create_careamics_module( algorithm_type: Literal["fcn"], algorithm: Union[SupportedAlgorithm, str], @@ -432,7 +560,7 @@ def create_careamics_module( lr_scheduler: Union[SupportedScheduler, str] = "ReduceLROnPlateau", lr_scheduler_parameters: Optional[dict] = None, ) -> Union[FCNModule, VAEModule]: - """Create a CAREamics Lithgning module. + """Create a CAREamics Lightning module. This function exposes parameters used to create an AlgorithmModel instance, triggering parameters validation. diff --git a/src/careamics/losses/__init__.py b/src/careamics/losses/__init__.py index f79497e5..d2372178 100644 --- a/src/careamics/losses/__init__.py +++ b/src/careamics/losses/__init__.py @@ -7,8 +7,9 @@ "n2v_loss", "denoisplit_loss", "musplit_loss", + "denoisplit_musplit_loss", ] from .fcn.losses import mae_loss, mse_loss, n2v_loss from .loss_factory import loss_factory -from .lvae.losses import denoisplit_loss, musplit_loss +from .lvae.losses import denoisplit_loss, denoisplit_musplit_loss, musplit_loss diff --git a/src/careamics/models/lvae/likelihoods.py b/src/careamics/models/lvae/likelihoods.py index 3771bc01..47d1a4ec 100644 --- a/src/careamics/models/lvae/likelihoods.py +++ b/src/careamics/models/lvae/likelihoods.py @@ -40,6 +40,9 @@ def likelihood_factory( nn.Module The likelihood module. """ + if config is None: + return None + if isinstance(config, GaussianLikelihoodConfig): return GaussianLikelihood( predict_logvar=config.predict_logvar, diff --git a/src/careamics/models/lvae/noise_models.py b/src/careamics/models/lvae/noise_models.py index deac604e..3c1d885f 100644 --- a/src/careamics/models/lvae/noise_models.py +++ b/src/careamics/models/lvae/noise_models.py @@ -36,24 +36,23 @@ def noise_model_factory( """ if model_config: noise_models = [] - for nm in model_config.noise_models: - if nm.path: - if nm.model_type == "GaussianMixtureNoiseModel": - noise_models.append(GaussianMixtureNoiseModel(nm)) + for nm_config in model_config.noise_models: + if nm_config.path: + if nm_config.model_type == "GaussianMixtureNoiseModel": + noise_models.append(GaussianMixtureNoiseModel(nm_config)) else: raise NotImplementedError( - f"Model {nm.model_type} is not implemented" + f"Model {nm_config.model_type} is not implemented" ) else: # TODO this means signal/obs are provided. Controlled in pydantic model # TODO train a new model. Config should always be provided? - if nm.model_type == "GaussianMixtureNoiseModel": - # TODO one model for each channel all make this choise inside the model? - trained_nm = train_gm_noise_model(nm) + if nm_config.model_type == "GaussianMixtureNoiseModel": + trained_nm = train_gm_noise_model(nm_config) noise_models.append(trained_nm) else: raise NotImplementedError( - f"Model {nm.model_type} is not implemented" + f"Model {nm_config.model_type} is not implemented" ) return MultiChannelNoiseModel(noise_models) return None @@ -77,7 +76,7 @@ def train_gm_noise_model( # TODO any training params ? Different channels ? noise_model = GaussianMixtureNoiseModel(model_config) # TODO revisit config unpacking - noise_model.train(noise_model.signal, noise_model.observation) + noise_model.train_noise_model(noise_model.signal, noise_model.observation) return noise_model @@ -459,7 +458,7 @@ def forward(self, x, y): return x, y # TODO taken from pn2v. Ashesh needs to clarify this - def train( + def train_noise_model( self, signal, observation, diff --git a/src/careamics/models/model_factory.py b/src/careamics/models/model_factory.py index 56a47958..c235a6bc 100644 --- a/src/careamics/models/model_factory.py +++ b/src/careamics/models/model_factory.py @@ -4,20 +4,34 @@ Model creation factory functions. """ -from typing import Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Union import torch -from ..config.architectures import CustomModel, UNetModel, get_custom_model -from ..config.support import SupportedArchitecture -from ..utils import get_logger -from .unet import UNet +from careamics.config.architectures import ( + CustomModel, + get_custom_model, +) +from careamics.config.support import SupportedArchitecture +from careamics.models.lvae import LadderVAE as LVAE +from careamics.models.unet import UNet +from careamics.utils import get_logger + +if TYPE_CHECKING: + from careamics.config.architectures import ( + CustomModel, + LVAEModel, + UNetModel, + ) + logger = get_logger(__name__) def model_factory( - model_configuration: Union[UNetModel, CustomModel], + model_configuration: Union[UNetModel, LVAEModel, CustomModel], ) -> torch.nn.Module: """ Deep learning model factory. @@ -41,6 +55,8 @@ def model_factory( """ if model_configuration.architecture == SupportedArchitecture.UNET: return UNet(**model_configuration.model_dump()) + elif model_configuration.architecture == SupportedArchitecture.LVAE: + return LVAE(**model_configuration.model_dump()) elif model_configuration.architecture == SupportedArchitecture.CUSTOM: assert isinstance(model_configuration, CustomModel) model = get_custom_model(model_configuration.name) diff --git a/src/careamics/utils/metrics.py b/src/careamics/utils/metrics.py index 2389316d..ad44914c 100644 --- a/src/careamics/utils/metrics.py +++ b/src/careamics/utils/metrics.py @@ -4,7 +4,7 @@ This module contains various metrics and a metrics tracking class. """ -from typing import Union +from typing import Optional, Union import numpy as np import torch @@ -113,3 +113,76 @@ def scale_invariant_psnr( range_parameter = (np.max(gt) - np.min(gt)) / np.std(gt) gt_ = _zero_mean(gt) / np.std(gt) return psnr(_zero_mean(gt_), _fix(gt_, pred), range_parameter) + + +class RunningPSNR: + """Compute the running PSNR during validation step in training. + + This class allows to compute the PSNR on the entire validation set + one batch at the time. + + Attributes + ---------- + N : int + Number of elements seen so far during the epoch. + mse_sum : float + Running sum of the MSE over the N elements seen so far. + max : float + Running max value of the N target images seen so far. + min : float + Running min value of the N target images seen so far. + """ + + def __init__(self): + """Constructor.""" + self.N = None + self.mse_sum = None + self.max = self.min = None + self.reset() + + def reset(self): + """Reset the running PSNR computation. + + Usually called at the end of each epoch. + """ + self.mse_sum = 0 + self.N = 0 + self.max = self.min = None + + def update(self, rec: torch.Tensor, tar: torch.Tensor) -> None: + """Update the running PSNR statistics given a new batch. + + Parameters + ---------- + rec : torch.Tensor + Reconstructed batch. + tar : torch.Tensor + Target batch. + """ + ins_max = torch.max(tar).item() + ins_min = torch.min(tar).item() + if self.max is None: + assert self.min is None + self.max = ins_max + self.min = ins_min + else: + self.max = max(self.max, ins_max) + self.min = min(self.min, ins_min) + + mse = (rec - tar) ** 2 + elementwise_mse = torch.mean(mse.view(len(mse), -1), dim=1) + self.mse_sum += torch.nansum(elementwise_mse) + self.N += len(elementwise_mse) - torch.sum(torch.isnan(elementwise_mse)) + + def get(self) -> Optional[torch.Tensor]: + """Get the actual PSNR value given the running statistics. + + Returns + ------- + Optional[torch.Tensor] + PSNR value. + """ + if self.N == 0 or self.N is None: + return None + rmse = torch.sqrt(self.mse_sum / self.N) + return 20 * torch.log10((self.max - self.min) / rmse) diff --git a/tests/lightning/test_LVAE_lightning_module.py b/tests/lightning/test_LVAE_lightning_module.py new file mode 100644 index 00000000..985f206e --- /dev/null +++ b/tests/lightning/test_LVAE_lightning_module.py @@ -0,0 +1,579 @@ +from contextlib import nullcontext as does_not_raise +from pathlib import Path +from typing import Callable, Literal, Optional + +import numpy as np +import pytest +import torch +from pydantic import ValidationError +from pytorch_lightning import Trainer +from torch.utils.data import DataLoader, Dataset + +from careamics.config import VAEAlgorithmConfig +from careamics.config.architectures import LVAEModel +from careamics.config.likelihood_model import ( + GaussianLikelihoodConfig, + NMLikelihoodConfig, +) +from careamics.config.nm_model import GaussianMixtureNMConfig, MultiChannelNMConfig +from careamics.lightning import VAEModule +from careamics.losses import denoisplit_loss, denoisplit_musplit_loss, musplit_loss +from careamics.models.lvae.likelihoods import GaussianLikelihood, NoiseModelLikelihood +from careamics.models.lvae.noise_models import ( + MultiChannelNoiseModel, + noise_model_factory, +) +from careamics.utils.metrics import RunningPSNR + + +# TODO: move to conftest.py as pytest.fixture +def create_dummy_noise_model( + tmp_path: Path, + n_gaussians: int = 3, + n_coeffs: int = 3, +) -> None: + weights = np.random.rand(3 * n_gaussians, n_coeffs) + nm_dict = { + "trained_weight": weights, + "min_signal": np.array([0]), + "max_signal": np.array([2**16 - 1]), + "min_sigma": 0.125, + } + np.savez(tmp_path / "dummy_noise_model.npz", **nm_dict) + + +# TODO: move to conftest.py as pytest.fixture +# it can be split into modules for more clarity (?) +def create_split_lightning_model( + tmp_path: Path, + algorithm: str, + loss_type: str, + multiscale_count: int = 1, + predict_logvar: Optional[Literal["pixelwise"]] = None, + target_ch: int = 1, +) -> VAEModule: + """Instantiate the muSplit lightining model.""" + lvae_config = LVAEModel( + architecture="LVAE", + input_shape=64, + multiscale_count=multiscale_count, + z_dims=[128, 128, 128, 128], + output_channels=target_ch, + predict_logvar=predict_logvar, + ) + + # gaussian likelihood + if loss_type in ["musplit", "denoisplit_musplit"]: + gaussian_lik_config = GaussianLikelihoodConfig( + predict_logvar=predict_logvar, + logvar_lowerbound=0.0, + ) + else: + gaussian_lik_config = None + # noise model likelihood + if loss_type in ["denoisplit", "denoisplit_musplit"]: + create_dummy_noise_model(tmp_path, 3, 3) + gmm = GaussianMixtureNMConfig( + model_type="GaussianMixtureNoiseModel", + path=tmp_path / "dummy_noise_model.npz", + ) + noise_model_config = MultiChannelNMConfig(noise_models=[gmm] * target_ch) + nm = noise_model_factory(noise_model_config) + nm_lik_config = NMLikelihoodConfig(noise_model=nm) + else: + noise_model_config = None + nm_lik_config = None + + vae_config = VAEAlgorithmConfig( + algorithm_type="vae", + algorithm=algorithm, + loss=loss_type, + model=lvae_config, + gaussian_likelihood_model=gaussian_lik_config, + noise_model=noise_model_config, + noise_model_likelihood_model=nm_lik_config, + ) + + return VAEModule( + algorithm_config=vae_config, + ) + + +class DummyDataset(Dataset): + def __init__( + self, + img_size: int = 64, + target_ch: int = 1, + multiscale_count: int = 1, + ): + self.num_samples = 3 + self.img_size = img_size + self.target_ch = target_ch + self.multiscale_count = multiscale_count + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx: int): + input_ = torch.randn(self.multiscale_count, self.img_size, self.img_size) + target = torch.randn(self.target_ch, self.img_size, self.img_size) + return input_, target + + +def create_dummy_dloader( + batch_size: int = 1, + img_size: int = 64, + target_ch: int = 1, + multiscale_count: int = 1, +): + dataset = DummyDataset( + img_size=img_size, + target_ch=target_ch, + multiscale_count=multiscale_count, + ) + return DataLoader(dataset, batch_size=batch_size, shuffle=False) + + +@pytest.mark.parametrize( + "multiscale_count, predict_logvar, loss_type, exp_error", + [ + (1, None, "musplit", does_not_raise()), + (1, "pixelwise", "musplit", does_not_raise()), + (5, None, "musplit", does_not_raise()), + (5, "pixelwise", "musplit", does_not_raise()), + (1, None, "denoisplit", pytest.raises(ValueError)), + ], +) +def test_musplit_lightining_init( + multiscale_count: int, + predict_logvar: str, + loss_type: str, + exp_error: Callable, +): + lvae_config = LVAEModel( + architecture="LVAE", + input_shape=64, + multiscale_count=multiscale_count, + z_dims=[128, 128, 128, 128], + output_channels=3, + predict_logvar=predict_logvar, + ) + + likelihood_config = GaussianLikelihoodConfig( + predict_logvar=predict_logvar, + logvar_lowerbound=0.0, + ) + + with exp_error: + vae_config = VAEAlgorithmConfig( + algorithm_type="vae", + algorithm="musplit", + loss=loss_type, + model=lvae_config, + gaussian_likelihood_model=likelihood_config, + ) + lightning_model = VAEModule( + algorithm_config=vae_config, + ) + assert lightning_model is not None + assert isinstance(lightning_model.model, torch.nn.Module) + assert lightning_model.noise_model is None + assert lightning_model.noise_model_likelihood is None + assert isinstance(lightning_model.gaussian_likelihood, GaussianLikelihood) + assert lightning_model.loss_func == musplit_loss + + +@pytest.mark.parametrize( + "multiscale_count, predict_logvar, target_ch, nm_cnt, loss_type, exp_error", + [ + (1, None, 1, 1, "denoisplit", does_not_raise()), + (5, None, 1, 1, "denoisplit", does_not_raise()), + (1, None, 3, 3, "denoisplit", does_not_raise()), + (5, None, 3, 3, "denoisplit", does_not_raise()), + (1, None, 2, 4, "denoisplit", pytest.raises(ValidationError)), + (1, None, 1, 1, "denoisplit_musplit", does_not_raise()), + (1, None, 3, 3, "denoisplit_musplit", does_not_raise()), + (1, "pixelwise", 1, 1, "denoisplit_musplit", does_not_raise()), + (1, "pixelwise", 3, 3, "denoisplit_musplit", does_not_raise()), + (5, None, 1, 1, "denoisplit_musplit", does_not_raise()), + (5, None, 3, 3, "denoisplit_musplit", does_not_raise()), + (5, "pixelwise", 1, 1, "denoisplit_musplit", does_not_raise()), + (5, "pixelwise", 3, 3, "denoisplit_musplit", does_not_raise()), + (5, "pixelwise", 2, 4, "denoisplit_musplit", pytest.raises(ValidationError)), + (1, "pixelwise", 1, 1, "denoisplit", pytest.raises(ValidationError)), + (1, None, 1, 1, "musplit", pytest.raises(ValueError)), + ], +) +def test_denoisplit_lightining_init( + tmp_path: Path, + multiscale_count: int, + predict_logvar: str, + target_ch: int, + nm_cnt: int, + loss_type: str, + exp_error: Callable, +): + # Create the model config + lvae_config = LVAEModel( + architecture="LVAE", + input_shape=64, + multiscale_count=multiscale_count, + z_dims=[128, 128, 128, 128], + output_channels=target_ch, + predict_logvar=predict_logvar, + ) + + # Create the likelihood config(s) + # gaussian + if loss_type == "denoisplit_musplit": + gaussian_lik_config = GaussianLikelihoodConfig( + predict_logvar=predict_logvar, + logvar_lowerbound=0.0, + ) + else: + gaussian_lik_config = None + # noise model + create_dummy_noise_model(tmp_path, 3, 3) + gmm = GaussianMixtureNMConfig( + model_type="GaussianMixtureNoiseModel", + path=tmp_path / "dummy_noise_model.npz", + ) + noise_model_config = MultiChannelNMConfig(noise_models=[gmm] * nm_cnt) + nm = noise_model_factory(noise_model_config) + nm_lik_config = NMLikelihoodConfig(noise_model=nm) + + with exp_error: + vae_config = VAEAlgorithmConfig( + algorithm_type="vae", + algorithm="denoisplit", + loss=loss_type, + model=lvae_config, + gaussian_likelihood_model=gaussian_lik_config, + noise_model=noise_model_config, + noise_model_likelihood_model=nm_lik_config, + ) + lightning_model = VAEModule( + algorithm_config=vae_config, + ) + assert lightning_model is not None + assert isinstance(lightning_model.model, torch.nn.Module) + assert isinstance(lightning_model.noise_model, MultiChannelNoiseModel) + assert isinstance(lightning_model.noise_model_likelihood, NoiseModelLikelihood) + if loss_type == "denoisplit_musplit": + assert isinstance(lightning_model.gaussian_likelihood, GaussianLikelihood) + assert lightning_model.loss_func == denoisplit_musplit_loss + else: + assert lightning_model.gaussian_likelihood is None + assert lightning_model.loss_func == denoisplit_loss + + +@pytest.mark.parametrize("batch_size", [1, 8]) +@pytest.mark.parametrize("multiscale_count", [1, 5]) +@pytest.mark.parametrize("predict_logvar", [None, "pixelwise"]) +@pytest.mark.parametrize("target_ch", [1, 3]) +def test_musplit_training_step( + batch_size: int, + multiscale_count: int, + predict_logvar: str, + target_ch: int, +): + lightning_model = create_split_lightning_model( + tmp_path=None, + algorithm="musplit", + loss_type="musplit", + multiscale_count=multiscale_count, + predict_logvar=predict_logvar, + target_ch=target_ch, + ) + dloader = create_dummy_dloader( + batch_size=batch_size, + img_size=64, + multiscale_count=multiscale_count, + target_ch=target_ch, + ) + batch = next(iter(dloader)) + train_loss = lightning_model.training_step(batch=batch, batch_idx=0) + + # check outputs + assert train_loss is not None + assert isinstance(train_loss, dict) + assert "loss" in train_loss + assert "reconstruction_loss" in train_loss + assert "kl_loss" in train_loss + + +@pytest.mark.parametrize("batch_size", [1, 8]) +@pytest.mark.parametrize("multiscale_count", [1, 5]) +@pytest.mark.parametrize("predict_logvar", [None, "pixelwise"]) +@pytest.mark.parametrize("target_ch", [1, 3]) +def test_musplit_validation_step( + batch_size: int, + multiscale_count: int, + predict_logvar: str, + target_ch: int, +): + lightning_model = create_split_lightning_model( + tmp_path=None, + algorithm="musplit", + loss_type="musplit", + multiscale_count=multiscale_count, + predict_logvar=predict_logvar, + target_ch=target_ch, + ) + dloader = create_dummy_dloader( + batch_size=batch_size, + img_size=64, + multiscale_count=multiscale_count, + target_ch=target_ch, + ) + batch = next(iter(dloader)) + lightning_model.validation_step(batch=batch, batch_idx=0) + # NOTE: `validation_step` does not return anything... + + +@pytest.mark.parametrize( + "multiscale_count, predict_logvar, target_ch, loss_type", + [ + (1, None, 1, "denoisplit"), + (5, None, 1, "denoisplit"), + (1, None, 3, "denoisplit"), + (5, None, 3, "denoisplit"), + (1, None, 1, "denoisplit_musplit"), + (1, None, 3, "denoisplit_musplit"), + (1, "pixelwise", 1, "denoisplit_musplit"), + (1, "pixelwise", 3, "denoisplit_musplit"), + (5, None, 1, "denoisplit_musplit"), + (5, None, 3, "denoisplit_musplit"), + (5, "pixelwise", 1, "denoisplit_musplit"), + (5, "pixelwise", 3, "denoisplit_musplit"), + ], +) +def test_denoisplit_training_step( + tmp_path: Path, + multiscale_count: int, + predict_logvar: str, + target_ch: int, + loss_type: str, +): + lightning_model = create_split_lightning_model( + tmp_path=tmp_path, + algorithm="denoisplit", + loss_type=loss_type, + multiscale_count=multiscale_count, + predict_logvar=predict_logvar, + target_ch=target_ch, + ) + dloader = create_dummy_dloader( + batch_size=8, + img_size=64, + multiscale_count=multiscale_count, + target_ch=target_ch, + ) + batch = next(iter(dloader)) + train_loss = lightning_model.training_step(batch=batch, batch_idx=0) + + # check outputs + assert train_loss is not None + assert isinstance(train_loss, dict) + assert "loss" in train_loss + assert "reconstruction_loss" in train_loss + assert "kl_loss" in train_loss + + +@pytest.mark.parametrize( + "multiscale_count, predict_logvar, target_ch, loss_type", + [ + (1, None, 1, "denoisplit"), + (5, None, 1, "denoisplit"), + (1, None, 3, "denoisplit"), + (5, None, 3, "denoisplit"), + (1, None, 1, "denoisplit_musplit"), + (1, None, 3, "denoisplit_musplit"), + (1, "pixelwise", 1, "denoisplit_musplit"), + (1, "pixelwise", 3, "denoisplit_musplit"), + (5, None, 1, "denoisplit_musplit"), + (5, None, 3, "denoisplit_musplit"), + (5, "pixelwise", 1, "denoisplit_musplit"), + (5, "pixelwise", 3, "denoisplit_musplit"), + ], +) +def test_denoisplit_validation_step( + tmp_path: Path, + multiscale_count: int, + predict_logvar: str, + target_ch: int, + loss_type: str, +): + lightning_model = create_split_lightning_model( + tmp_path=tmp_path, + algorithm="denoisplit", + loss_type=loss_type, + multiscale_count=multiscale_count, + predict_logvar=predict_logvar, + target_ch=target_ch, + ) + dloader = create_dummy_dloader( + batch_size=8, + img_size=64, + multiscale_count=multiscale_count, + target_ch=target_ch, + ) + batch = next(iter(dloader)) + lightning_model.validation_step(batch=batch, batch_idx=0) + + +@pytest.mark.parametrize("batch_size", [1, 8]) +@pytest.mark.parametrize("multiscale_count", [1, 5]) +@pytest.mark.parametrize("predict_logvar", [None, "pixelwise"]) +@pytest.mark.parametrize("target_ch", [1, 3]) +def test_training_loop_musplit( + batch_size: int, + multiscale_count: int, + predict_logvar: str, + target_ch: int, +): + lightning_model = create_split_lightning_model( + tmp_path=None, + algorithm="musplit", + loss_type="musplit", + multiscale_count=multiscale_count, + predict_logvar=predict_logvar, + target_ch=target_ch, + ) + dloader = create_dummy_dloader( + batch_size=batch_size, + img_size=64, + multiscale_count=multiscale_count, + target_ch=target_ch, + ) + trainer = Trainer(accelerator="cpu", max_epochs=2, logger=False, callbacks=[]) + + try: + trainer.fit( + model=lightning_model, + train_dataloaders=dloader, + val_dataloaders=dloader, + ) + except Exception as e: + pytest.fail(f"Training routine failed with exception: {e}") + + +@pytest.mark.parametrize( + "batch_size, predict_logvar, target_ch, loss_type", + [ + (1, None, 1, "denoisplit"), + (4, None, 1, "denoisplit"), + (1, None, 3, "denoisplit"), + (4, None, 3, "denoisplit"), + (1, None, 1, "denoisplit_musplit"), + (1, None, 3, "denoisplit_musplit"), + (1, "pixelwise", 1, "denoisplit_musplit"), + (1, "pixelwise", 3, "denoisplit_musplit"), + (4, None, 1, "denoisplit_musplit"), + (4, None, 3, "denoisplit_musplit"), + (4, "pixelwise", 1, "denoisplit_musplit"), + (4, "pixelwise", 3, "denoisplit_musplit"), + ], +) +def test_training_loop_denoisplit( + tmp_path: Path, + batch_size: int, + predict_logvar: str, + target_ch: int, + loss_type: str, +): + lightning_model = create_split_lightning_model( + tmp_path=tmp_path, + algorithm="denoisplit", + loss_type=loss_type, + multiscale_count=1, + predict_logvar=predict_logvar, + target_ch=target_ch, + ) + dloader = create_dummy_dloader( + batch_size=batch_size, + img_size=64, + multiscale_count=1, + target_ch=target_ch, + ) + trainer = Trainer(accelerator="cpu", max_epochs=2, logger=False, callbacks=[]) + + try: + trainer.fit( + model=lightning_model, + train_dataloaders=dloader, + val_dataloaders=dloader, + ) + except Exception as e: + pytest.fail(f"Training routine failed with exception: {e}") + + +@pytest.mark.parametrize("predict_logvar", [None, "pixelwise"]) +@pytest.mark.parametrize("target_ch", [1, 3]) +def test_get_reconstructed_tensor( + predict_logvar: str, + target_ch: int, +): + lightning_model = create_split_lightning_model( + tmp_path=None, + algorithm="musplit", + loss_type="musplit", + multiscale_count=1, + predict_logvar=predict_logvar, + target_ch=target_ch, + ) + dloader = create_dummy_dloader( + batch_size=1, + img_size=64, + multiscale_count=1, + target_ch=target_ch, + ) + input_, target = next(iter(dloader)) + output = lightning_model(input_) + rec_img = lightning_model.get_reconstructed_tensor(output) + assert rec_img.shape == target.shape # same shape as target + + +@pytest.mark.parametrize("predict_logvar", [None, "pixelwise"]) +@pytest.mark.parametrize("target_ch", [1, 3]) +def test_val_PSNR_computation( + predict_logvar: str, + target_ch: int, +): + lightning_model = create_split_lightning_model( + tmp_path=None, + algorithm="musplit", + loss_type="musplit", + multiscale_count=1, + predict_logvar=predict_logvar, + target_ch=target_ch, + ) + assert lightning_model.running_psnr is not None + assert len(lightning_model.running_psnr) == target_ch + for item in lightning_model.running_psnr: + assert isinstance(item, RunningPSNR) + + dloader = create_dummy_dloader( + batch_size=1, + img_size=64, + multiscale_count=1, + target_ch=target_ch, + ) + input_, target = next(iter(dloader)) + output = lightning_model(input_) + + curr_psnr = lightning_model.compute_val_psnr(output, target) + assert curr_psnr is not None + assert len(curr_psnr) == target_ch + for i in range(target_ch): + assert lightning_model.running_psnr[i].mse_sum != 0 + assert lightning_model.running_psnr[i].N == 1 + assert lightning_model.running_psnr[i].min is not None + assert lightning_model.running_psnr[i].max is not None + assert lightning_model.running_psnr[i].get() is not None + lightning_model.running_psnr[i].reset() + assert lightning_model.running_psnr[i].mse_sum == 0 + assert lightning_model.running_psnr[i].N == 0 + assert lightning_model.running_psnr[i].min is None + assert lightning_model.running_psnr[i].max is None + assert lightning_model.running_psnr[i].get() is None diff --git a/tests/lightning/test_lightning_module.py b/tests/lightning/test_lightning_module.py index 16158653..1a6bca24 100644 --- a/tests/lightning/test_lightning_module.py +++ b/tests/lightning/test_lightning_module.py @@ -7,6 +7,8 @@ create_careamics_module, ) +# TODO: rename to test_FCN_lightining_module.py + def test_careamics_module(minimum_algorithm_n2v): """Test that the minimum algorithm allows instantiating a the Lightning API diff --git a/tests/models/lvae/test_noise_model.py b/tests/models/lvae/test_noise_model.py index f3deb987..a8ace6f4 100644 --- a/tests/models/lvae/test_noise_model.py +++ b/tests/models/lvae/test_noise_model.py @@ -136,6 +136,6 @@ def test_gm_noise_model_training(tmp_path): noise_model = GaussianMixtureNoiseModel(nm_config) # Test training - output = noise_model.train(x, y, n_epochs=2) + output = noise_model.train_noise_model(x, y, n_epochs=2) assert output is not None # TODO do something with output ? diff --git a/tests/models/test_model_factory.py b/tests/models/test_model_factory.py index e4180ea0..18864a56 100644 --- a/tests/models/test_model_factory.py +++ b/tests/models/test_model_factory.py @@ -1,9 +1,7 @@ -import pytest from torch import nn, ones from careamics.config.architectures import ( CustomModel, - LVAEModel, UNetModel, register_model, ) @@ -52,13 +50,3 @@ def forward(self, input): assert isinstance(model, LinearModel) assert model.in_features == 10 assert model.out_features == 5 - - -def test_lvae(): - """Test that VAE are currently not supported.""" - model_config = { - "architecture": SupportedArchitecture.LVAE.value, - } - - with pytest.raises(NotImplementedError): - model_factory(LVAEModel(**model_config))