diff --git a/.gitignore b/.gitignore index 32720c5..461cc11 100644 --- a/.gitignore +++ b/.gitignore @@ -103,6 +103,7 @@ ENV/ # IDE settings .vscode/ +.idea/ libtilt/_version.py src/libtilt/_version.py diff --git a/src/libtilt/ctf/ctf_2d.py b/src/libtilt/ctf/ctf_2d.py index 2bb8507..72fe015 100644 --- a/src/libtilt/ctf/ctf_2d.py +++ b/src/libtilt/ctf/ctf_2d.py @@ -22,6 +22,7 @@ def calculate_ctf( image_shape: Tuple[int, int], rfft: bool, fftshift: bool, + device: torch.device | None = None ): """ @@ -56,20 +57,22 @@ def calculate_ctf( Whether to apply fftshift on the resulting CTF images. """ # to torch.Tensor and unit conversions - defocus = torch.atleast_1d(torch.as_tensor(defocus, dtype=torch.float)) + if bool(rfft) + bool(fftshift) > 1: + raise ValueError("Only one of `rfft` and `fftshift` may be `True`.") + defocus = torch.atleast_1d(torch.as_tensor(defocus, dtype=torch.float, device=device)) defocus *= 1e4 # micrometers -> angstroms - astigmatism = torch.atleast_1d(torch.as_tensor(astigmatism, dtype=torch.float)) + astigmatism = torch.atleast_1d(torch.as_tensor(astigmatism, dtype=torch.float, device=device)) astigmatism *= 1e4 # micrometers -> angstroms - astigmatism_angle = torch.atleast_1d(torch.as_tensor(astigmatism_angle, dtype=torch.float)) + astigmatism_angle = torch.atleast_1d(torch.as_tensor(astigmatism_angle, dtype=torch.float, device=device)) astigmatism_angle *= (C.pi / 180) # degrees -> radians - pixel_size = torch.atleast_1d(torch.as_tensor(pixel_size)) - voltage = torch.atleast_1d(torch.as_tensor(voltage, dtype=torch.float)) + pixel_size = torch.atleast_1d(torch.as_tensor(pixel_size, device=device)) + voltage = torch.atleast_1d(torch.as_tensor(voltage, dtype=torch.float, device=device)) voltage *= 1e3 # kV -> V spherical_aberration = torch.atleast_1d( - torch.as_tensor(spherical_aberration, dtype=torch.float) + torch.as_tensor(spherical_aberration, dtype=torch.float, device=device) ) spherical_aberration *= 1e7 # mm -> angstroms - image_shape = torch.as_tensor(image_shape) + image_shape = torch.as_tensor(image_shape, device=device) # derived quantities used in CTF calculation defocus_u = defocus + astigmatism @@ -79,10 +82,10 @@ def calculate_ctf( k2 = C.pi / 2 * spherical_aberration * _lambda ** 3 k3 = torch.tensor(np.deg2rad(phase_shift)) k4 = -b_factor / 4 - k5 = np.arctan(amplitude_contrast / np.sqrt(1 - amplitude_contrast ** 2)) + k5 = torch.arctan(amplitude_contrast / torch.sqrt(1 - amplitude_contrast ** 2)) # construct 2D frequency grids and rescale cycles / px -> cycles / Å - fftfreq_grid = _construct_fftfreq_grid_2d(image_shape=image_shape, rfft=rfft) # (h, w, 2) + fftfreq_grid = _construct_fftfreq_grid_2d(image_shape=image_shape, rfft=rfft, device=device) # (h, w, 2) fftfreq_grid = fftfreq_grid / einops.rearrange(pixel_size, 'b -> b 1 1 1') fftfreq_grid_squared = fftfreq_grid ** 2 diff --git a/src/libtilt/fft_utils.py b/src/libtilt/fft_utils.py index d02bfe2..1c4d8e1 100644 --- a/src/libtilt/fft_utils.py +++ b/src/libtilt/fft_utils.py @@ -22,10 +22,11 @@ def dft_center( device: torch.device | None = None, ) -> torch.LongTensor: """Return the position of the DFT center for a given input shape.""" + _rfft_shape = rfft_shape(image_shape) fft_center = torch.zeros(size=(len(image_shape),), device=device) image_shape = torch.as_tensor(image_shape).float() if rfft is True: - image_shape = torch.tensor(rfft_shape(image_shape)) + image_shape = torch.tensor(_rfft_shape, device=device) if fftshifted is True: fft_center = torch.divide(image_shape, 2, rounding_mode='floor') if rfft is True: @@ -438,11 +439,13 @@ def fftfreq_to_dft_coordinates( coordinates: torch.Tensor `(..., d)` array of coordinates into a fftshifted DFT. """ + _image_shape = image_shape image_shape = torch.as_tensor( - image_shape, device=frequencies.device, dtype=frequencies.dtype + _image_shape, device=frequencies.device, dtype=frequencies.dtype ) + _rfft_shape = rfft_shape(_image_shape) _rfft_shape = torch.as_tensor( - rfft_shape(image_shape), device=frequencies.device, dtype=frequencies.dtype + _rfft_shape, device=frequencies.device, dtype=frequencies.dtype ) coordinates = torch.empty_like(frequencies) coordinates[..., :-1] = frequencies[..., :-1] * image_shape[:-1] @@ -450,5 +453,5 @@ def fftfreq_to_dft_coordinates( coordinates[..., -1] = frequencies[..., -1] * 2 * (_rfft_shape[-1] - 1) else: coordinates[..., -1] = frequencies[..., -1] * image_shape[-1] - dc = dft_center(image_shape, rfft=rfft, fftshifted=True, device=frequencies.device) + dc = dft_center(_image_shape, rfft=rfft, fftshifted=True, device=frequencies.device) return coordinates + dc diff --git a/src/libtilt/interpolation/interpolate_dft_3d.py b/src/libtilt/interpolation/interpolate_dft_3d.py index e754d13..ec0a588 100644 --- a/src/libtilt/interpolation/interpolate_dft_3d.py +++ b/src/libtilt/interpolation/interpolate_dft_3d.py @@ -49,7 +49,8 @@ def sample_dft_3d( samples = torch.view_as_complex(samples.contiguous()) # (b, ) # pack data back up and return - [samples] = einops.unpack(samples, pattern='*', packed_shapes=ps) + # [samples] = einops.unpack(samples, pattern='*', packed_shapes=ps) + samples = samples.reshape(*ps) # replaces commented line above, for performance return samples # (...) diff --git a/src/libtilt/projection/project_fourier.py b/src/libtilt/projection/project_fourier.py index be14b3c..b72a4bc 100644 --- a/src/libtilt/projection/project_fourier.py +++ b/src/libtilt/projection/project_fourier.py @@ -1,3 +1,5 @@ +from typing import Tuple + import torch import torch.nn.functional as F import einops @@ -33,30 +35,12 @@ def project_fourier( projections: torch.Tensor `(..., d, d)` array of projection images. """ - # padding - if pad is True: - pad_length = volume.shape[-1] // 2 - volume = F.pad(volume, pad=[pad_length] * 6, mode='constant', value=0) - - # premultiply by sinc2 - grid = fftfreq_grid( - image_shape=volume.shape, - rfft=False, - fftshift=True, - norm=True, - device=volume.device - ) - volume = volume * torch.sinc(grid) ** 2 - - # calculate DFT - dft = torch.fft.fftshift(volume, dim=(-3, -2, -1)) # volume center to array origin - dft = torch.fft.rfftn(dft, dim=(-3, -2, -1)) - dft = torch.fft.fftshift(dft, dim=(-3, -2,)) # actual fftshift of rfft + dft, vol_shape, pad_length = _compute_dft(volume, pad) # make projections by taking central slices projections = extract_central_slices_rfft( dft=dft, - image_shape=volume.shape, + image_shape=vol_shape, rotation_matrices=rotation_matrices, rotation_matrix_zyx=rotation_matrix_zyx ) # (..., h, w) rfft @@ -92,7 +76,8 @@ def extract_central_slices_rfft( # flip coordinates in redundant half transform conjugate_mask = grid[..., 2] < 0 - conjugate_mask = einops.repeat(conjugate_mask, '... -> ... 3') + # conjugate_mask = einops.repeat(conjugate_mask, '... -> ... 3') #This operation does not compile + conjugate_mask = conjugate_mask.unsqueeze(-1).expand(*[-1] * len(conjugate_mask.shape), 3) #This does grid[conjugate_mask] *= -1 conjugate_mask = conjugate_mask[..., 0] # un-repeat @@ -107,3 +92,49 @@ def extract_central_slices_rfft( # take complex conjugate of values from redundant half transform projections[conjugate_mask] = torch.conj(projections[conjugate_mask]) return projections + +def _compute_dft( + volume: torch.Tensor, + pad: bool = True, + pad_length: int | None = None +) -> Tuple[torch.Tensor, Tuple[int,int,int], int]: + """Computes the DFT of a volume. Intended to be used as a preprocessing before using extract_central_slices_rfft. + + Parameters + ---------- + volume: torch.Tensor + `(d, d, d)` volume. + pad: bool + Whether to pad the volume with zeros to increase sampling in the DFT. + pad_length: int | None + The length used for padding each side of each dimension. If pad_length=None, and pad=True then volume.shape[-1] // 2 is used instead + + Returns + ------- + projections: Tuple[torch.Tensor, torch.Tensor, int] + `(..., d, d, d)` dft of the volume. fftshifted rfft + Tuple[int,int,int] the shape of the volume after padding + int with the padding length + """ + # padding + if pad is True: + if pad_length is None: + pad_length = volume.shape[-1] // 2 + volume = F.pad(volume, pad=[pad_length] * 6, mode='constant', value=0) + + # premultiply by sinc2 + grid = fftfreq_grid( + image_shape=volume.shape, + rfft=False, + fftshift=True, + norm=True, + device=volume.device + ) + volume = volume * torch.sinc(grid) ** 2 + + # calculate DFT + dft = torch.fft.fftshift(volume, dim=(-3, -2, -1)) # volume center to array origin + dft = torch.fft.rfftn(dft, dim=(-3, -2, -1)) + dft = torch.fft.fftshift(dft, dim=(-3, -2,)) # actual fftshift of rfft + + return dft, volume.shape, pad_length diff --git a/src/libtilt/projection/project_real.py b/src/libtilt/projection/project_real.py index 36a0940..63dbc27 100644 --- a/src/libtilt/projection/project_real.py +++ b/src/libtilt/projection/project_real.py @@ -49,7 +49,7 @@ def project_real(volume: torch.Tensor, rotation_matrices: torch.Tensor) -> torch torch_padding = einops.rearrange(torch_padding, 'whd pad -> (whd pad)') volume = F.pad(volume, pad=tuple(torch_padding), mode='constant', value=0) padded_volume_shape = (ps, ps, ps) - volume_coordinates = coordinate_grid(image_shape=padded_volume_shape) + volume_coordinates = coordinate_grid(image_shape=padded_volume_shape, device=volume.device) volume_coordinates -= padded_sidelength // 2 # (d, h, w, zyx) volume_coordinates = torch.flip(volume_coordinates, dims=(-1,)) # (d, h, w, zyx) volume_coordinates = einops.rearrange(volume_coordinates, 'd h w zyx -> d h w zyx 1') @@ -73,5 +73,5 @@ def _project_volume(rotation_matrix) -> torch.Tensor: yl, yh = padding[1, 0], -padding[1, 1] xl, xh = padding[2, 0], -padding[2, 1] - images = [_project_volume(matrix)[yl:yh, xl:xh] for matrix in rotation_matrices] + images = [_project_volume(matrix)[yl:yh, xl:xh] for matrix in rotation_matrices] #TODO: This can probabaly optimized using vmap return torch.stack(images, dim=0)