From 09d1099dd0d68baa57aa58d7415f06434f1b2ff5 Mon Sep 17 00:00:00 2001 From: Fabian Klopfer Date: Mon, 18 Nov 2024 20:08:50 +0100 Subject: [PATCH] TorchIO, RandTorchIO, TorchIOd and RandTorchIOd; add RandTorchVision as well Signed-off-by: Fabian Klopfer --- monai/transforms/utility/array.py | 80 ++++++++++++++++++++++---- monai/transforms/utility/dictionary.py | 60 ++++++++++++++----- 2 files changed, 113 insertions(+), 27 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 9d67e69033..28b5f11142 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -105,6 +105,8 @@ "ToDevice", "CuCIM", "RandCuCIM", + "RandTorchIO", + "RandTorchVision", "ToCupy", "ImageFilter", "RandImageFilter", @@ -1139,7 +1141,7 @@ def __call__( class TorchVision(Transform): """ - This is a wrapper transform for PyTorch TorchVision transform based on the specified transform name and args. + This is a wrapper transform for PyTorch TorchVision non-randomized transform based on the specified transform name and args. Data is converted to a torch.tensor before applying the transform and then converted back to the original data type. """ @@ -1171,9 +1173,43 @@ def __call__(self, img: NdarrayOrTensor): return out -class TorchIO(Transform, RandomizableTrait): +class RandTorchVision(Transform, RandomizableTrait): """ - This is a wrapper transform for TorchIO transforms based on the specified transform name and args. + This is a wrapper transform for PyTorch TorchVision randomized transform based on the specified transform name and args. + Data is converted to a torch.tensor before applying the transform and then converted back to the original data type. + """ + + backend = [TransformBackends.TORCH] + + def __init__(self, name: str, *args, **kwargs) -> None: + """ + Args: + name: The transform name in TorchVision package. + args: parameters for the TorchVision transform. + kwargs: parameters for the TorchVision transform. + + """ + super().__init__() + self.name = name + transform, _ = optional_import("torchvision.transforms", "0.8.0", min_version, name=name) + self.trans = transform(*args, **kwargs) + + def __call__(self, img: NdarrayOrTensor): + """ + Args: + img: PyTorch Tensor data for the TorchVision transform. + + """ + img_t, *_ = convert_data_type(img, torch.Tensor) + + out = self.trans(img_t) + out, *_ = convert_to_dst_type(src=out, dst=img) + return out + + +class TorchIO(Transform): + """ + This is a wrapper for TorchIO non-randomized transforms based on the specified transform name and args. See https://torchio.readthedocs.io/transforms/transforms.html for more details. """ @@ -1185,21 +1221,41 @@ def __init__(self, name: str, *args, **kwargs) -> None: name: The transform name in TorchIO package. args: parameters for the TorchIO transform. kwargs: parameters for the TorchIO transform. - - Note: - The `p=` kwarg of TorchIO transforms control set the probability with which the transform is applied. - You can specify the probability of applying the transform by passing either `prob` ot `p` in kwargs but' - ' not both. """ super().__init__() self.name = name transform, _ = optional_import("torchio.transforms", "0.18.0", min_version, name=name) + self.trans = transform(*args, **kwargs) - if "prob" in kwargs: - if "p" in kwargs: - raise ValueError("Cannot specify both 'prob' and 'p' in kwargs.") - kwargs["p"] = kwargs.pop("prob") + def __call__(self, img: NdarrayOrTensor): + """ + Args: + img: an instance of torchio.Subject, torchio.Image, numpy.ndarray, torch.Tensor, SimpleITK.Image, + or dict containing 4D tensors as values + """ + return self.trans(img) + +class RandTorchIO(Transform, RandomizableTrait): + """ + This is a wrapper for TorchIO randomized transforms based on the specified transform name and args. + See https://torchio.readthedocs.io/transforms/transforms.html for more details. + Use this wrapper for all TorchIO transform inheriting from RandomTransform: + https://torchio.readthedocs.io/transforms/augmentation.html#randomtransform + """ + + backend = [TransformBackends.TORCH] + + def __init__(self, name: str, *args, **kwargs) -> None: + """ + Args: + name: The transform name in TorchIO package. + args: parameters for the TorchIO transform. + kwargs: parameters for the TorchIO transform. + """ + super().__init__() + self.name = name + transform, _ = optional_import("torchio.transforms", "0.18.0", min_version, name=name) self.trans = transform(*args, **kwargs) def __call__(self, img: NdarrayOrTensor): diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index f29119d348..14409fc0e6 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -134,6 +134,9 @@ "RandCuCIMD", "RandCuCIMDict", "RandImageFilterd", + "RandTorchIOd", + "RandTorchIOD", + "RandTorchIODict", "RandLambdaD", "RandLambdaDict", "RandLambdad", @@ -1449,10 +1452,10 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N return d -class TorchIOd(MapTransform, RandomizableTrait): +class TorchIOd(MapTransform): """ - Dictionary-based wrapper of :py:class:`monai.transforms.TorchIO` for transforms. - All transforms in TorchIO can be applied randomly with probability p by specifying the `p=` argument. + Dictionary-based wrapper of :py:class:`monai.transforms.TorchIO` for non-randomized transforms. + For randomized transforms of TorchIO use :py:class:`monai.transforms.RandTorchIOd`. """ backend = TorchIO.backend @@ -1461,7 +1464,6 @@ def __init__( self, keys: KeysCollection, name: str, - apply_same_transform: bool = False, allow_missing_keys: bool = False, *args, **kwargs, @@ -1479,21 +1481,49 @@ def __init__( """ super().__init__(keys, allow_missing_keys) self.name = name - self.apply_same_transform = apply_same_transform + kwargs["include"] = self.keys + + self.trans = TorchIO(name, *args, **kwargs) + + def __call__(self, data: Mapping[NdarrayOrTensor]) -> dict[NdarrayOrTensor]: + return self.trans(dict(data)) - if self.apply_same_transform: - kwargs["include"] = self.keys + +class RandTorchIOd(MapTransform, RandomizableTrait): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.TorchIO` for randomized transforms. + For non-randomized transforms of TorchIO use :py:class:`monai.transforms.TorchIOd`. + """ + + backend = TorchIO.backend + + def __init__( + self, + keys: KeysCollection, + name: str, + allow_missing_keys: bool = False, + *args, + **kwargs, + ) -> None: + """ + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + name: The transform name in TorchIO package. + apply_same_transform: whether to apply the same transform for all the items specified by `keys`. + allow_missing_keys: don't raise exception if key is missing. + args: parameters for the TorchIO transform. + kwargs: parameters for the TorchIO transform. + + """ + super().__init__(keys, allow_missing_keys) + self.name = name + kwargs["include"] = self.keys self.trans = TorchIO(name, *args, **kwargs) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: - d = dict(data) - if self.apply_same_transform: - d = self.trans(d) - else: - for key in self.key_iterator(d): - d[key] = self.trans(d[key]) - return d + def __call__(self, data: Mapping[NdarrayOrTensor]) -> dict[NdarrayOrTensor]: + return self.trans(dict(data)) class MapLabelValued(MapTransform):