diff --git a/pycpd/rigid_registration.py b/pycpd/rigid_registration.py index fa43e6d..272ebb6 100644 --- a/pycpd/rigid_registration.py +++ b/pycpd/rigid_registration.py @@ -25,6 +25,10 @@ class RigidRegistration(EMRegistration): Utility array used to calculate the rotation matrix. Defined in Fig. 2 of https://arxiv.org/pdf/0905.2635.pdf. + rotationFilter: single character + 'x','y' or 'z' + If one of those values is set, only rotations areound this axis are allowed + """ # Additional parameters used in this class, but not inputs. # YPY: float @@ -36,7 +40,7 @@ class RigidRegistration(EMRegistration): # Defined in Fig. 2 of https://arxiv.org/pdf/0905.2635.pdf. - def __init__(self, R=None, t=None, s=None, scale=True, *args, **kwargs): + def __init__(self, R=None, t=None, s=None, rotationFilter=None, scale=True, *args, **kwargs): super().__init__(*args, **kwargs) if self.D != 2 and self.D != 3: raise ValueError( @@ -54,10 +58,15 @@ def __init__(self, R=None, t=None, s=None, scale=True, *args, **kwargs): raise ValueError( 'The scale factor must be a positive number. Instead got: {}.'.format(s)) + if rotationFilter is not None and ((rotationFilter != 'x') and (rotationFilter != 'y') and (rotationFilter != 'z')): + raise ValueError( + 'Valid rotation filters are x,y or z. If one of those values is set, only rotations areound this axis are allowed. Instead got: {}.'.format(rotationFilter)) + self.R = np.eye(self.D) if R is None else R self.t = np.atleast_2d(np.zeros((1, self.D))) if t is None else t self.s = 1 if s is None else s self.scale = scale + self.rotationFilter = rotationFilter def update_transform(self): """ @@ -88,6 +97,26 @@ def update_transform(self): # Calculate the rotation matrix using Eq. 9 of https://arxiv.org/pdf/0905.2635.pdf. self.R = np.transpose(np.dot(np.dot(U, np.diag(C)), V)) + + if self.rotationFilter is not None and self.rotationFilter == 'z': + self.R[0,2] = 0 + self.R[1,2] = 0 + self.R[2,0] = 0 + self.R[2,1] = 0 + self.R[2,2] = 1 + elif self.rotationFilter is not None and self.rotationFilter == 'x': + self.R[0,0] = 1 + self.R[0,1] = 0 + self.R[0,2] = 0 + self.R[1,0] = 0 + self.R[2,0] = 0 + elif self.rotationFilter is not None and self.rotationFilter == 'y': + self.R[0,1] = 0 + self.R[1,0] = 0 + self.R[2,1] = 0 + self.R[1,2] = 0 + self.R[1,1] = 1 + # Update scale and translation using Fig. 2 of https://arxiv.org/pdf/0905.2635.pdf. if self.scale is True: self.s = np.trace(np.dot(np.transpose(self.A), np.transpose(self.R))) / self.YPY