Skip to content

Commit

Permalink
Add kwarg to not check spatial consistency
Browse files Browse the repository at this point in the history
Related to #734.
  • Loading branch information
fepegar committed Nov 11, 2021
1 parent b28cdc0 commit 16c0722
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
8 changes: 8 additions & 0 deletions tests/transforms/augmentation/test_random_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,11 @@ def test_no_inverse(self):
)
transformed = apply_affine(tensor)
self.assertTensorAlmostEqual(transformed, expected)

def test_different_spaces(self):
t1 = self.sample_subject.t1
label = tio.Resample(2)(self.sample_subject.label)
new_subject = tio.Subject(t1=t1, label=label)
with self.assertRaises(RuntimeError):
tio.RandomAffine()(new_subject)
tio.RandomAffine(check_shape=False)(new_subject)
16 changes: 14 additions & 2 deletions torchio/transforms/augmentation/spatial/random_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ class RandomAffine(RandomTransform, SpatialTransform):
`Otsu threshold <https://ieeexplore.ieee.org/document/4310076>`_.
If it is a number, that value will be used.
image_interpolation: See :ref:`Interpolation`.
check_shape: If ``True`` an error will be raised if the images are in
different physical spaces. If ``False``, :attr:`center` should
probably not be ``'image'`` but ``'center'``.
**kwargs: See :class:`~torchio.transforms.Transform` for additional
keyword arguments.
Expand Down Expand Up @@ -112,6 +115,7 @@ def __init__(
center: str = 'image',
default_pad_value: Union[str, float] = 'minimum',
image_interpolation: str = 'linear',
check_shape: bool = True,
**kwargs
):
super().__init__(**kwargs)
Expand All @@ -129,6 +133,7 @@ def __init__(
self.center = center
self.default_pad_value = _parse_default_value(default_pad_value)
self.image_interpolation = self.parse_interpolation(image_interpolation)
self.check_shape = check_shape

def get_params(
self,
Expand All @@ -145,7 +150,6 @@ def get_params(
return scaling_params, rotation_params, translation_params

def apply_transform(self, subject: Subject) -> Subject:
subject.check_consistent_spatial_shape()
scaling_params, rotation_params, translation_params = self.get_params(
self.scales,
self.degrees,
Expand All @@ -159,6 +163,7 @@ def apply_transform(self, subject: Subject) -> Subject:
center=self.center,
default_pad_value=self.default_pad_value,
image_interpolation=self.image_interpolation,
check_shape=self.check_shape,
)
transform = Affine(**self.add_include_exclude(arguments))
transformed = transform(subject)
Expand Down Expand Up @@ -187,6 +192,9 @@ class Affine(SpatialTransform):
`Otsu threshold <https://ieeexplore.ieee.org/document/4310076>`_.
If it is a number, that value will be used.
image_interpolation: See :ref:`Interpolation`.
check_shape: If ``True`` an error will be raised if the images are in
different physical spaces. If ``False``, :attr:`center` should
probably not be ``'image'`` but ``'center'``.
**kwargs: See :class:`~torchio.transforms.Transform` for additional
keyword arguments.
"""
Expand All @@ -198,6 +206,7 @@ def __init__(
center: str = 'image',
default_pad_value: Union[str, float] = 'minimum',
image_interpolation: str = 'linear',
check_shape: bool = True,
**kwargs
):
super().__init__(**kwargs)
Expand Down Expand Up @@ -231,13 +240,15 @@ def __init__(
self.default_pad_value = _parse_default_value(default_pad_value)
self.image_interpolation = self.parse_interpolation(image_interpolation)
self.invert_transform = False
self.check_shape = check_shape
self.args_names = (
'scales',
'degrees',
'translation',
'center',
'default_pad_value',
'image_interpolation',
'check_shape',
)

@staticmethod
Expand Down Expand Up @@ -322,7 +333,8 @@ def get_affine_transform(self, image):
return transform

def apply_transform(self, subject: Subject) -> Subject:
subject.check_consistent_spatial_shape()
if self.check_shape:
subject.check_consistent_spatial_shape()
for image in self.get_images(subject):
transform = self.get_affine_transform(image)
transformed_tensors = []
Expand Down

0 comments on commit 16c0722

Please sign in to comment.