Skip to content

Commit

Permalink
Implemented training_step and validation_step, fixed all the conf…
Browse files Browse the repository at this point in the history
…igs needed to instantiate the lightning module (#218)

### Description

This PR finally implements a working training loop for the lightining
module `VAEModule`.

- **What**: we added `training_step` and `validation_step` plus all the
required methods. We also fixed some bugs in the configs (`pydantic`
model) describing the different model components.
- **Why**: To be able to train a LVAE model.
- **How**: see below.

### Changes Made

- **Added**: 
- `training_step` and `validation_step` in the lightning `VAEModule`
(plus other related methods).
- Some required attributes and validators in the `VAEAlgorithmmConfig`.
- An example notebook to illustrate how to train the LVAE model using
current lightning API.
  - `RunningPSNR` class in `metrics.py`.
  - Tests for the added features.
- **Modified**:
- Typing and validators in different `pydantic` models (e.g.,
`lvae_model.py`, `nm_model.py`)
- **Removed**: Nothing in particular.

### Things that can be improved

- `validation_step` requires some additional methods for computing the
PSNR at each epoch. These methods are currently within the `VAEModule`
since require some of its attributes. If possible, it would be nice to
move this methods out of the lightning module to have something cleaner.

---

**Please ensure your PR meets the following requirements:**

- [x] Code builds and passes tests locally, including doctests
- [x] New tests have been added (for bug fixes/features)
- [x] Pre-commit passes
- [ ] PR to the documentation exists (for bug fixes / features)

---------

Co-authored-by: CatEek <[email protected]>
Co-authored-by: Joran Deschamps <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: melisande-c <[email protected]>
  • Loading branch information
5 people authored Aug 24, 2024
1 parent 692429e commit 427fa26
Show file tree
Hide file tree
Showing 14 changed files with 1,352 additions and 75 deletions.
439 changes: 439 additions & 0 deletions examples/example_training_LVAE_split.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions src/careamics/config/architectures/lvae_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .architecture_model import ArchitectureModel


# TODO: it is quite confusing to call this LVAEModel, as it is basically a config
class LVAEModel(ArchitectureModel):
"""LVAE model."""

Expand Down
5 changes: 4 additions & 1 deletion src/careamics/config/nm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ class GaussianMixtureNMConfig(BaseModel):
"""Gaussian mixture noise model."""

model_config = ConfigDict(
validate_assignment=True, arbitrary_types_allowed=True, extra="allow"
protected_namespaces=(),
validate_assignment=True,
arbitrary_types_allowed=True,
extra="allow",
)
# model type
model_type: Literal["GaussianMixtureNoiseModel"]
Expand Down
75 changes: 60 additions & 15 deletions src/careamics/config/vae_algorithm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,18 @@ class VAEAlgorithmConfig(BaseModel):
Attributes
----------
algorithm : Literal["n2v", "custom"]
algorithm : algorithm: Literal["musplit", "denoisplit", "custom"]
Algorithm to use.
loss : Literal["n2v", "mae", "mse"]
loss : Literal["musplit", "denoisplit", "denoisplit_musplit"]
Loss function to use.
model : Union[UNetModel, LVAEModel, CustomModel]
model : Union[LVAEModel, CustomModel]
Model architecture to use.
noise_model: Optional[MultiChannelNmModel]
Noise model to use.
noise_model_likelihood_model: Optional[NMLikelihoodModel]
Noise model likelihood model to use.
gaussian_likelihood_model: Optional[GaussianLikelihoodModel]
Gaussian likelihood model to use.
optimizer : OptimizerModel, optional
Optimizer to use.
lr_scheduler : LrSchedulerModel, optional
Expand Down Expand Up @@ -66,9 +72,10 @@ class VAEAlgorithmConfig(BaseModel):
# - values can still be passed as strings and they will be cast to Enum
algorithm_type: Literal["vae"]
algorithm: Literal["musplit", "denoisplit", "custom"]
loss: Literal["musplit", "denoisplit"]
loss: Literal["musplit", "denoisplit", "denoisplit_musplit"]
model: Union[LVAEModel, CustomModel] = Field(discriminator="architecture")

# TODO: these are configs, change naming of attrs
noise_model: Optional[MultiChannelNMConfig] = None
noise_model_likelihood_model: Optional[NMLikelihoodConfig] = None
gaussian_likelihood_model: Optional[GaussianLikelihoodConfig] = None
Expand All @@ -94,25 +101,63 @@ def algorithm_cross_validation(self: Self) -> Self:
raise ValueError(
f"Algorithm {self.algorithm} only supports loss `musplit`."
)
if self.model.predict_logvar != "pixelwise":
raise ValueError(
"Algorithm `musplit` only supports `predict_logvar` as `pixelwise`."
)
# TODO add more checks

if self.algorithm == SupportedAlgorithm.DENOISPLIT:
if self.loss != SupportedLoss.DENOISPLIT:
if self.loss not in [
SupportedLoss.DENOISPLIT,
SupportedLoss.DENOISPLIT_MUSPLIT,
]:
raise ValueError(
f"Algorithm {self.algorithm} only supports loss `denoisplit`."
f"Algorithm {self.algorithm} only supports loss `denoisplit` "
"or `denoisplit_musplit."
)
if self.model.predict_logvar is not None:
if (
self.loss == SupportedLoss.DENOISPLIT
and self.model.predict_logvar is not None
):
raise ValueError(
"Algorithm `denoisplit` only supports `predict_logvar` as `None`."
"Algorithm `denoisplit` with loss `denoisplit` only supports "
"`predict_logvar` as `None`."
)

if self.noise_model is None:
raise ValueError("Algorithm `denoisplit` requires a noise model.")
# TODO: what if algorithm is not musplit or denoisplit
# TODO: what if algorithm is not musplit or denoisplit (HDN?)
return self

@model_validator(mode="after")
def output_channels_validation(self: Self) -> Self:
"""Validate the consistency between number of out channels and noise models.
Returns
-------
Self
The validated model.
"""
if self.noise_model is not None:
assert self.model.output_channels == len(self.noise_model.noise_models), (
f"Number of output channels ({self.model.output_channels}) must match "
f"the number of noise models ({len(self.noise_model.noise_models)})."
)
return self

@model_validator(mode="after")
def predict_logvar_validation(self: Self) -> Self:
"""Validate the consistency of `predict_logvar` throughout the model.
Returns
-------
Self
The validated model.
"""
if self.gaussian_likelihood_model is not None:
assert (
self.model.predict_logvar
== self.gaussian_likelihood_model.predict_logvar
), (
f"Model `predict_logvar` ({self.model.predict_logvar}) must match "
"Gaussian likelihood model `predict_logvar` "
f"({self.gaussian_likelihood_model.predict_logvar}).",
)
return self

def __str__(self) -> str:
Expand Down
Loading

0 comments on commit 427fa26

Please sign in to comment.