Skip to content

Commit

Permalink
Refactoring: lightning API package and smoke tests (#161)
Browse files Browse the repository at this point in the history
### Description

I moved the lightning modules into their own package to clarify the
imports for the different API, and added smoke tests for the Lightning
API. As a result I needed to implement a new function to retrieve the
statistics from the `CAREamicsTrainData`, which led me to refactor a bit
the datasets and standardize how the statistics are recorded across the
two datasets.

The smoke tests check that the Lightning API works for tiled and
un-tiled data.

Finally, the lightning modules wrappers have disappeared, and now
convenience functions explicit the parameters. This led to renaming of
the lightning modules.

- **What**:
    - Moved Lightning modules and callbacks to `careamics.lightning`
- Added tests for the Lightning API (tiled and un-tiled prediction
included)
    - Normalize the way the datasets keep the image statistics.
    - Fixed some errors in `CAREamicsTrainData`
- **Why**: Since we are differentiating between CAREamist and Lightning
APIs, they should be imported from different packages.
- **How**: 
    - Moved Lightning modules and callbacks to `careamics.lightning`
    - Added tests for Ligthning API.
- Removed `StatsOutput`, both datasets now have a `self.image_stats:
Stats` and `self.target_stats: Stats`
- Added `get_data_statistics` to the datasets and `CAREamicsTrainData`
and corresponding test

### Changes Made

- **Added**: `test_lightning_api.py`.
- **Modified**: 
    - `ligthning_data_module.py`
    - `Iterable_dataset.py`
    - `in_memory_dataset.py`
    - `patching.py`
    - imports throughout the careamist and other files


### Additional Notes and Examples

To use the Lightning API, users now need to have the following imports:

```python
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

from careamics.lightning import (
    CAREamicsModuleWrapper,
    TrainingDataWrapper,
    PredictDataWrapper,
)
from careamics.prediction_utils import convert_outputs # if tiling required
```

(see [full example
here](https://github.com/CAREamics/careamics-examples/blob/jd/lightning_api/applications/lightning_api/2D/BSD68_Noise2Void_lightning_api.ipynb))

Since the `PredictDataWrapper` requires passing the statistics, users
can now simply call:

```python
means, stds = train_data_wrapper.get_data_statistics()
```

While building the notebook, I uncovered the following error:
#162

---

**Please ensure your PR meets the following requirements:**

- [x] Code builds and passes tests locally, including doctests
- [x] New tests have been added (for bug fixes/features)
- [x] Pre-commit passes
- [ ] PR to the documentation exists (for bug fixes / features)
    - CAREamics/careamics-examples#6
    - CAREamics/careamics.github.io#11

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
jdeschamps and pre-commit-ci[bot] authored Jul 4, 2024
1 parent f223aa7 commit ff20596
Show file tree
Hide file tree
Showing 29 changed files with 722 additions and 1,024 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ exclude_lines = [
"except ImportError",
"\\.\\.\\.",
"raise NotImplementedError()",
"except PackageNotFoundError:",
]

[tool.coverage.run]
Expand Down
15 changes: 1 addition & 14 deletions src/careamics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,7 @@
except PackageNotFoundError:
__version__ = "uninstalled"

__all__ = [
"CAREamist",
"CAREamicsModuleWrapper",
"CAREamicsPredictData",
"CAREamicsTrainData",
"Configuration",
"load_configuration",
"save_configuration",
"TrainingDataWrapper",
"PredictDataWrapper",
]
__all__ = ["CAREamist", "Configuration", "load_configuration", "save_configuration"]

from .careamist import CAREamist
from .config import Configuration, load_configuration, save_configuration
from .lightning_datamodule import CAREamicsTrainData, TrainingDataWrapper
from .lightning_module import CAREamicsModuleWrapper
from .lightning_prediction_datamodule import CAREamicsPredictData, PredictDataWrapper
112 changes: 76 additions & 36 deletions src/careamics/careamist.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,29 @@
)
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger

from careamics.callbacks import ProgressBarCallback
from careamics.config import (
Configuration,
load_configuration,
)
from careamics.config.support import SupportedAlgorithm, SupportedData, SupportedLogger
from careamics.config.support import (
SupportedAlgorithm,
SupportedArchitecture,
SupportedData,
SupportedLogger,
)
from careamics.dataset.dataset_utils import reshape_array
from careamics.lightning_datamodule import CAREamicsTrainData
from careamics.lightning_module import CAREamicsModule
from careamics.lightning import (
CAREamicsModule,
HyperParametersCallback,
PredictDataModule,
ProgressBarCallback,
TrainDataModule,
create_predict_datamodule,
)
from careamics.model_io import export_to_bmz, load_pretrained
from careamics.prediction_utils import convert_outputs, create_pred_datamodule
from careamics.prediction_utils import convert_outputs
from careamics.utils import check_path_exists, get_logger

from .callbacks import HyperParametersCallback
from .lightning_prediction_datamodule import CAREamicsPredictData

logger = get_logger(__name__)

LOGGER_TYPES = Optional[Union[TensorBoardLogger, WandbLogger]]
Expand Down Expand Up @@ -61,9 +68,9 @@ class CAREamist:
Experiment logger, "wandb" or "tensorboard".
work_dir : pathlib.Path
Working directory.
train_datamodule : CAREamicsTrainData
train_datamodule : TrainDataModule
Training datamodule.
pred_datamodule : CAREamicsPredictData
pred_datamodule : PredictDataModule
Prediction datamodule.
"""

Expand Down Expand Up @@ -193,8 +200,8 @@ def __init__(
)

# place holder for the datamodules
self.train_datamodule: Optional[CAREamicsTrainData] = None
self.pred_datamodule: Optional[CAREamicsPredictData] = None
self.train_datamodule: Optional[TrainDataModule] = None
self.pred_datamodule: Optional[PredictDataModule] = None

def _define_callbacks(self, callbacks: Optional[list[Callback]] = None) -> None:
"""
Expand Down Expand Up @@ -246,7 +253,7 @@ def _define_callbacks(self, callbacks: Optional[list[Callback]] = None) -> None:
def train(
self,
*,
datamodule: Optional[CAREamicsTrainData] = None,
datamodule: Optional[TrainDataModule] = None,
train_source: Optional[Union[Path, str, NDArray]] = None,
val_source: Optional[Union[Path, str, NDArray]] = None,
train_target: Optional[Union[Path, str, NDArray]] = None,
Expand All @@ -273,7 +280,7 @@ def train(
Parameters
----------
datamodule : CAREamicsTrainData, optional
datamodule : TrainDataModule, optional
Datamodule to train on, by default None.
train_source : pathlib.Path or str or NDArray, optional
Train source, if no datamodule is provided, by default None.
Expand Down Expand Up @@ -375,17 +382,17 @@ def train(

else:
raise ValueError(
f"Invalid input, expected a str, Path, array or CAREamicsTrainData "
f"Invalid input, expected a str, Path, array or TrainDataModule "
f"instance (got {type(train_source)})."
)

def _train_on_datamodule(self, datamodule: CAREamicsTrainData) -> None:
def _train_on_datamodule(self, datamodule: TrainDataModule) -> None:
"""
Train the model on the provided datamodule.
Parameters
----------
datamodule : CAREamicsTrainData
datamodule : TrainDataModule
Datamodule to train on.
"""
# record datamodule
Expand Down Expand Up @@ -421,7 +428,7 @@ def _train_on_array(
Minimum number of patches to use for validation, by default 5.
"""
# create datamodule
datamodule = CAREamicsTrainData(
datamodule = TrainDataModule(
data_config=self.cfg.data_config,
train_data=train_data,
val_data=val_data,
Expand Down Expand Up @@ -477,7 +484,7 @@ def _train_on_path(
path_to_val_target = check_path_exists(path_to_val_target)

# create datamodule
datamodule = CAREamicsTrainData(
datamodule = TrainDataModule(
data_config=self.cfg.data_config,
train_data=path_to_train_data,
val_data=path_to_val_data,
Expand All @@ -493,7 +500,7 @@ def _train_on_path(

@overload
def predict( # numpydoc ignore=GL08
self, source: CAREamicsPredictData
self, source: PredictDataModule
) -> Union[list[NDArray], NDArray]: ...

@overload
Expand Down Expand Up @@ -528,7 +535,7 @@ def predict( # numpydoc ignore=GL08

def predict(
self,
source: Union[CAREamicsPredictData, Path, str, NDArray],
source: Union[PredictDataModule, Path, str, NDArray],
*,
batch_size: Optional[int] = None,
tile_size: Optional[tuple[int, ...]] = None,
Expand Down Expand Up @@ -591,29 +598,62 @@ def predict(
-------
list of NDArray or NDArray
Predictions made by the model.
"""
# Reuse batch size if not provided explicitly
if batch_size is None:
batch_size = (
self.train_datamodule.batch_size
if self.train_datamodule
else self.cfg.data_config.batch_size
)
self.pred_datamodule = create_pred_datamodule(
source=source,
config=self.cfg,
batch_size=batch_size,
Raises
------
ValueError
If mean and std are not provided in the configuration.
ValueError
If tile size is not divisible by 2**depth for UNet models.
ValueError
If tile overlap is not specified.
"""
if (
self.cfg.data_config.image_means is None
or self.cfg.data_config.image_stds is None
):
raise ValueError("Mean and std must be provided in the configuration.")

# tile size for UNets
if tile_size is not None:
model = self.cfg.algorithm_config.model

if model.architecture == SupportedArchitecture.UNET.value:
# tile size must be equal to k*2^n, where n is the number of pooling
# layers (equal to the depth) and k is an integer
depth = model.depth
tile_increment = 2**depth

for i, t in enumerate(tile_size):
if t % tile_increment != 0:
raise ValueError(
f"Tile size must be divisible by {tile_increment} along "
f"all axes (got {t} for axis {i}). If your image size is "
f"smaller along one axis (e.g. Z), consider padding the "
f"image."
)

# tile overlaps must be specified
if tile_overlap is None:
raise ValueError("Tile overlap must be specified.")

# create the prediction
self.pred_datamodule = create_predict_datamodule(
pred_data=source,
data_type=data_type or self.cfg.data_config.data_type,
axes=axes or self.cfg.data_config.axes,
image_means=self.cfg.data_config.image_means,
image_stds=self.cfg.data_config.image_stds,
tile_size=tile_size,
tile_overlap=tile_overlap,
axes=axes,
data_type=data_type,
batch_size=batch_size or self.cfg.data_config.batch_size,
tta_transforms=tta_transforms,
dataloader_params=dataloader_params,
read_source_func=read_source_func,
extension_filter=extension_filter,
dataloader_params=dataloader_params,
)

# predict
predictions = self.trainer.predict(
model=self.model, datamodule=self.pred_datamodule
)
Expand Down
3 changes: 0 additions & 3 deletions src/careamics/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,14 @@
"create_care_configuration",
"register_model",
"CustomModel",
"create_inference_configuration",
"clear_custom_models",
"ConfigurationInformation",
]

from .algorithm_model import AlgorithmConfig
from .architectures import CustomModel, clear_custom_models, register_model
from .callback_model import CheckpointModel
from .configuration_factory import (
create_care_configuration,
create_inference_configuration,
create_n2n_configuration,
create_n2v_configuration,
)
Expand Down
86 changes: 0 additions & 86 deletions src/careamics/config/configuration_example.py

This file was deleted.

Loading

0 comments on commit ff20596

Please sign in to comment.