Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,910 Bytes
90a9dd3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
from typing import Optional, Union
import numpy as np
import torch
from torch.fft import fft2, fftshift, ifft2, ifftshift
from torchvision.utils import save_image
def draw_img(img: Union[torch.Tensor, np.ndarray],
save_path:Optional[str]='test.png',
nrow:Optional[int]=8,
normalize:Optional[bool]=True):
if isinstance(img, np.ndarray):
img = torch.Tensor(img)
save_image(img, fp=save_path, nrow=nrow, normalize=normalize)
def normalize(img: Union[torch.Tensor, np.ndarray]) \
-> Union[torch.Tensor, np.ndarray]:
return (img - img.min())/(img.max()-img.min())
def to_np(img: torch.Tensor,
mode: Optional[str]='NCHW') -> np.ndarray:
assert mode in ['NCHW', 'NHWC']
if mode == 'NCHW':
img = img.permute(0,2,3,1)
return img.detach().cpu().numpy()
def fft2d(img: torch.Tensor,
mode: Optional[str]='NCHW') -> torch.Tensor:
assert mode in ['NCHW', 'NHWC']
if mode == 'NCHW':
return fftshift(fft2(img))
elif mode == 'NHWC':
img = img.permute(0,3,1,2)
return fftshift(fft2(img))
else:
raise NameError
def ifft2d(img: torch.Tensor,
mode: Optional[str]='NCHW') -> torch.Tensor:
assert mode in ['NCHW', 'NHWC']
if mode == 'NCHW':
return ifft2(ifftshift(img))
elif mode == 'NHWC':
img = ifft2(ifftshift(img))
return img.permute(0,2,3,1)
else:
raise NameError
"""
Helper functions for new types of inverse problems
"""
def fft2c_new(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor:
"""
Apply centered 2 dimensional Fast Fourier Transform.
Args:
data: Complex valued input data containing at least 3 dimensions:
dimensions -3 & -2 are spatial dimensions and dimension -1 has size
2. All other dimensions are assumed to be batch dimensions.
norm: Normalization mode. See ``torch.fft.fft``.
Returns:
The FFT of the input.
"""
if not data.shape[-1] == 2:
raise ValueError("Tensor does not have separate complex dim.")
data = ifftshift(data, dim=[-3, -2])
data = torch.view_as_real(
torch.fft.fftn( # type: ignore
torch.view_as_complex(data), dim=(-2, -1), norm=norm
)
)
data = fftshift(data, dim=[-3, -2])
return data
def ifft2c_new(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor:
"""
Apply centered 2-dimensional Inverse Fast Fourier Transform.
Args:
data: Complex valued input data containing at least 3 dimensions:
dimensions -3 & -2 are spatial dimensions and dimension -1 has size
2. All other dimensions are assumed to be batch dimensions.
norm: Normalization mode. See ``torch.fft.ifft``.
Returns:
The IFFT of the input.
"""
if not data.shape[-1] == 2:
raise ValueError("Tensor does not have separate complex dim.")
data = ifftshift(data, dim=[-3, -2])
data = torch.view_as_real(
torch.fft.ifftn( # type: ignore
torch.view_as_complex(data), dim=(-2, -1), norm=norm
)
)
data = fftshift(data, dim=[-3, -2])
return data
def fft2(x):
""" FFT with shifting DC to the center of the image"""
return torch.fft.fftshift(torch.fft.fft2(x), dim=[-1, -2])
def ifft2(x):
""" IFFT with shifting DC to the corner of the image prior to transform"""
return torch.fft.ifft2(torch.fft.ifftshift(x, dim=[-1, -2]))
def fft2_m(x):
""" FFT for multi-coil """
if not torch.is_complex(x):
x = x.type(torch.complex64)
return torch.view_as_complex(fft2c_new(torch.view_as_real(x)))
def ifft2_m(x):
""" IFFT for multi-coil """
if not torch.is_complex(x):
x = x.type(torch.complex64)
return torch.view_as_complex(ifft2c_new(torch.view_as_real(x))) |