Spaces:
Running
on
Zero
Running
on
Zero
import numpy as np | |
import torch | |
from munch import Munch | |
from src.flair.functions import svd_operators as svd_op | |
from src.flair.functions import measurements | |
from src.flair.utils.inpaint_util import MaskGenerator | |
__DEGRADATION__ = {} | |
def register_degradation(name: str): | |
def wrapper(fn): | |
if __DEGRADATION__.get(name) is not None: | |
raise NameError(f'DEGRADATION {name} is already registered') | |
__DEGRADATION__[name]=fn | |
return fn | |
return wrapper | |
def get_degradation(name: str, | |
deg_config: Munch, | |
device:torch.device): | |
if __DEGRADATION__.get(name) is None: | |
raise NameError(f'DEGRADATION {name} does not exist.') | |
return __DEGRADATION__[name](deg_config, device) | |
def deg_cs_walshhadamard(deg_config, device): | |
compressed_size = round(1/deg_config.deg_scale) | |
A_funcs = svd_op.WalshHadamardCS(deg_config.channels, | |
deg_config.image_size, | |
compressed_size, | |
torch.randperm(deg_config.image_size**2), | |
device) | |
return A_funcs | |
def deg_cs_blockbased(deg_config, device): | |
cs_ratio = deg_config.deg_scale | |
A_funcs = svd_op.CS(deg_config.channels, | |
deg_config.image_size, | |
cs_ratio, | |
device) | |
return A_funcs | |
def deg_inpainting(deg_config, device): | |
# TODO: generate mask rather than load | |
loaded = np.load("exp/inp_masks/mask_768_half.npy") # block | |
# loaded = np.load("lip_mask_4.npy") | |
mask = torch.from_numpy(loaded).to(device).reshape(-1) | |
missing_r = torch.nonzero(mask == 0).long().reshape(-1) * 3 | |
missing_g = missing_r + 1 | |
missing_b = missing_g + 1 | |
missing = torch.cat([missing_r, missing_g, missing_b], dim=0) | |
A_funcs = svd_op.Inpainting(deg_config.channels, | |
deg_config.image_size, | |
missing, | |
device) | |
return A_funcs | |
def deg_denoise(deg_config, device): | |
A_funcs = svd_op.Denoising(deg_config.channels, | |
deg_config.image_size, | |
device) | |
return A_funcs | |
def deg_colorization(deg_config, device): | |
A_funcs = svd_op.Colorization(deg_config.image_size, | |
device) | |
return A_funcs | |
def deg_sr_avgpool(deg_config, device): | |
blur_by = int(deg_config.deg_scale) | |
A_funcs = svd_op.SuperResolution(deg_config.channels, | |
deg_config.image_size, | |
blur_by, | |
device) | |
return A_funcs | |
def deg_sr_bicubic(deg_config, device): | |
def bicubic_kernel(x, a=-0.5): | |
if abs(x) <= 1: | |
return (a + 2) * abs(x) ** 3 - (a + 3) * abs(x) ** 2 + 1 | |
elif 1 < abs(x) and abs(x) < 2: | |
return a * abs(x) ** 3 - 5 * a * abs(x) ** 2 + 8 * a * abs(x) - 4 * a | |
else: | |
return 0 | |
factor = int(deg_config.deg_scale) | |
k = np.zeros((factor * 4)) | |
for i in range(factor * 4): | |
x = (1 / factor) * (i - np.floor(factor * 4 / 2) + 0.5) | |
k[i] = bicubic_kernel(x) | |
k = k / np.sum(k) | |
kernel = torch.from_numpy(k).float().to(device) | |
A_funcs = svd_op.SRConv(kernel / kernel.sum(), | |
deg_config.channels, | |
deg_config.image_size, | |
device, | |
stride=factor) | |
return A_funcs | |
def deg_deblur_uni(deg_config, device): | |
A_funcs = svd_op.Deblurring(torch.tensor([1/deg_config.deg_scale]*deg_config.deg_scale).to(device), | |
deg_config.channels, | |
deg_config.image_size, | |
device) | |
return A_funcs | |
def deg_deblur_gauss(deg_config, device): | |
sigma = 3.0 | |
pdf = lambda x: torch.exp(torch.Tensor([-0.5 * (x / sigma) ** 2])) | |
size = deg_config.deg_scale | |
ker = [] | |
for k in range(-size//2, size//2): | |
ker.append(pdf(k)) | |
kernel = torch.Tensor(ker).to(device) | |
A_funcs = svd_op.Deblurring(kernel / kernel.sum(), | |
deg_config.channels, | |
deg_config.image_size, | |
device) | |
return A_funcs | |
def deg_deblur_aniso(deg_config, device): | |
sigma = 20 | |
pdf = lambda x: torch.exp(torch.Tensor([-0.5 * (x / sigma) ** 2])) | |
kernel2 = torch.Tensor([pdf(-4), pdf(-3), pdf(-2), pdf(-1), pdf(0), pdf(1), pdf(2), pdf(3), pdf(4)]).to(device) | |
sigma = 1 | |
pdf = lambda x: torch.exp(torch.Tensor([-0.5 * (x / sigma) ** 2])) | |
kernel1 = torch.Tensor([pdf(-4), pdf(-3), pdf(-2), pdf(-1), pdf(0), pdf(1), pdf(2), pdf(3), pdf(4)]).to(device) | |
A_funcs = svd_op.Deblurring2D(kernel1 / kernel1.sum(), | |
kernel2 / kernel2.sum(), | |
deg_config.channels, | |
deg_config.image_size, | |
device) | |
return A_funcs | |
def deg_deblur_motion(deg_config, device): | |
A_funcs = measurements.MotionBlurOperator( | |
kernel_size=deg_config.deg_scale, | |
intensity=0.5, | |
device=device | |
) | |
return A_funcs | |
def deg_deblur_motion(deg_config, device, kernels=None, masks=None): | |
A_funcs = measurements.NonuniformBlurOperator( | |
deg_config.image_size, | |
deg_config.deg_scale, | |
device, | |
kernels=kernels, | |
masks=masks, | |
) | |
return A_funcs | |
# ======= FOR arbitraty image size ======= | |
def deg_sr_avgpool_general(deg_config, device): | |
blur_by = int(deg_config.deg_scale) | |
A_funcs = svd_op.SuperResolutionGeneral(deg_config.channels, | |
deg_config.imgH, | |
deg_config.imgW, | |
blur_by, | |
device) | |
return A_funcs | |
def deg_deblur_guass_general(deg_config, device): | |
A_funcs = measurements.GaussialBlurOperator( | |
kernel_size=deg_config.deg_scale, | |
intensity=3.0, | |
device=device | |
) | |
return A_funcs | |
from src.flair.functions.jpeg import jpeg_encode, jpeg_decode | |
class JPEGOperator(): | |
def __init__(self, qf: int, device): | |
self.qf = qf | |
self.device = device | |
def A(self, img): | |
x_luma, x_chroma = jpeg_encode(img, self.qf) | |
return x_luma, x_chroma | |
def At(self, encoded): | |
return jpeg_decode(encoded, self.qf) | |
def deg_jpeg(deg_config, device): | |
A_funcs = JPEGOperator( | |
qf = deg_config.deg_scale, | |
device=device | |
) | |
return A_funcs | |