Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix data-consistency module #9

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open

Conversation

bilalkabas
Copy link

@bilalkabas bilalkabas commented Aug 3, 2024

Summary

Fixes #10

This PR addresses the problem with data-consistency module and 2D Fourier transform functions fft2, and ifft2. The data-consistency module has been updated, fft2c and ifft2c functions are added to transforms.py.

Problem definition

The below data-consistency module does not work:

class DataConsistencyInKspace(nn.Module):
""" Create data consistency operator
Warning: note that FFT2 (by the default of torch.fft) is applied to the last 2 axes of the input.
This method detects if the input tensor is 4-dim (2D data) or 5-dim (3D data)
and applies FFT2 to the (nx, ny) axis.
"""
def __init__(self):
super(DataConsistencyInKspace, self).__init__()
def forward(self, *input, **kwargs):
return self.perform(*input)
def data_consistency(self,k, k0, mask):
"""
k - input in k-space
k0 - initially sampled elements in k-space
mask - corresponding nonzero location
"""
out = (1 - mask) * k + mask * k0
return out
def perform(self, x, k0, mask):
"""
x - input in image domain, of shape (n, 2, nx, ny[, nt])
k0 - initially sampled elements in k-space
mask - corresponding nonzero location
"""
x = x.permute(0, 2, 3, 1)
k0 = k0.permute(0, 2, 3, 1)
mask = mask.permute(0, 2, 3, 1)
k = transforms.fft2(x)
out = self.data_consistency(k, k0, mask)
x_res = transforms.ifft2(out)
x_res = x_res.permute(0, 3, 1, 2)
return x_res

This is due to some errors in fft2 and ifft2 functions in transforms.py:

def fft2(data, normalized=True):
"""
Apply centered 2 dimensional Fast Fourier Transform.
Args:
data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions
-3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are
assumed to be batch dimensions.
Returns:
torch.Tensor: The FFT of the input.
"""
assert data.size(-1) == 2
data = ifftshift(data, dim=(-3, -2))
data = torch.fft(data, 2, normalized=normalized)
data = fftshift(data, dim=(-3, -2))
return data
def rfft2(data):
"""
Apply centered 2 dimensional Fast Fourier Transform.
Args:
data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions
-3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are
assumed to be batch dimensions.
Returns:
torch.Tensor: The FFT of the input.
"""
data = ifftshift(data, dim=(-2, -1))
data = torch.rfft(data, 2, normalized=True, onesided=False)
data = fftshift(data, dim=(-3, -2))
return data

To Reproduce

import torch
from backbones.reconformer.reconformer import DataConsistencyInKspace

resolution = 320
device = 'cuda:0'

x = torch.randn((1, 2, resolution, resolution)).to(device)
k0 = torch.randn((1, 2, resolution, resolution)).to(device)
mask = torch.randn((1, 1, resolution, resolution)).to(device)

dc = DataConsistencyInKspace()
out = dc(x, k0, mask)

print(f"Input shape: {x.shape}")
print(f"Output shape: {out.shape}")
     46 k0 = k0.permute(0, 2, 3, 1)
     47 mask = mask.permute(0, 2, 3, 1)
...
--> 122 data = torch.fft.fft(data, 2, normalized=normalized)
    123 data = fftshift(data, dim=(-3, -2))
    124 return data

TypeError: fft_fft() got an unexpected keyword argument 'normalized'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Problem in data-consistency module
1 participant