Skip to content

Commit

Permalink
LVAE dataset: add MultiFileDset (#223)
Browse files Browse the repository at this point in the history
### Description


- **What**: This is a follow-up on updating the datasets from the
Disentangle repo (reference commit `57cad67` from the main branch),
adding the `MultiFileDset` class, `TilingMode` and `EmptyPatchFetcher`
- **Why**: The `MultiFileDset` is necessary for the datasets in the
paper
- **How**: Copy/paste the code from the Disentangle repo with minimal
changes

### Changes Made

- **Added**: 
  - from the Disentangle repo:
    - copied `MultiFileDset` class and related classes
    - added `TilingMode` and updated related code in `MultiChDloader`
    - copied `EmptyPatchFetcher`
    - copied several new parameters for `MultiChDloader`
    - copied code for flips (using _Albumentations_)
- added "load_data_fn" parameter to dataset classes to pass the data
loading logic from different experiments
- added lists of the dataset types needed for the paper in the
`data_utils.py`
- added a simple test to test the `MultiFileDset` object initialization
- **Modified**: 
  - moved dataset configs into `lvae_training/dataset/configs` 
- moved `GridIndexManager` and `IndexSwitcher` to separate files in
`lvae_training/dataset/utils`
- **Removed**: 
- removed code related to the BioSR dataset from `data_utils.py` and
tests

---

**Please ensure your PR meets the following requirements:**

- [x] Code builds and passes tests locally, including doctests
- [x] New tests have been added (for bug fixes/features)
- [ ] Pre-commit passes
- [ ] PR to the documentation exists (for bug fixes / features)

---------

Co-authored-by: Joran Deschamps <[email protected]>
  • Loading branch information
veegalinova and jdeschamps authored Sep 18, 2024
1 parent 1877efd commit 59fa2dd
Show file tree
Hide file tree
Showing 20 changed files with 1,367 additions and 1,006 deletions.
2 changes: 1 addition & 1 deletion src/careamics/dataset/tiling/lvae_tiled_patching.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from numpy.typing import NDArray

from careamics.config.tile_information import TileInformation
from careamics.lvae_training.dataset.data_utils import GridIndexManager
from careamics.lvae_training.dataset.utils.index_manager import GridIndexManager


def extract_tiles(
Expand Down
15 changes: 15 additions & 0 deletions src/careamics/lvae_training/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from .multich_dataset import MultiChDloader
from .lc_dataset import LCMultiChDloader
from .multifile_dataset import MultiFileDset
from .config import DatasetConfig
from .types import DataType, DataSplitType, TilingMode

__all__ = [
"DatasetConfig",
"MultiChDloader",
"LCMultiChDloader",
"MultiFileDset",
"DataType",
"DataSplitType",
"TilingMode",
]
Original file line number Diff line number Diff line change
@@ -1,63 +1,13 @@
from typing import Any, Optional
from enum import Enum

from pydantic import BaseModel, ConfigDict, computed_field


# TODO: get rid of unnecessary enums
class DataType(Enum):
MNIST = 0
Places365 = 1
NotMNIST = 2
OptiMEM100_014 = 3
CustomSinosoid = 4
Prevedel_EMBL = 5
AllenCellMito = 6
SeparateTiffData = 7
CustomSinosoidThreeCurve = 8
SemiSupBloodVesselsEMBL = 9
Pavia2 = 10
Pavia2VanillaSplitting = 11
ExpansionMicroscopyMitoTub = 12
ShroffMitoEr = 13
HTIba1Ki67 = 14
BSD68 = 15
BioSR_MRC = 16
TavernaSox2Golgi = 17
Dao3Channel = 18
ExpMicroscopyV2 = 19
Dao3ChannelWithInput = 20
TavernaSox2GolgiV2 = 21
TwoDset = 22
PredictedTiffData = 23
Pavia3SeqData = 24
# Here, we have 16 splitting tasks.
NicolaData = 25


class DataSplitType(Enum):
All = 0
Train = 1
Val = 2
Test = 3


class GridAlignement(Enum):
"""
A patch is formed by padding the grid with content. If the grids are 'Center' aligned, then padding is to done equally on all 4 sides.
On the other hand, if grids are 'LeftTop' aligned, padding is to be done on the right and bottom end of the grid.
In the former case, one needs (patch_size - grid_size)//2 amount of content on the right end of the frame.
In the latter case, one needs patch_size - grid_size amount of content on the right end of the frame.
"""

LeftTop = 0
Center = 1


# TODO: for all bool params check if they are taking different values in Disentangle repo

from pydantic import BaseModel, ConfigDict

from .types import DataType, DataSplitType, TilingMode


# TODO: check if any bool logic can be removed
class VaeDatasetConfig(BaseModel):
model_config = ConfigDict(validate_assignment=True)
class DatasetConfig(BaseModel):
model_config = ConfigDict(validate_assignment=True, extra="forbid")

data_type: Optional[DataType]
"""Type of the dataset, should be one of DataType"""
Expand Down Expand Up @@ -132,15 +82,10 @@ class VaeDatasetConfig(BaseModel):
# TODO: why is this not used?
enable_rotation_aug: Optional[bool] = False

grid_alignment: GridAlignement = GridAlignement.LeftTop

max_val: Optional[float] = None
"""Maximum data in the dataset. Is calculated for train split, and should be
externally set for val and test splits."""

trim_boundary: Optional[bool] = True
"""Whether to trim boundary of the image"""

overlapping_padding_kwargs: Any = None
"""Parameters for np.pad method"""

Expand All @@ -157,23 +102,22 @@ class VaeDatasetConfig(BaseModel):
train_aug_rotate: Optional[bool] = False
enable_random_cropping: Optional[bool] = True

# TODO: not used?
multiscale_lowres_count: Optional[int] = None
"""Number of LC scales"""

tiling_mode: Optional[TilingMode] = TilingMode.ShiftBoundary

target_separate_normalization: Optional[bool] = True

mode_3D: Optional[bool] = False
"""If training in 3D mode or not"""

trainig_datausage_fraction: Optional[float] = 1.0

validtarget_random_fraction: Optional[float] = None

validation_datausage_fraction: Optional[float] = 1.0

random_flip_z_3D: Optional[bool] = False

@computed_field
@property
def padding_kwargs(self) -> dict:
kwargs_dict = {}
padding_kwargs = {}
if (
self.multiscale_lowres_count is not None
and self.multiscale_lowres_count is not None
):
# Get padding attributes
if "padding_kwargs" not in kwargs_dict:
padding_kwargs = {}
padding_kwargs["mode"] = "constant"
padding_kwargs["constant_values"] = 0
else:
padding_kwargs = kwargs_dict.pop("padding_kwargs")
return padding_kwargs
padding_kwargs: Optional[dict] = None
Loading

0 comments on commit 59fa2dd

Please sign in to comment.