File size: 3,910 Bytes
90a9dd3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from typing import Optional, Union

import numpy as np
import torch
from torch.fft import fft2, fftshift, ifft2, ifftshift
from torchvision.utils import save_image


def draw_img(img: Union[torch.Tensor, np.ndarray],
            save_path:Optional[str]='test.png',
            nrow:Optional[int]=8,
            normalize:Optional[bool]=True):
    if isinstance(img, np.ndarray):
        img = torch.Tensor(img)

    save_image(img, fp=save_path, nrow=nrow, normalize=normalize)

def normalize(img: Union[torch.Tensor, np.ndarray]) \
                        -> Union[torch.Tensor, np.ndarray]:
    
    return (img - img.min())/(img.max()-img.min())
     
def to_np(img: torch.Tensor,
          mode: Optional[str]='NCHW') -> np.ndarray:

    assert mode in ['NCHW', 'NHWC']
    
    if mode == 'NCHW':
        img = img.permute(0,2,3,1) 

    return img.detach().cpu().numpy()

def fft2d(img: torch.Tensor,
          mode: Optional[str]='NCHW') -> torch.Tensor:

    assert mode in ['NCHW', 'NHWC']
    
    if mode == 'NCHW':
        return fftshift(fft2(img))
    elif mode == 'NHWC':
        img = img.permute(0,3,1,2)
        return fftshift(fft2(img))
    else:
        raise NameError    
    

def ifft2d(img: torch.Tensor,
           mode: Optional[str]='NCHW') -> torch.Tensor:

    assert mode in ['NCHW', 'NHWC']
    
    if mode == 'NCHW':
        return ifft2(ifftshift(img))
    elif mode == 'NHWC':
        img = ifft2(ifftshift(img))
        return img.permute(0,2,3,1)
    else:
        raise NameError    


"""
Helper functions for new types of inverse problems
"""

def fft2c_new(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor:
    """
    Apply centered 2 dimensional Fast Fourier Transform.
    Args:
        data: Complex valued input data containing at least 3 dimensions:
            dimensions -3 & -2 are spatial dimensions and dimension -1 has size
            2. All other dimensions are assumed to be batch dimensions.
        norm: Normalization mode. See ``torch.fft.fft``.
    Returns:
        The FFT of the input.
    """
    if not data.shape[-1] == 2:
        raise ValueError("Tensor does not have separate complex dim.")

    data = ifftshift(data, dim=[-3, -2])
    data = torch.view_as_real(
        torch.fft.fftn(  # type: ignore
            torch.view_as_complex(data), dim=(-2, -1), norm=norm
        )
    )
    data = fftshift(data, dim=[-3, -2])

    return data


def ifft2c_new(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor:
    """
    Apply centered 2-dimensional Inverse Fast Fourier Transform.
    Args:
        data: Complex valued input data containing at least 3 dimensions:
            dimensions -3 & -2 are spatial dimensions and dimension -1 has size
            2. All other dimensions are assumed to be batch dimensions.
        norm: Normalization mode. See ``torch.fft.ifft``.
    Returns:
        The IFFT of the input.
    """
    if not data.shape[-1] == 2:
        raise ValueError("Tensor does not have separate complex dim.")

    data = ifftshift(data, dim=[-3, -2])
    data = torch.view_as_real(
        torch.fft.ifftn(  # type: ignore
            torch.view_as_complex(data), dim=(-2, -1), norm=norm
        )
    )
    data = fftshift(data, dim=[-3, -2])

    return data

def fft2(x):
  """ FFT with shifting DC to the center of the image"""
  return torch.fft.fftshift(torch.fft.fft2(x), dim=[-1, -2])


def ifft2(x):
  """ IFFT with shifting DC to the corner of the image prior to transform"""
  return torch.fft.ifft2(torch.fft.ifftshift(x, dim=[-1, -2]))


def fft2_m(x):
  """ FFT for multi-coil """
  if not torch.is_complex(x):
      x = x.type(torch.complex64)
  return torch.view_as_complex(fft2c_new(torch.view_as_real(x)))


def ifft2_m(x):
  """ IFFT for multi-coil """
  if not torch.is_complex(x):
      x = x.type(torch.complex64)
  return torch.view_as_complex(ifft2c_new(torch.view_as_real(x)))