Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Final transform module #691

Open
wants to merge 47 commits into
base: clinicadl_v2
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
559040e
first draft fow new extraction objects
thibaultdvx Dec 9, 2024
3e7dc79
unittests
thibaultdvx Dec 10, 2024
016f61c
replace mentions of BaseExtraction
thibaultdvx Dec 10, 2024
6c5c81d
fix typing issue with Self
thibaultdvx Dec 10, 2024
a8e458a
update method calls in other modules
thibaultdvx Dec 10, 2024
810076b
add torchio subject support
thibaultdvx Dec 11, 2024
7e9209d
uniitest for torchio support
thibaultdvx Dec 11, 2024
8f1ba86
remove slice_mode
thibaultdvx Dec 11, 2024
7e711bc
add .clone()
thibaultdvx Dec 11, 2024
8bf0d54
typo in docstring
thibaultdvx Dec 11, 2024
d74a307
transform object
thibaultdvx Dec 12, 2024
4df6b40
change 'elem' to 'sample'
thibaultdvx Dec 12, 2024
1f27840
change eval() and train() so that they don't return anything
thibaultdvx Dec 12, 2024
5cf8933
add label, mask and torchio subject support in dataset
thibaultdvx Dec 12, 2024
5e319cd
prepare-data and check_preprocessing out of caps dataset
thibaultdvx Dec 16, 2024
b579022
preprocessing
thibaultdvx Dec 16, 2024
bb858eb
change mention of BasePreprocessing to Preprocessing
thibaultdvx Dec 16, 2024
6ad8d30
remove roi
thibaultdvx Dec 16, 2024
7b258ca
init in unittests
thibaultdvx Dec 16, 2024
0f358af
prepare-data outside caps_dataset
thibaultdvx Dec 16, 2024
c407d9f
change unittest accordingly
thibaultdvx Dec 16, 2024
a8582d6
first augmentations
thibaultdvx Dec 17, 2024
034e4db
Merge remote-tracking branch 'upstream/clinicadl_v2' into clinicadl_v2
thibaultdvx Dec 17, 2024
1f90670
Merge branch 'clinicadl_v2' into caps_dataset_transforms
thibaultdvx Dec 17, 2024
87059ec
remove use_uncropped_image in DTI
thibaultdvx Dec 17, 2024
714e928
set use_uncropped_image to False
thibaultdvx Dec 17, 2024
30b340b
check if tsv file exists before writting it
thibaultdvx Dec 17, 2024
61458a0
modify transforms
thibaultdvx Dec 18, 2024
e650b57
Merge remote-tracking branch 'upstream/clinicadl_v2' into clinicadl_v2
thibaultdvx Dec 18, 2024
bac2315
Merge branch 'clinicadl_v2' into caps_dataset_transforms
thibaultdvx Dec 18, 2024
f431c7f
complete merge
thibaultdvx Dec 18, 2024
c3b8df9
use dictionnary
thibaultdvx Dec 18, 2024
6566a9c
path issue
thibaultdvx Dec 18, 2024
2561274
path issue
thibaultdvx Dec 18, 2024
3a3e0dd
path issue
thibaultdvx Dec 18, 2024
8eef173
Merge branch 'caps_dataset_transforms' into transforms_factory
thibaultdvx Dec 18, 2024
a016181
first augmentations
thibaultdvx Dec 19, 2024
26545f9
first draft for config classes
thibaultdvx Dec 19, 2024
61369f6
first draft factories
thibaultdvx Dec 19, 2024
d00cf66
unittests
thibaultdvx Dec 23, 2024
d582774
Merge remote-tracking branch 'upstream/clinicadl_v2' into clinicadl_v2
thibaultdvx Dec 23, 2024
4b3fbe8
Merge branch 'clinicadl_v2' into transforms_factory
thibaultdvx Dec 23, 2024
45f17ae
last changes
thibaultdvx Dec 23, 2024
ac5a85d
add new datapoint object
thibaultdvx Dec 27, 2024
902d7dc
add Resample
thibaultdvx Dec 27, 2024
1711484
add ToCanonical
thibaultdvx Dec 27, 2024
7be5872
solve axes issue
thibaultdvx Dec 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions clinicadl/data/datasets/caps_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@

from clinicadl.data.preprocessing import Preprocessing, PreprocessingT1
from clinicadl.data.readers.caps_reader import CapsReader
from clinicadl.data.structures import DataPoint
from clinicadl.data.utils import (
check_df,
get_infos_from_json,
tsv_to_df,
)
from clinicadl.transforms.extraction import Sample
from clinicadl.transforms.transforms import Transforms
from clinicadl.transforms.utils import get_tio_image
from clinicadl.utils.exceptions import ClinicaDLCAPSError, ClinicaDLTSVError
from clinicadl.utils.iotools.clinica_utils import create_subs_sess_list
from clinicadl.utils.loading import nifti_to_tensor, pt_to_tensor
Expand Down Expand Up @@ -511,22 +511,22 @@ def __getitem__(self, idx: NonNegativeInt) -> Sample:
label = self._get_label(img_index)
masks = self._get_masks(img_index)

tio_image = get_tio_image(image, label, **masks)
data_point = DataPoint(image, label, **masks)

tio_image = self.image_transform(tio_image)
data_point = self.image_transform(data_point)
if not self.eval_mode:
tio_image = self.image_augmentation(tio_image)
data_point = self.image_augmentation(data_point)

