Spaces:
Running
on
Zero
Running
on
Zero
from typing import Optional, Union | |
import numpy as np | |
import torch | |
from torch.fft import fft2, fftshift, ifft2, ifftshift | |
from torchvision.utils import save_image | |
def draw_img(img: Union[torch.Tensor, np.ndarray], | |
save_path:Optional[str]='test.png', | |
nrow:Optional[int]=8, | |
normalize:Optional[bool]=True): | |
if isinstance(img, np.ndarray): | |
img = torch.Tensor(img) | |
save_image(img, fp=save_path, nrow=nrow, normalize=normalize) | |
def normalize(img: Union[torch.Tensor, np.ndarray]) \ | |
-> Union[torch.Tensor, np.ndarray]: | |
return (img - img.min())/(img.max()-img.min()) | |
def to_np(img: torch.Tensor, | |
mode: Optional[str]='NCHW') -> np.ndarray: | |
assert mode in ['NCHW', 'NHWC'] | |
if mode == 'NCHW': | |
img = img.permute(0,2,3,1) | |
return img.detach().cpu().numpy() | |
def fft2d(img: torch.Tensor, | |
mode: Optional[str]='NCHW') -> torch.Tensor: | |
assert mode in ['NCHW', 'NHWC'] | |
if mode == 'NCHW': | |
return fftshift(fft2(img)) | |
elif mode == 'NHWC': | |
img = img.permute(0,3,1,2) | |
return fftshift(fft2(img)) | |
else: | |
raise NameError | |
def ifft2d(img: torch.Tensor, | |
mode: Optional[str]='NCHW') -> torch.Tensor: | |
assert mode in ['NCHW', 'NHWC'] | |
if mode == 'NCHW': | |
return ifft2(ifftshift(img)) | |
elif mode == 'NHWC': | |
img = ifft2(ifftshift(img)) | |
return img.permute(0,2,3,1) | |
else: | |
raise NameError | |
""" | |
Helper functions for new types of inverse problems | |
""" | |
def fft2c_new(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: | |
""" | |
Apply centered 2 dimensional Fast Fourier Transform. | |
Args: | |
data: Complex valued input data containing at least 3 dimensions: | |
dimensions -3 & -2 are spatial dimensions and dimension -1 has size | |
2. All other dimensions are assumed to be batch dimensions. | |
norm: Normalization mode. See ``torch.fft.fft``. | |
Returns: | |
The FFT of the input. | |
""" | |
if not data.shape[-1] == 2: | |
raise ValueError("Tensor does not have separate complex dim.") | |
data = ifftshift(data, dim=[-3, -2]) | |
data = torch.view_as_real( | |
torch.fft.fftn( # type: ignore | |
torch.view_as_complex(data), dim=(-2, -1), norm=norm | |
) | |
) | |
data = fftshift(data, dim=[-3, -2]) | |
return data | |
def ifft2c_new(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: | |
""" | |
Apply centered 2-dimensional Inverse Fast Fourier Transform. | |
Args: | |
data: Complex valued input data containing at least 3 dimensions: | |
dimensions -3 & -2 are spatial dimensions and dimension -1 has size | |
2. All other dimensions are assumed to be batch dimensions. | |
norm: Normalization mode. See ``torch.fft.ifft``. | |
Returns: | |
The IFFT of the input. | |
""" | |
if not data.shape[-1] == 2: | |
raise ValueError("Tensor does not have separate complex dim.") | |
data = ifftshift(data, dim=[-3, -2]) | |
data = torch.view_as_real( | |
torch.fft.ifftn( # type: ignore | |
torch.view_as_complex(data), dim=(-2, -1), norm=norm | |
) | |
) | |
data = fftshift(data, dim=[-3, -2]) | |
return data | |
def fft2(x): | |
""" FFT with shifting DC to the center of the image""" | |
return torch.fft.fftshift(torch.fft.fft2(x), dim=[-1, -2]) | |
def ifft2(x): | |
""" IFFT with shifting DC to the corner of the image prior to transform""" | |
return torch.fft.ifft2(torch.fft.ifftshift(x, dim=[-1, -2])) | |
def fft2_m(x): | |
""" FFT for multi-coil """ | |
if not torch.is_complex(x): | |
x = x.type(torch.complex64) | |
return torch.view_as_complex(fft2c_new(torch.view_as_real(x))) | |
def ifft2_m(x): | |
""" IFFT for multi-coil """ | |
if not torch.is_complex(x): | |
x = x.type(torch.complex64) | |
return torch.view_as_complex(ifft2c_new(torch.view_as_real(x))) |