FLAIR / src /flair /utils /img_util.py
juliuse's picture
Initial commit: track binaries with LFS
90a9dd3
from typing import Optional, Union
import numpy as np
import torch
from torch.fft import fft2, fftshift, ifft2, ifftshift
from torchvision.utils import save_image
def draw_img(img: Union[torch.Tensor, np.ndarray],
save_path:Optional[str]='test.png',
nrow:Optional[int]=8,
normalize:Optional[bool]=True):
if isinstance(img, np.ndarray):
img = torch.Tensor(img)
save_image(img, fp=save_path, nrow=nrow, normalize=normalize)
def normalize(img: Union[torch.Tensor, np.ndarray]) \
-> Union[torch.Tensor, np.ndarray]:
return (img - img.min())/(img.max()-img.min())
def to_np(img: torch.Tensor,
mode: Optional[str]='NCHW') -> np.ndarray:
assert mode in ['NCHW', 'NHWC']
if mode == 'NCHW':
img = img.permute(0,2,3,1)
return img.detach().cpu().numpy()
def fft2d(img: torch.Tensor,
mode: Optional[str]='NCHW') -> torch.Tensor:
assert mode in ['NCHW', 'NHWC']
if mode == 'NCHW':
return fftshift(fft2(img))
elif mode == 'NHWC':
img = img.permute(0,3,1,2)
return fftshift(fft2(img))
else:
raise NameError
def ifft2d(img: torch.Tensor,
mode: Optional[str]='NCHW') -> torch.Tensor:
assert mode in ['NCHW', 'NHWC']
if mode == 'NCHW':
return ifft2(ifftshift(img))
elif mode == 'NHWC':
img = ifft2(ifftshift(img))
return img.permute(0,2,3,1)
else:
raise NameError
"""
Helper functions for new types of inverse problems
"""
def fft2c_new(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor:
"""
Apply centered 2 dimensional Fast Fourier Transform.
Args:
data: Complex valued input data containing at least 3 dimensions:
dimensions -3 & -2 are spatial dimensions and dimension -1 has size
2. All other dimensions are assumed to be batch dimensions.
norm: Normalization mode. See ``torch.fft.fft``.
Returns:
The FFT of the input.
"""
if not data.shape[-1] == 2:
raise ValueError("Tensor does not have separate complex dim.")
data = ifftshift(data, dim=[-3, -2])
data = torch.view_as_real(
torch.fft.fftn( # type: ignore
torch.view_as_complex(data), dim=(-2, -1), norm=norm
)
)
data = fftshift(data, dim=[-3, -2])
return data
def ifft2c_new(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor:
"""
Apply centered 2-dimensional Inverse Fast Fourier Transform.
Args:
data: Complex valued input data containing at least 3 dimensions:
dimensions -3 & -2 are spatial dimensions and dimension -1 has size
2. All other dimensions are assumed to be batch dimensions.
norm: Normalization mode. See ``torch.fft.ifft``.
Returns:
The IFFT of the input.
"""
if not data.shape[-1] == 2:
raise ValueError("Tensor does not have separate complex dim.")
data = ifftshift(data, dim=[-3, -2])
data = torch.view_as_real(
torch.fft.ifftn( # type: ignore
torch.view_as_complex(data), dim=(-2, -1), norm=norm
)
)
data = fftshift(data, dim=[-3, -2])
return data
def fft2(x):
""" FFT with shifting DC to the center of the image"""
return torch.fft.fftshift(torch.fft.fft2(x), dim=[-1, -2])
def ifft2(x):
""" IFFT with shifting DC to the corner of the image prior to transform"""
return torch.fft.ifft2(torch.fft.ifftshift(x, dim=[-1, -2]))
def fft2_m(x):
""" FFT for multi-coil """
if not torch.is_complex(x):
x = x.type(torch.complex64)
return torch.view_as_complex(fft2c_new(torch.view_as_real(x)))
def ifft2_m(x):
""" IFFT for multi-coil """
if not torch.is_complex(x):
x = x.type(torch.complex64)
return torch.view_as_complex(ifft2c_new(torch.view_as_real(x)))