tio_sample, sample_description = self.extraction.extract_tio_sample(
tio_image, sample_index
sample, sample_description = self.extraction.extract_sample(
data_point, sample_index
)

tio_sample = self.sample_transform(tio_sample)
sample = self.sample_transform(sample)
if not self.eval_mode:
tio_sample = self.sample_augmentation(tio_sample)
sample = self.sample_augmentation(sample)

return self.extraction.format_output(
tio_sample,
sample,
participant_id=participant,
session_id=session,
image_path=image_path,
Expand Down
179 changes: 179 additions & 0 deletions clinicadl/data/structures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import copy
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Union

import torch
import torchio as tio

from clinicadl.dictionary.words import LABEL


class DataPoint(tio.Subject):
"""
Object that gathers an image, the associated label, and any
mask associated to the image.

Parameters
----------
image : Union[torch.Tensor, tio.ScalarImage]
the image, as a Pytorch tensor or a TorchIO ScalarImage.
label : Optional[Union[float, int, torch.Tensor, tio.LabelMap]]
the label associated to the image. Can be a float (e.g. regression),
an int (e.g. classification), a mask (passed as torch Tensor or a
TorchIO LabelMap; e.g. segmentation) or None if no label (e.g. reconstruction).
**masks : Union[torch.Tensor, tio.LabelMap]
any mask related to the image and useful to compute transforms.

Raises
------
AssertionError
If all the images/masks passed don't have the same shape.
"""

image: tio.ScalarImage
label: Optional[Union[float, int, tio.LabelMap]]

def __init__(
self,
image: Union[torch.Tensor, tio.ScalarImage],
label: Optional[Union[float, int, torch.Tensor, tio.LabelMap]],
**masks: Union[torch.Tensor, tio.LabelMap],
) -> None:
if not isinstance(image, tio.ScalarImage):
image = tio.ScalarImage(tensor=image)
image_shape = image.tensor.shape

if isinstance(label, torch.Tensor):
label = tio.LabelMap(tensor=label)

for name, mask in masks.items():
if not isinstance(mask, tio.LabelMap):
masks[name] = tio.LabelMap(tensor=mask)
masks[LABEL] = label

for name, mask in masks.items():
if isinstance(mask, tio.LabelMap):
assert mask.tensor.shape == image_shape, (
f"Masks must be the same shape as the image, but got "
f"{image_shape} for the image and {mask.tensor.shape} "
f"for '{name}')"
)

super().__init__(image=image, **masks)

def __copy__(self):
return _subject_copy_helper(self, type(self))


def _subject_copy_helper(
old_obj: tio.Subject,
new_subj_cls: Callable[[Dict[str, Any]], tio.Subject],
):
"""
Adapted torchio.data.subject._subject_copy_helper to work
with DataPoint.
"""
result_dict = {}
for key, value in old_obj.items():
if isinstance(value, tio.Image):
value = copy.copy(value)
else:
value = copy.deepcopy(value)
result_dict[key] = value

new = new_subj_cls(**result_dict)
new.applied_transforms = old_obj.applied_transforms[:]
return new


class Mask:
"""To handle masks in ClinicaDL. More precisely, it makes the difference
between a mask passed as a file name, that corresponds to a common mask,
and a mask passed as a suffix (a simple string), that corresponds to a mask
specific to each subject.

For example, `Mask("masks/mask.nii.gz")` will be understood has a common
mask, where as `Mask("mask")` will be understood has a specific mask.

In the latter case, it is expected that all the (subject, session) studied
have the associated mask in their CAPS folders. It will look for files with
the suffix `mask` in these folders.

Parameters
----------
filename : Union[str, Path]
the mask, passed as a path or a suffix.
"""

def __init__(self, mask: Union[str, Path]) -> None:
if isinstance(mask, Path):
if not self._check_path(mask):
raise ValueError(
f"The mask has been passed as a Path object (got {mask}), but no such file exists."
)
self.common_mask = True
self.mask = Path(mask)

elif isinstance(mask, str):
if self._check_path(mask):
self.common_mask = True
self.mask = Path(mask)
else:
self.common_mask = False
self.mask = mask

@staticmethod
def _check_path(mask_path: Union[str, Path]) -> bool:
"""Checks if the mask file exists."""
mask_path = Path(mask_path)
return mask_path.is_file()

def get_associated_mask(self, filename: Union[str, Path]) -> Path:
"""
Returns the mask associated to an image.

If the mask is common to all subjects and sessions, the method will
simply return it. On the other hand, if the mask is specific to each
(subject, session), the method will use the input `filename` to get
the associated mask.

Parameters
----------
filename : Union[str, Path]
the image whose associated mask is to be found.

Returns
-------
Path :
the path to the mask associated to the image.

Raises
------
ValueError
if the associated mask doesn't exist.

Examples
--------
>>> mask=Mask("seg")
>>> mask.get_associated_mask("sub-001_ses-M000_T1w.nii.gz")
PosixPath('sub-001_ses-M000_seg.nii.gz')

>>> mask=Mask("masks/leftHippocampus.nii.gz")
>>> mask.get_associated_mask("sub-001_ses-M000_T1w.nii.gz")
PosixPath('masks/leftHippocampus.nii.gz')
"""

if self.common_mask:
return self.mask
else:
filename = Path(filename)
without_extension = str(filename).rstrip("".join(filename.suffixes))
suffix = without_extension.rsplit("_", maxsplit=1)[-1]
mask_file = str(filename).replace(f"_{suffix}.", f"_{self.mask}.")
if not self._check_path(mask_file):
raise ValueError(
f"A mask associated to {str(filename)} was expected "
f"to be found in {mask_file}, but there is no such file."
)

return Path(mask_file)
20 changes: 0 additions & 20 deletions clinicadl/networks/old_network/__init__.py

This file was deleted.

124 changes: 0 additions & 124 deletions clinicadl/networks/old_network/autoencoder/cnn_transformer.py

This file was deleted.

Loading
Loading