Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Copyright (c) Facebook, Inc. and its affiliates. | |
This source code is licensed under the MIT license found in the | |
LICENSE file in the root directory of this source tree. | |
""" | |
import random | |
from typing import Dict, List, NamedTuple, Optional, Sequence, Tuple, Union | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
import fastmri | |
from .subsample import MaskFunc | |
def to_tensor(data: np.ndarray) -> torch.Tensor: | |
""" | |
Convert numpy array to PyTorch tensor. | |
For complex arrays, the real and imaginary parts are stacked along the last | |
dimension. | |
Args: | |
data: Input numpy array. | |
Returns: | |
PyTorch version of data. | |
""" | |
if np.iscomplexobj(data): | |
data = np.stack((data.real, data.imag), axis=-1) | |
return torch.from_numpy(data) | |
def tensor_to_complex_np(data: torch.Tensor) -> np.ndarray: | |
""" | |
Converts a complex torch tensor to numpy array. | |
Args: | |
data: Input data to be converted to numpy. | |
Returns: | |
Complex numpy version of data. | |
""" | |
return torch.view_as_complex(data).numpy() | |
def apply_mask( | |
data: torch.Tensor, | |
mask_func: MaskFunc, | |
offset: Optional[int] = None, | |
seed: Optional[Union[int, Tuple[int, ...]]] = None, | |
padding: Optional[Sequence[int]] = None, | |
) -> Tuple[torch.Tensor, torch.Tensor, int]: | |
""" | |
Subsample given k-space by multiplying with a mask. | |
Args: | |
data: The input k-space data. This should have at least 3 dimensions, | |
where dimensions -3 and -2 are the spatial dimensions, and the | |
final dimension has size 2 (for complex values). | |
mask_func: A function that takes a shape (tuple of ints) and a random | |
number seed and returns a mask. | |
seed: Seed for the random number generator. | |
padding: Padding value to apply for mask. | |
Returns: | |
tuple containing: | |
masked data: Subsampled k-space data. | |
mask: The generated mask. | |
num_low_frequencies: The number of low-resolution frequency samples | |
in the mask. | |
""" | |
shape = (1,) * len(data.shape[:-3]) + tuple(data.shape[-3:]) | |
mask, num_low_frequencies = mask_func(shape, offset, seed) | |
if padding is not None: | |
mask[..., : padding[0], :] = 0 | |
mask[..., padding[1] :, :] = ( | |
0 # padding value inclusive on right of zeros | |
) | |
masked_data = data * mask + 0.0 # the + 0.0 removes the sign of the zeros | |
return masked_data, mask, num_low_frequencies | |
def mask_center(x: torch.Tensor, mask_from: int, mask_to: int) -> torch.Tensor: | |
""" | |
Initializes a mask with the center filled in. | |
Args: | |
mask_from: Part of center to start filling. | |
mask_to: Part of center to end filling. | |
Returns: | |
A mask with the center filled. | |
""" | |
mask = torch.zeros_like(x) | |
mask[:, :, :, mask_from:mask_to] = x[:, :, :, mask_from:mask_to] | |
return mask | |
def batched_mask_center( | |
x: torch.Tensor, mask_from: torch.Tensor, mask_to: torch.Tensor | |
) -> torch.Tensor: | |
""" | |
Initializes a mask with the center filled in. | |
Can operate with different masks for each batch element. | |
Args: | |
mask_from: Part of center to start filling. | |
mask_to: Part of center to end filling. | |
Returns: | |
A mask with the center filled. | |
""" | |
if not mask_from.shape == mask_to.shape: | |
raise ValueError("mask_from and mask_to must match shapes.") | |
if not mask_from.ndim == 1: | |
raise ValueError("mask_from and mask_to must have 1 dimension.") | |
if not mask_from.shape[0] == 1: | |
if (not x.shape[0] == mask_from.shape[0]) or ( | |
not x.shape[0] == mask_to.shape[0] | |
): | |
raise ValueError( | |
"mask_from and mask_to must have batch_size length." | |
) | |
if mask_from.shape[0] == 1: | |
mask = mask_center(x, int(mask_from), int(mask_to)) | |
else: | |
mask = torch.zeros_like(x) | |
for i, (start, end) in enumerate(zip(mask_from, mask_to)): | |
mask[i, :, :, start:end] = x[i, :, :, start:end] | |
return mask | |
def center_crop(data: torch.Tensor, shape: Tuple[int, int]) -> torch.Tensor: | |
""" | |
Apply a center crop to the input real image or batch of real images. | |
Args: | |
data: The input tensor to be center cropped. It should | |
have at least 2 dimensions and the cropping is applied along the | |
last two dimensions. | |
shape: The output shape. The shape should be smaller | |
than the corresponding dimensions of data. | |
Returns: | |
The center cropped image. | |
""" | |
if not (0 < shape[0] <= data.shape[-2] and 0 < shape[1] <= data.shape[-1]): | |
raise ValueError("Invalid shapes.") | |
w_from = (data.shape[-2] - shape[0]) // 2 | |
h_from = (data.shape[-1] - shape[1]) // 2 | |
w_to = w_from + shape[0] | |
h_to = h_from + shape[1] | |
return data[..., w_from:w_to, h_from:h_to] | |
def complex_center_crop( | |
data: torch.Tensor, shape: Tuple[int, int] | |
) -> torch.Tensor: | |
""" | |
Apply a center crop to the input image or batch of complex images. | |
Args: | |
data: The complex input tensor to be center cropped. It should have at | |
least 3 dimensions and the cropping is applied along dimensions -3 | |
and -2 and the last dimensions should have a size of 2. | |
shape: The output shape. The shape should be smaller than the | |
corresponding dimensions of data. | |
Returns: | |
The center cropped image | |
""" | |
if not (0 < shape[0] <= data.shape[-3] and 0 < shape[1] <= data.shape[-2]): | |
raise ValueError("Invalid shapes.") | |
w_from = (data.shape[-3] - shape[0]) // 2 | |
h_from = (data.shape[-2] - shape[1]) // 2 | |
w_to = w_from + shape[0] | |
h_to = h_from + shape[1] | |
return data[..., w_from:w_to, h_from:h_to, :] | |
def center_crop_to_smallest( | |
x: torch.Tensor, y: torch.Tensor | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
""" | |
Apply a center crop on the larger image to the size of the smaller. | |
The minimum is taken over dim=-1 and dim=-2. If x is smaller than y at | |
dim=-1 and y is smaller than x at dim=-2, then the returned dimension will | |
be a mixture of the two. | |
Args: | |
x: The first image. | |
y: The second image. | |
Returns: | |
tuple of tensors x and y, each cropped to the minimim size. | |
""" | |
smallest_width = min(x.shape[-1], y.shape[-1]) | |
smallest_height = min(x.shape[-2], y.shape[-2]) | |
x = center_crop(x, (smallest_height, smallest_width)) | |
y = center_crop(y, (smallest_height, smallest_width)) | |
return x, y | |
def normalize( | |
data: torch.Tensor, | |
mean: Union[float, torch.Tensor], | |
stddev: Union[float, torch.Tensor], | |
eps: Union[float, torch.Tensor] = 0.0, | |
) -> torch.Tensor: | |
""" | |
Normalize the given tensor. | |
Applies the formula (data - mean) / (stddev + eps). | |
Args: | |
data: Input data to be normalized. | |
mean: Mean value. | |
stddev: Standard deviation. | |
eps: Added to stddev to prevent dividing by zero. | |
Returns: | |
Normalized tensor. | |
""" | |
return (data - mean) / (stddev + eps) | |
def normalize_instance( | |
data: torch.Tensor, eps: Union[float, torch.Tensor] = 0.0 | |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
""" | |
Normalize the given tensor with instance norm/ | |
Applies the formula (data - mean) / (stddev + eps), where mean and stddev | |
are computed from the data itself. | |
Args: | |
data: Input data to be normalized | |
eps: Added to stddev to prevent dividing by zero. | |
Returns: | |
torch.Tensor: Normalized tensor | |
""" | |
mean = data.mean() | |
std = data.std() | |
return normalize(data, mean, std, eps), mean, std | |
class UnetSample(NamedTuple): | |
""" | |
A subsampled image for U-Net reconstruction. | |
Args: | |
image: Subsampled image after inverse FFT. | |
target: The target image (if applicable). | |
mean: Per-channel mean values used for normalization. | |
std: Per-channel standard deviations used for normalization. | |
fname: File name. | |
slice_num: The slice index. | |
max_value: Maximum image value. | |
""" | |
image: torch.Tensor | |
target: torch.Tensor | |
mean: torch.Tensor | |
std: torch.Tensor | |
fname: str | |
slice_num: int | |
max_value: float | |
class UnetDataTransform: | |
""" | |
Data Transformer for training U-Net models. | |
""" | |
def __init__( | |
self, | |
which_challenge: str, | |
mask_func: Optional[MaskFunc] = None, | |
use_seed: bool = True, | |
): | |
""" | |
Args: | |
which_challenge: Challenge from ("singlecoil", "multicoil"). | |
mask_func: Optional; A function that can create a mask of | |
appropriate shape. | |
use_seed: If true, this class computes a pseudo random number | |
generator seed from the filename. This ensures that the same | |
mask is used for all the slices of a given volume every time. | |
""" | |
if which_challenge not in ("singlecoil", "multicoil"): | |
raise ValueError( | |
"Challenge should either be 'singlecoil' or 'multicoil'" | |
) | |
self.mask_func = mask_func | |
self.which_challenge = which_challenge | |
self.use_seed = use_seed | |
def __call__( | |
self, | |
kspace: np.ndarray, | |
mask: np.ndarray, | |
target: np.ndarray, | |
attrs: Dict, | |
fname: str, | |
slice_num: int, | |
) -> Tuple[ | |
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, str, int, float | |
]: | |
""" | |
Args: | |
kspace: Input k-space of shape (num_coils, rows, cols) for | |
multi-coil data or (rows, cols) for single coil data. | |
mask: Mask from the test dataset. | |
target: Target image. | |
attrs: Acquisition related information stored in the HDF5 object. | |
fname: File name. | |
slice_num: Serial number of the slice. | |
Returns: | |
A tuple containing, zero-filled input image, the reconstruction | |
target, the mean used for normalization, the standard deviations | |
used for normalization, the filename, and the slice number. | |
""" | |
kspace_torch = to_tensor(kspace) | |
# check for max value | |
max_value = attrs["max"] if "max" in attrs.keys() else 0.0 | |
# apply mask | |
if self.mask_func: | |
seed = None if not self.use_seed else tuple(map(ord, fname)) | |
# we only need first element, which is k-space after masking | |
masked_kspace = apply_mask(kspace_torch, self.mask_func, seed=seed)[ | |
0 | |
] | |
else: | |
masked_kspace = kspace_torch | |
# inverse Fourier transform to get zero filled solution | |
image = fastmri.ifft2c(masked_kspace) | |
# crop input to correct size | |
if target is not None: | |
crop_size = (target.shape[-2], target.shape[-1]) | |
else: | |
crop_size = (attrs["recon_size"][0], attrs["recon_size"][1]) | |
# check for FLAIR 203 | |
if image.shape[-2] < crop_size[1]: | |
crop_size = (image.shape[-2], image.shape[-2]) | |
image = complex_center_crop(image, crop_size) | |
# absolute value | |
image = fastmri.complex_abs(image) | |
# apply Root-Sum-of-Squares if multicoil data | |
if self.which_challenge == "multicoil": | |
image = fastmri.rss(image) | |
# normalize input | |
image, mean, std = normalize_instance(image, eps=1e-11) | |
image = image.clamp(-6, 6) | |
# normalize target | |
if target is not None: | |
target_torch = to_tensor(target) | |
target_torch = center_crop(target_torch, crop_size) | |
target_torch = normalize(target_torch, mean, std, eps=1e-11) | |
target_torch = target_torch.clamp(-6, 6) | |
else: | |
target_torch = torch.Tensor([0]) | |
return UnetSample( | |
image=image, | |
target=target_torch, | |
mean=mean, | |
std=std, | |
fname=fname, | |
slice_num=slice_num, | |
max_value=max_value, | |
) | |
class VarNetSample(NamedTuple): | |
""" | |
A sample of masked k-space for variational network reconstruction. | |
Args: | |
masked_kspace: k-space after applying sampling mask. | |
mask: The applied sampling mask. | |
num_low_frequencies: The number of samples for the densely-sampled | |
center. | |
target: The target image (if applicable). | |
fname: File name. | |
slice_num: The slice index. | |
max_value: Maximum image value. | |
crop_size: The size to crop the final image. | |
""" | |
masked_kspace: torch.Tensor | |
mask: torch.Tensor | |
num_low_frequencies: Optional[int] | |
target: torch.Tensor | |
fname: str | |
slice_num: int | |
max_value: float | |
crop_size: Tuple[int, int] | |
class VarNetDataTransform: | |
""" | |
Data Transformer for training VarNet models. | |
""" | |
def __init__( | |
self, mask_func: Optional[MaskFunc] = None, use_seed: bool = True | |
): | |
""" | |
Args: | |
mask_func: Optional; A function that can create a mask of | |
appropriate shape. Defaults to None. | |
use_seed: If True, this class computes a pseudo random number | |
generator seed from the filename. This ensures that the same | |
mask is used for all the slices of a given volume every time. | |
""" | |
self.mask_func = mask_func | |
self.use_seed = use_seed | |
def __call__( | |
self, | |
kspace: np.ndarray, | |
mask: np.ndarray, | |
target: Optional[np.ndarray], | |
attrs: Dict, | |
fname: str, | |
slice_num: int, | |
) -> VarNetSample: | |
""" | |
Args: | |
kspace: Input k-space of shape (num_coils, rows, cols) for | |
multi-coil data. | |
mask: Mask from the test dataset. | |
target: Target image. | |
attrs: Acquisition related information stored in the HDF5 object. | |
fname: File name. | |
slice_num: Serial number of the slice. | |
Returns: | |
A VarNetSample with the masked k-space, sampling mask, target | |
image, the filename, the slice number, the maximum image value | |
(from target), the target crop size, and the number of low | |
frequency lines sampled. | |
""" | |
if target is not None: | |
target_torch = to_tensor(target) | |
max_value = attrs["max"] | |
else: | |
target_torch = torch.tensor(0) | |
max_value = 0.0 | |
kspace_torch = to_tensor(kspace) | |
seed = None if not self.use_seed else tuple(map(ord, fname)) | |
acq_start = attrs["padding_left"] | |
acq_end = attrs["padding_right"] | |
crop_size = (attrs["recon_size"][0], attrs["recon_size"][1]) | |
if self.mask_func is not None: | |
masked_kspace, mask_torch, num_low_frequencies = apply_mask( | |
kspace_torch, | |
self.mask_func, | |
seed=seed, | |
padding=(acq_start, acq_end), | |
) | |
sample = VarNetSample( | |
masked_kspace=masked_kspace, | |
mask=mask_torch.to(torch.bool), | |
num_low_frequencies=num_low_frequencies, | |
target=target_torch, | |
fname=fname, | |
slice_num=slice_num, | |
max_value=max_value, | |
crop_size=crop_size, | |
) | |
else: | |
masked_kspace = kspace_torch | |
shape = np.array(kspace_torch.shape) | |
num_cols = shape[-2] | |
shape[:-3] = 1 | |
mask_shape = [1] * len(shape) | |
mask_shape[-2] = num_cols | |
mask_torch = torch.from_numpy( | |
mask.reshape(*mask_shape).astype(np.float32) | |
) | |
mask_torch = mask_torch.reshape(*mask_shape) | |
mask_torch[:, :, :acq_start] = 0 | |
mask_torch[:, :, acq_end:] = 0 | |
sample = VarNetSample( | |
masked_kspace=masked_kspace, | |
mask=mask_torch.to(torch.bool), | |
num_low_frequencies=0, | |
target=target_torch, | |
fname=fname, | |
slice_num=slice_num, | |
max_value=max_value, | |
crop_size=crop_size, | |
) | |
# whether to crop samples for batch processing | |
batch_crop = False | |
def save_img(x, fname): | |
slice_kspace2 = x | |
slice_image = fastmri.ifft2c( | |
slice_kspace2 | |
) # Apply Inverse Fourier Transform to get the complex image | |
slice_image_abs = fastmri.complex_abs( | |
slice_image | |
) # Compute absolute value to get a real image | |
slice_image_rss = fastmri.rss(slice_image_abs, dim=0) | |
plt.imsave(f"{fname}.png", torch.abs(slice_image_rss), cmap="gray") | |
def save_raw_img(x, fname): | |
# slice_kspace2 = x | |
# slice_image = fastmri.ifft2c( | |
# slice_kspace2 | |
# ) # Apply Inverse Fourier Transform to get the complex image | |
# slice_image_abs = fastmri.complex_abs( | |
# slice_image | |
# ) # Compute absolute value to get a real image | |
x = fastmri.rss(x, dim=0)[:, :, 0] | |
plt.imsave(f"{fname}.png", torch.abs(x)) | |
if batch_crop: | |
# crop kspace data to minx, miny size (640, 320 cols) | |
square_crop = (attrs["recon_size"][0], attrs["recon_size"][1]) | |
# print(square_crop) | |
cropped_kspace = fastmri.fft2c( | |
complex_center_crop( | |
fastmri.ifft2c(sample.masked_kspace), square_crop | |
) | |
) | |
cropped_kspace = complex_center_crop(cropped_kspace, (320, 320)) | |
# print(cropped_kspace.shape) | |
# exit(0) | |
# CHECK: debugging purposes | |
# save_img(sample.masked_kspace, "og") | |
# save_img(cropped_kspace, "cropped") | |
# save_raw_img(sample.masked_kspace, "og_kspace") | |
# save_raw_img(cropped_kspace, "cropped_kspace") | |
# exit(0) | |
# crop mask shape | |
h_from = (mask_torch.shape[-2] - 320) // 2 | |
h_to = h_from + 320 | |
cropped_mask = mask_torch[..., :, h_from:h_to, :] | |
sample = VarNetSample( | |
masked_kspace=cropped_kspace, | |
mask=cropped_mask.to(torch.bool), | |
num_low_frequencies=0, | |
target=target_torch, | |
fname=fname, | |
slice_num=slice_num, | |
max_value=max_value, | |
crop_size=crop_size, | |
) | |
return sample | |
class EnhancedVarNetDataTransform(VarNetDataTransform): | |
""" | |
Enhanced Data Transformer for training VarNet models with additional functionality. | |
- allows for training on multiple patterns | |
""" | |
def __init__( | |
self, mask_funcs: List[MaskFunc] = None, use_seed: bool = True | |
): | |
self.mask_funcs = mask_funcs | |
self.use_seed = use_seed | |
def __call__( | |
self, | |
kspace: np.ndarray, | |
mask: np.ndarray, | |
target: Optional[np.ndarray], | |
attrs: Dict, | |
fname: str, | |
slice_num: int, | |
) -> VarNetSample: | |
""" | |
Args: | |
kspace: Input k-space of shape (num_coils, rows, cols) for | |
multi-coil data. | |
mask: Mask from the test dataset. | |
use mask for test data see og VarNetDataTransform __call__ | |
target: Target image. | |
attrs: Acquisition related information stored in the HDF5 object. | |
fname: File name. | |
slice_num: Serial number of the slice. | |
Returns: | |
A VarNetSample with the masked k-space, sampling mask, target | |
image, the filename, the slice number, the maximum image value | |
(from target), the target crop size, and the number of low | |
frequency lines sampled. | |
""" | |
if target is not None: | |
target_torch = to_tensor(target) | |
max_value = attrs["max"] | |
else: | |
target_torch = torch.tensor(0) | |
max_value = 0.0 | |
kspace_torch = to_tensor(kspace) | |
seed = None if not self.use_seed else tuple(map(ord, fname)) | |
acq_start = attrs["padding_left"] | |
acq_end = attrs["padding_right"] | |
crop_size = (attrs["recon_size"][0], attrs["recon_size"][1]) | |
# choose one of the masking functions provided randomly | |
mask_func = random.choice(self.mask_funcs) | |
masked_kspace, mask_torch, num_low_frequencies = apply_mask( | |
kspace_torch, | |
mask_func, | |
seed=seed, | |
padding=(acq_start, acq_end), | |
) | |
# print(masked_kspace.shape) | |
# print(mask_torch.shape) | |
# torch.save(masked_kspace, f"masked_kspace_{slice_num}.pkl") | |
# torch.save(mask_torch, f"mask_torch_{slice_num}.pkl") | |
sample = VarNetSample( | |
masked_kspace=masked_kspace, | |
mask=mask_torch.to(torch.bool), | |
num_low_frequencies=num_low_frequencies, | |
target=target_torch, | |
fname=fname, | |
slice_num=slice_num, | |
max_value=max_value, | |
crop_size=crop_size, | |
) | |
# whether to crop samples for batch processing | |
batch_crop = False | |
if batch_crop: | |
# crop kspace data to minx, miny size (640, 320 cols) | |
square_crop = (attrs["recon_size"][0], attrs["recon_size"][1]) | |
# print(square_crop) | |
cropped_kspace = fastmri.fft2c( | |
complex_center_crop( | |
fastmri.ifft2c(sample.masked_kspace), square_crop | |
) | |
) | |
# cropped_kspace = complex_center_crop(cropped_kspace, (640, 320)) | |
# exit(0) | |
# crop mask shape | |
h_from = (mask_torch.shape[-2] - 320) // 2 | |
h_to = h_from + 320 | |
cropped_mask = mask_torch[..., :, h_from:h_to, :] | |
sample = VarNetSample( | |
masked_kspace=cropped_kspace, | |
mask=cropped_mask.to(torch.bool), | |
num_low_frequencies=0, | |
target=target_torch, | |
fname=fname, | |
slice_num=slice_num, | |
max_value=max_value, | |
crop_size=crop_size, | |
) | |
return sample | |
class MiniCoilSample(NamedTuple): | |
""" | |
A sample of masked coil-compressed k-space for reconstruction. | |
Args: | |
kspace: the original k-space before masking. | |
masked_kspace: k-space after applying sampling mask. | |
mask: The applied sampling mask. | |
num_low_frequencies: The number of samples for the densely-sampled | |
center. | |
target: The target image (if applicable). | |
fname: File name. | |
slice_num: The slice index. | |
max_value: Maximum image value. | |
crop_size: The size to crop the final image. | |
""" | |
kspace: torch.Tensor | |
masked_kspace: torch.Tensor | |
mask: torch.Tensor | |
target: torch.Tensor | |
fname: str | |
slice_num: int | |
max_value: float | |
crop_size: Tuple[int, int] | |
class MiniCoilTransform: | |
""" | |
Multi-coil compressed transform, for faster prototyping. | |
""" | |
def __init__( | |
self, | |
mask_func: Optional[MaskFunc] = None, | |
use_seed: Optional[bool] = True, | |
crop_size: Optional[tuple] = None, | |
num_compressed_coils: Optional[int] = None, | |
): | |
""" | |
Args: | |
mask_func: Optional; A function that can create a mask of | |
appropriate shape. Defaults to None. | |
use_seed: If True, this class computes a pseudo random number | |
generator seed from the filename. This ensures that the same | |
mask is used for all the slices of a given volume every time. | |
crop_size: Image dimensions for mini MR images. | |
num_compressed_coils: Number of coils to output from coil | |
compression. | |
""" | |
self.mask_func = mask_func | |
self.use_seed = use_seed | |
self.crop_size = crop_size | |
self.num_compressed_coils = num_compressed_coils | |
def __call__(self, kspace, mask, target, attrs, fname, slice_num): | |
""" | |
Args: | |
kspace: Input k-space of shape (num_coils, rows, cols) for | |
multi-coil data. | |
mask: Mask from the test dataset. Not used if mask_func is defined. | |
target: Target image. | |
attrs: Acquisition related information stored in the HDF5 object. | |
fname: File name. | |
slice_num: Serial number of the slice. | |
Returns: | |
tuple containing: | |
kspace: original kspace (used for active acquisition only). | |
masked_kspace: k-space after applying sampling mask. If there | |
is no mask or mask_func, returns same as kspace. | |
mask: The applied sampling mask | |
target: The target image (if applicable). The target is built | |
from the RSS opp of all coils pre-compression. | |
fname: File name. | |
slice_num: The slice index. | |
max_value: Maximum image value. | |
crop_size: The size to crop the final image. | |
""" | |
if target is not None: | |
target = to_tensor(target) | |
max_value = attrs["max"] | |
else: | |
target = torch.tensor(0) | |
max_value = 0.0 | |
if self.crop_size is None: | |
crop_size = torch.tensor( | |
[attrs["recon_size"][0], attrs["recon_size"][1]] | |
) | |
else: | |
if isinstance(self.crop_size, tuple) or isinstance( | |
self.crop_size, list | |
): | |
assert len(self.crop_size) == 2 | |
if self.crop_size[0] is None or self.crop_size[1] is None: | |
crop_size = torch.tensor( | |
[attrs["recon_size"][0], attrs["recon_size"][1]] | |
) | |
else: | |
crop_size = torch.tensor(self.crop_size) | |
elif isinstance(self.crop_size, int): | |
crop_size = torch.tensor((self.crop_size, self.crop_size)) | |
else: | |
raise ValueError( | |
"`crop_size` should be None, tuple, list, or int, not:" | |
f" {type(self.crop_size)}" | |
) | |
if self.num_compressed_coils is None: | |
num_compressed_coils = kspace.shape[0] | |
else: | |
num_compressed_coils = self.num_compressed_coils | |
seed = None if not self.use_seed else tuple(map(ord, fname)) | |
acq_start = 0 | |
acq_end = crop_size[1] | |
# new cropping section | |
square_crop = (attrs["recon_size"][0], attrs["recon_size"][1]) | |
kspace = fastmri.fft2c( | |
complex_center_crop(fastmri.ifft2c(to_tensor(kspace)), square_crop) | |
).numpy() | |
kspace = complex_center_crop(kspace, crop_size) | |
# we calculate the target before coil compression. This causes the mini | |
# simulation to be one where we have a 15-coil, low-resolution image | |
# and our reconstructor has an SVD coil approximation. This is a little | |
# bit more realistic than doing the target after SVD compression | |
target = fastmri.rss_complex(fastmri.ifft2c(to_tensor(kspace))) | |
max_value = target.max() | |
# apply coil compression | |
new_shape = (num_compressed_coils,) + kspace.shape[1:] | |
kspace = np.reshape(kspace, (kspace.shape[0], -1)) | |
left_vec, _, _ = np.linalg.svd( | |
kspace, compute_uv=True, full_matrices=False | |
) | |
kspace = np.reshape( | |
np.array(np.matrix(left_vec[:, :num_compressed_coils]).H @ kspace), | |
new_shape, | |
) | |
kspace = to_tensor(kspace) | |
# Mask kspace | |
if self.mask_func: | |
masked_kspace, mask, _ = apply_mask( | |
kspace, self.mask_func, seed, (acq_start, acq_end) | |
) | |
mask = mask.byte() | |
elif mask is not None: | |
masked_kspace = kspace | |
shape = np.array(kspace.shape) | |
num_cols = shape[-2] | |
shape[:-3] = 1 | |
mask_shape = [1] * len(shape) | |
mask_shape[-2] = num_cols | |
mask = torch.from_numpy( | |
mask.reshape(*mask_shape).astype(np.float32) | |
) | |
mask = mask.reshape(*mask_shape) | |
mask = mask.byte() | |
else: | |
masked_kspace = kspace | |
shape = np.array(kspace.shape) | |
num_cols = shape[-2] | |
return MiniCoilSample( | |
kspace, | |
masked_kspace, | |
mask, | |
target, | |
fname, | |
slice_num, | |
max_value, | |
crop_size, | |
) | |
""" | |
sens maps & feature transformations | |
- expand | |
- reduce | |
- batch -> chan | |
- chan -> batch | |
""" | |
def sens_expand(x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor: | |
""" | |
Calculates F (x sens_maps) | |
Parameters | |
---------- | |
x : ndarray | |
Single-channel image of shape (..., H, W, 2) | |
sens_maps : ndarray | |
Sensitivity maps (image space) | |
Returns | |
------- | |
ndarray | |
Result of the operation F (x sens_maps) | |
""" | |
return fastmri.fft2c(fastmri.complex_mul(x, sens_maps)) | |
def sens_reduce(k: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor: | |
""" | |
Calculates F^{-1}(k) * conj(sens_maps) | |
where conj(sens_maps) is the element-wise applied complex conjugate | |
Parameters | |
---------- | |
k : ndarray | |
Multi-channel k-space of shape (B, C, H, W, 2) | |
sens_maps : ndarray | |
Sensitivity maps (image space) | |
Returns | |
------- | |
ndarray | |
Result of the operation F^{-1}(k) * conj(sens_maps) | |
""" | |
return fastmri.complex_mul( | |
fastmri.ifft2c(k), fastmri.complex_conj(sens_maps) | |
).sum(dim=1, keepdim=True) | |
def chans_to_batch_dim(x: torch.Tensor) -> Tuple[torch.Tensor, int]: | |
"""Reshapes batched multi-channel samples into multiple single channel samples. | |
Parameters | |
---------- | |
x : torch.Tensor | |
x has shape (b, c, h, w, 2) | |
Returns | |
------- | |
Tuple[torch.Tensor, int] | |
tensor of shape (b * c, 1, h, w, 2), b | |
""" | |
b, c, h, w, comp = x.shape | |
return x.view(b * c, 1, h, w, comp), b | |
def batch_chans_to_chan_dim(x: torch.Tensor, batch_size: int) -> torch.Tensor: | |
"""Reshapes batched independent samples into original multi-channel samples. | |
Parameters | |
---------- | |
x : torch.Tensor | |
tensor of shape (b * c, 1, h, w, 2) | |
batch_size : int | |
batch size | |
Returns | |
------- | |
torch.Tensor | |
original multi-channel tensor of shape (b, c, h, w, 2) | |
""" | |
bc, _, h, w, comp = x.shape | |
c = bc // batch_size | |
return x.view(batch_size, c, h, w, comp) | |