Spaces:
Running
on
Zero
Running
on
Zero
'''This module handles task-dependent operations (A) and noises (n) to simulate a measurement y=Ax+n.''' | |
from abc import ABC, abstractmethod | |
from functools import partial | |
from torch.nn import functional as F | |
from torchvision import torch | |
from src.flair.utils.blur_util import Blurkernel | |
from src.flair.utils.img_util import fft2d | |
import numpy as np | |
from src.flair.utils.resizer import Resizer | |
from src.flair.utils.utils_sisr import pre_calculate_FK, pre_calculate_nonuniform | |
from torch.fft import fft2, ifft2 | |
from src.flair.motionblur.motionblur import Kernel | |
# ================= | |
# Operation classes | |
# ================= | |
__OPERATOR__ = {} | |
_GAMMA_FACTOR = 2.2 | |
def register_operator(name: str): | |
def wrapper(cls): | |
if __OPERATOR__.get(name, None): | |
raise NameError(f"Name {name} is already registered!") | |
__OPERATOR__[name] = cls | |
return cls | |
return wrapper | |
def get_operator(name: str, **kwargs): | |
if __OPERATOR__.get(name, None) is None: | |
raise NameError(f"Name {name} is not defined.") | |
return __OPERATOR__[name](**kwargs) | |
class LinearOperator(ABC): | |
def forward(self, data, **kwargs): | |
# calculate A * X | |
pass | |
def noisy_forward(self, data, **kwargs): | |
# calculate A * X + n | |
pass | |
def transpose(self, data, **kwargs): | |
# calculate A^T * X | |
pass | |
def ortho_project(self, data, **kwargs): | |
# calculate (I - A^T * A)X | |
return data - self.transpose(self.forward(data, **kwargs), **kwargs) | |
def project(self, data, measurement, **kwargs): | |
# calculate (I - A^T * A)Y - AX | |
return self.ortho_project(measurement, **kwargs) - self.forward(data, **kwargs) | |
class DenoiseOperator(LinearOperator): | |
def __init__(self, device): | |
self.device = device | |
def forward(self, data): | |
return data | |
def noisy_forward(self, data): | |
return data | |
def transpose(self, data): | |
return data | |
def ortho_project(self, data): | |
return data | |
def project(self, data): | |
return data | |
class SuperResolutionOperator(LinearOperator): | |
def __init__(self, | |
in_shape, | |
scale_factor, | |
noise, | |
noise_scale, | |
device): | |
self.device = device | |
self.up_sample = partial(F.interpolate, scale_factor=scale_factor) | |
self.down_sample = Resizer(in_shape, 1/scale_factor).to(device) | |
self.noise = get_noise(name=noise, scale=noise_scale) | |
def A(self, data, **kwargs): | |
return self.forward(data, **kwargs) | |
def forward(self, data, **kwargs): | |
return self.down_sample(data) | |
def noisy_forward(self, data, **kwargs): | |
return self.noise.forward(self.down_sample(data)) | |
def transpose(self, data, **kwargs): | |
return self.up_sample(data) | |
def project(self, data, measurement, **kwargs): | |
return data - self.transpose(self.forward(data)) + self.transpose(measurement) | |
class MotionBlurOperator(LinearOperator): | |
def __init__(self, | |
kernel_size, | |
intensity, | |
device): | |
self.device = device | |
self.kernel_size = kernel_size | |
self.conv = Blurkernel(blur_type='motion', | |
kernel_size=kernel_size, | |
std=intensity, | |
device=device).to(device) # should we keep this device term? | |
self.kernel_size =kernel_size | |
self.intensity = intensity | |
self.kernel = Kernel(size=(kernel_size, kernel_size), intensity=intensity) | |
kernel = torch.tensor(self.kernel.kernelMatrix, dtype=torch.float32) | |
self.conv.update_weights(kernel) | |
def forward(self, data, **kwargs): | |
# A^T * A | |
return self.conv(data) | |
def noisy_forward(self, data, **kwargs): | |
pass | |
def transpose(self, data, **kwargs): | |
return data | |
def change_kernel(self): | |
self.kernel = Kernel(size=(self.kernel_size, self.kernel_size), intensity=self.intensity) | |
kernel = torch.tensor(self.kernel.kernelMatrix, dtype=torch.float32) | |
self.conv.update_weights(kernel) | |
def get_kernel(self): | |
kernel = self.kernel.kernelMatrix.type(torch.float32).to(self.device) | |
return kernel.view(1, 1, self.kernel_size, self.kernel_size) | |
def A(self, data): | |
return self.forward(data) | |
def At(self, data): | |
return self.transpose(data) | |
class GaussialBlurOperator(LinearOperator): | |
def __init__(self, | |
kernel_size, | |
intensity, | |
device): | |
self.device = device | |
self.kernel_size = kernel_size | |
self.conv = Blurkernel(blur_type='gaussian', | |
kernel_size=kernel_size, | |
std=intensity, | |
device=device).to(device) | |
self.kernel = self.conv.get_kernel() | |
self.conv.update_weights(self.kernel.type(torch.float32)) | |
def forward(self, data, **kwargs): | |
return self.conv(data) | |
def noisy_forward(self, data, **kwargs): | |
pass | |
def transpose(self, data, **kwargs): | |
return data | |
def get_kernel(self): | |
return self.kernel.view(1, 1, self.kernel_size, self.kernel_size) | |
def apply_kernel(self, data, kernel): | |
self.conv.update_weights(kernel.type(torch.float32)) | |
return self.conv(data) | |
def A(self, data): | |
return self.forward(data) | |
def At(self, data): | |
return self.transpose(data) | |
class InpaintingOperator(LinearOperator): | |
'''This operator get pre-defined mask and return masked image.''' | |
def __init__(self, | |
noise, | |
noise_scale, | |
device): | |
self.device = device | |
self.noise = get_noise(name=noise, scale=noise_scale) | |
def forward(self, data, **kwargs): | |
try: | |
return data * kwargs.get('mask', None).to(self.device) | |
except: | |
raise ValueError("Require mask") | |
def noisy_forward(self, data, **kwargs): | |
return self.noise.forward(self.forward(data, **kwargs)) | |
def transpose(self, data, **kwargs): | |
return data | |
def ortho_project(self, data, **kwargs): | |
return data - self.forward(data, **kwargs) | |
# Operator for BlindDPS. | |
class BlindBlurOperator(LinearOperator): | |
def __init__(self, device, **kwargs) -> None: | |
self.device = device | |
def forward(self, data, kernel, **kwargs): | |
return self.apply_kernel(data, kernel) | |
def transpose(self, data, **kwargs): | |
return data | |
def apply_kernel(self, data, kernel): | |
#TODO: faster way to apply conv?:W | |
b_img = torch.zeros_like(data).to(self.device) | |
for i in range(3): | |
b_img[:, i, :, :] = F.conv2d(data[:, i:i+1, :, :], kernel, padding='same') | |
return b_img | |
class NonLinearOperator(ABC): | |
def forward(self, data, **kwargs): | |
pass | |
def noisy_forward(self, data, **kwargs): | |
pass | |
def project(self, data, measurement, **kwargs): | |
return data + measurement - self.forward(data) | |
class PhaseRetrievalOperator(NonLinearOperator): | |
def __init__(self, | |
oversample, | |
noise, | |
noise_scale, | |
device): | |
self.pad = int((oversample / 8.0) * 256) | |
self.device = device | |
self.noise = get_noise(name=noise, scale=noise_scale) | |
def forward(self, data, **kwargs): | |
padded = F.pad(data, (self.pad, self.pad, self.pad, self.pad)) | |
amplitude = fft2d(padded).abs() | |
return amplitude | |
def noisy_forard(self, data, **kwargs): | |
return self.noise.forward(self.forward(data, **kwargs)) | |
class NonuniformBlurOperator(LinearOperator): | |
def __init__(self, in_shape, kernel_size, device, | |
kernels=None, masks=None): | |
self.device = device | |
self.kernel_size = kernel_size | |
self.in_shape = in_shape | |
# TODO: generalize | |
if kernels is None and masks is None: | |
self.kernels = np.load('./functions/nonuniform/kernels/000001.npy') | |
self.masks = np.load('./functions/nonuniform/masks/000001.npy') | |
self.kernels = torch.tensor(self.kernels).to(device) | |
self.masks = torch.tensor(self.masks).to(device) | |
# approximate in image space | |
def forward_img(self, data): | |
K = self.kernel_size | |
data = F.pad(data, (K//2, K//2, K//2, K//2), mode="reflect") | |
kernels = self.kernels.transpose(0, 1) | |
data_rgb_batch = data.transpose(0, 1) | |
conv_rgb_batch = F.conv2d(data_rgb_batch, kernels) | |
y_rgb_batch = conv_rgb_batch * self.masks | |
y_rgb_batch = torch.sum(y_rgb_batch, dim=1, keepdim=True) | |
y = y_rgb_batch.transpose(0, 1) | |
return y | |
# NOTE: Only using this operator will make the problem nonlinear (gamma-correction) | |
def forward_nonlinear(self, data, flatten=False, noiseless=False): | |
# 1. Usual nonuniform blurring degradataion pipeline | |
kernels = self.kernels.transpose(0, 1) | |
FK, FKC = pre_calculate_FK(kernels) | |
y = ifft2(FK * fft2(data)).real | |
y = y.transpose(0, 1) | |
y_rgb_batch = self.masks * y | |
y_rgb_batch = torch.sum(y_rgb_batch, dim=1, keepdim=True) | |
y = y_rgb_batch.transpose(0, 1) | |
F2KM, FKFMy = pre_calculate_nonuniform(data, y, FK, FKC, self.masks) | |
self.pre_calculated = (FK, FKC, F2KM, FKFMy) | |
# 2. Gamma-correction | |
y = (y + 1) / 2 | |
y = y ** (1 / _GAMMA_FACTOR) | |
y = (y - 0.5) / 0.5 | |
return y | |
def noisy_forward(self, data, **kwargs): | |
return self.noise.forward(self.forward(data)) | |
# exact in Fourier | |
def forward(self, data, flatten=False, noiseless=False): | |
# [1, 25, 33, 33] -> [25, 1, 33, 33] | |
kernels = self.kernels.transpose(0, 1) | |
# [25, 1, 512, 512] | |
FK, FKC = pre_calculate_FK(kernels) | |
# [25, 3, 512, 512] | |
y = ifft2(FK * fft2(data)).real | |
# [3, 25, 512, 512] | |
y = y.transpose(0, 1) | |
y_rgb_batch = self.masks * y | |
# [3, 1, 512, 512] | |
y_rgb_batch = torch.sum(y_rgb_batch, dim=1, keepdim=True) | |
# [1, 3, 512, 512] | |
y = y_rgb_batch.transpose(0, 1) | |
F2KM, FKFMy = pre_calculate_nonuniform(data, y, FK, FKC, self.masks) | |
self.pre_calculated = (FK, FKC, F2KM, FKFMy) | |
return y | |
def transpose(self, y, flatten=False): | |
kernels = self.kernels.transpose(0, 1) | |
FK, FKC = pre_calculate_FK(kernels) | |
# 1. braodcast and multiply | |
# [3, 1, 512, 512] | |
y_rgb_batch = y.transpose(0, 1) | |
# [3, 25, 512, 512] | |
y_rgb_batch = y_rgb_batch.repeat(1, 25, 1, 1) | |
y = self.masks * y_rgb_batch | |
# 2. transpose of convolution in Fourier | |
# [25, 3, 512, 512] | |
y = y.transpose(0, 1) | |
ATy_broadcast = ifft2(FKC * fft2(y)).real | |
# [1, 3, 512, 512] | |
ATy = torch.sum(ATy_broadcast, dim=0, keepdim=True) | |
return ATy | |
def A(self, data): | |
return self.forward(data) | |
def At(self, data): | |
return self.transpose(data) | |
# ============= | |
# Noise classes | |
# ============= | |
__NOISE__ = {} | |
def register_noise(name: str): | |
def wrapper(cls): | |
if __NOISE__.get(name, None): | |
raise NameError(f"Name {name} is already defined!") | |
__NOISE__[name] = cls | |
return cls | |
return wrapper | |
def get_noise(name: str, **kwargs): | |
if __NOISE__.get(name, None) is None: | |
raise NameError(f"Name {name} is not defined.") | |
noiser = __NOISE__[name](**kwargs) | |
noiser.__name__ = name | |
return noiser | |
class Noise(ABC): | |
def __call__(self, data): | |
return self.forward(data) | |
def forward(self, data): | |
pass | |
class Clean(Noise): | |
def __init__(self, **kwargs): | |
pass | |
def forward(self, data): | |
return data | |
class GaussianNoise(Noise): | |
def __init__(self, scale): | |
self.scale = scale | |
def forward(self, data): | |
return data + torch.randn_like(data, device=data.device) * self.scale | |
class PoissonNoise(Noise): | |
def __init__(self, scale): | |
self.scale = scale | |
def forward(self, data): | |
''' | |
Follow skimage.util.random_noise. | |
''' | |
# version 3 (stack-overflow) | |
import numpy as np | |
data = (data + 1.0) / 2.0 | |
data = data.clamp(0, 1) | |
device = data.device | |
data = data.detach().cpu() | |
data = torch.from_numpy(np.random.poisson(data * 255.0 * self.scale) / 255.0 / self.scale) | |
data = data * 2.0 - 1.0 | |
data = data.clamp(-1, 1) | |
return data.to(device) | |