Skip to content

Commit

Permalink
TorchIO, RandTorchIO, TorchIOd and RandTorchIOd; add RandTorchVision …
Browse files Browse the repository at this point in the history
…as well

Signed-off-by: Fabian Klopfer <[email protected]>
  • Loading branch information
SomeUserName1 committed Nov 18, 2024
1 parent 472c747 commit 09d1099
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 27 deletions.
80 changes: 68 additions & 12 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@
"ToDevice",
"CuCIM",
"RandCuCIM",
"RandTorchIO",
"RandTorchVision",
"ToCupy",
"ImageFilter",
"RandImageFilter",
Expand Down Expand Up @@ -1139,7 +1141,7 @@ def __call__(

class TorchVision(Transform):
"""
This is a wrapper transform for PyTorch TorchVision transform based on the specified transform name and args.
This is a wrapper transform for PyTorch TorchVision non-randomized transform based on the specified transform name and args.
Data is converted to a torch.tensor before applying the transform and then converted back to the original data type.
"""

Expand Down Expand Up @@ -1171,9 +1173,43 @@ def __call__(self, img: NdarrayOrTensor):
return out


class TorchIO(Transform, RandomizableTrait):
class RandTorchVision(Transform, RandomizableTrait):
"""
This is a wrapper transform for TorchIO transforms based on the specified transform name and args.
This is a wrapper transform for PyTorch TorchVision randomized transform based on the specified transform name and args.
Data is converted to a torch.tensor before applying the transform and then converted back to the original data type.
"""

backend = [TransformBackends.TORCH]

def __init__(self, name: str, *args, **kwargs) -> None:
"""
Args:
name: The transform name in TorchVision package.
args: parameters for the TorchVision transform.
kwargs: parameters for the TorchVision transform.
"""
super().__init__()
self.name = name
transform, _ = optional_import("torchvision.transforms", "0.8.0", min_version, name=name)
self.trans = transform(*args, **kwargs)

def __call__(self, img: NdarrayOrTensor):
"""
Args:
img: PyTorch Tensor data for the TorchVision transform.
"""
img_t, *_ = convert_data_type(img, torch.Tensor)

out = self.trans(img_t)
out, *_ = convert_to_dst_type(src=out, dst=img)
return out


class TorchIO(Transform):
"""
This is a wrapper for TorchIO non-randomized transforms based on the specified transform name and args.
See https://torchio.readthedocs.io/transforms/transforms.html for more details.
"""

Expand All @@ -1185,21 +1221,41 @@ def __init__(self, name: str, *args, **kwargs) -> None:
name: The transform name in TorchIO package.
args: parameters for the TorchIO transform.
kwargs: parameters for the TorchIO transform.
Note:
The `p=` kwarg of TorchIO transforms control set the probability with which the transform is applied.
You can specify the probability of applying the transform by passing either `prob` ot `p` in kwargs but'
' not both.
"""
super().__init__()
self.name = name
transform, _ = optional_import("torchio.transforms", "0.18.0", min_version, name=name)
self.trans = transform(*args, **kwargs)

if "prob" in kwargs:
if "p" in kwargs:
raise ValueError("Cannot specify both 'prob' and 'p' in kwargs.")
kwargs["p"] = kwargs.pop("prob")
def __call__(self, img: NdarrayOrTensor):
"""
Args:
img: an instance of torchio.Subject, torchio.Image, numpy.ndarray, torch.Tensor, SimpleITK.Image,
or dict containing 4D tensors as values
"""
return self.trans(img)

class RandTorchIO(Transform, RandomizableTrait):
"""
This is a wrapper for TorchIO randomized transforms based on the specified transform name and args.
See https://torchio.readthedocs.io/transforms/transforms.html for more details.
Use this wrapper for all TorchIO transform inheriting from RandomTransform:
https://torchio.readthedocs.io/transforms/augmentation.html#randomtransform
"""

backend = [TransformBackends.TORCH]

def __init__(self, name: str, *args, **kwargs) -> None:
"""
Args:
name: The transform name in TorchIO package.
args: parameters for the TorchIO transform.
kwargs: parameters for the TorchIO transform.
"""
super().__init__()
self.name = name
transform, _ = optional_import("torchio.transforms", "0.18.0", min_version, name=name)
self.trans = transform(*args, **kwargs)

def __call__(self, img: NdarrayOrTensor):
Expand Down
60 changes: 45 additions & 15 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@
"RandCuCIMD",
"RandCuCIMDict",
"RandImageFilterd",
"RandTorchIOd",
"RandTorchIOD",
"RandTorchIODict",
"RandLambdaD",
"RandLambdaDict",
"RandLambdad",
Expand Down Expand Up @@ -1449,10 +1452,10 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
return d


class TorchIOd(MapTransform, RandomizableTrait):
class TorchIOd(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.TorchIO` for transforms.
All transforms in TorchIO can be applied randomly with probability p by specifying the `p=` argument.
Dictionary-based wrapper of :py:class:`monai.transforms.TorchIO` for non-randomized transforms.
For randomized transforms of TorchIO use :py:class:`monai.transforms.RandTorchIOd`.
"""

backend = TorchIO.backend
Expand All @@ -1461,7 +1464,6 @@ def __init__(
self,
keys: KeysCollection,
name: str,
apply_same_transform: bool = False,
allow_missing_keys: bool = False,
*args,
**kwargs,
Expand All @@ -1479,21 +1481,49 @@ def __init__(
"""
super().__init__(keys, allow_missing_keys)
self.name = name
self.apply_same_transform = apply_same_transform
kwargs["include"] = self.keys

self.trans = TorchIO(name, *args, **kwargs)

def __call__(self, data: Mapping[NdarrayOrTensor]) -> dict[NdarrayOrTensor]:
return self.trans(dict(data))

if self.apply_same_transform:
kwargs["include"] = self.keys

class RandTorchIOd(MapTransform, RandomizableTrait):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.TorchIO` for randomized transforms.
For non-randomized transforms of TorchIO use :py:class:`monai.transforms.TorchIOd`.
"""

backend = TorchIO.backend

def __init__(
self,
keys: KeysCollection,
name: str,
allow_missing_keys: bool = False,
*args,
**kwargs,
) -> None:
"""
Args:
keys: keys of the corresponding items to be transformed.
See also: :py:class:`monai.transforms.compose.MapTransform`
name: The transform name in TorchIO package.
apply_same_transform: whether to apply the same transform for all the items specified by `keys`.
allow_missing_keys: don't raise exception if key is missing.
args: parameters for the TorchIO transform.
kwargs: parameters for the TorchIO transform.
"""
super().__init__(keys, allow_missing_keys)
self.name = name
kwargs["include"] = self.keys

self.trans = TorchIO(name, *args, **kwargs)

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
d = dict(data)
if self.apply_same_transform:
d = self.trans(d)
else:
for key in self.key_iterator(d):
d[key] = self.trans(d[key])
return d
def __call__(self, data: Mapping[NdarrayOrTensor]) -> dict[NdarrayOrTensor]:
return self.trans(dict(data))


class MapLabelValued(MapTransform):
Expand Down

0 comments on commit 09d1099

Please sign in to comment.