Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,430 Bytes
90a9dd3 a7169e0 90a9dd3 a7169e0 90a9dd3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 |
import numpy as np
import torch
from munch import Munch
from src.flair.functions import svd_operators as svd_op
from src.flair.functions import measurements
from src.flair.utils.inpaint_util import MaskGenerator
__DEGRADATION__ = {}
def register_degradation(name: str):
def wrapper(fn):
if __DEGRADATION__.get(name) is not None:
raise NameError(f'DEGRADATION {name} is already registered')
__DEGRADATION__[name]=fn
return fn
return wrapper
def get_degradation(name: str,
deg_config: Munch,
device:torch.device):
if __DEGRADATION__.get(name) is None:
raise NameError(f'DEGRADATION {name} does not exist.')
return __DEGRADATION__[name](deg_config, device)
@register_degradation(name='cs_walshhadamard')
def deg_cs_walshhadamard(deg_config, device):
compressed_size = round(1/deg_config.deg_scale)
A_funcs = svd_op.WalshHadamardCS(deg_config.channels,
deg_config.image_size,
compressed_size,
torch.randperm(deg_config.image_size**2),
device)
return A_funcs
@register_degradation(name='cs_blockbased')
def deg_cs_blockbased(deg_config, device):
cs_ratio = deg_config.deg_scale
A_funcs = svd_op.CS(deg_config.channels,
deg_config.image_size,
cs_ratio,
device)
return A_funcs
@register_degradation(name='inpainting')
def deg_inpainting(deg_config, device):
# TODO: generate mask rather than load
loaded = np.load("exp/inp_masks/mask_768_half.npy") # block
# loaded = np.load("lip_mask_4.npy")
mask = torch.from_numpy(loaded).to(device).reshape(-1)
missing_r = torch.nonzero(mask == 0).long().reshape(-1) * 3
missing_g = missing_r + 1
missing_b = missing_g + 1
missing = torch.cat([missing_r, missing_g, missing_b], dim=0)
A_funcs = svd_op.Inpainting(deg_config.channels,
deg_config.image_size,
missing,
device)
return A_funcs
@register_degradation(name='denoising')
def deg_denoise(deg_config, device):
A_funcs = svd_op.Denoising(deg_config.channels,
deg_config.image_size,
device)
return A_funcs
@register_degradation(name='colorization')
def deg_colorization(deg_config, device):
A_funcs = svd_op.Colorization(deg_config.image_size,
device)
return A_funcs
@register_degradation(name='sr_avgpool')
def deg_sr_avgpool(deg_config, device):
blur_by = int(deg_config.deg_scale)
A_funcs = svd_op.SuperResolution(deg_config.channels,
deg_config.image_size,
blur_by,
device)
return A_funcs
@register_degradation(name='sr_bicubic')
def deg_sr_bicubic(deg_config, device):
def bicubic_kernel(x, a=-0.5):
if abs(x) <= 1:
return (a + 2) * abs(x) ** 3 - (a + 3) * abs(x) ** 2 + 1
elif 1 < abs(x) and abs(x) < 2:
return a * abs(x) ** 3 - 5 * a * abs(x) ** 2 + 8 * a * abs(x) - 4 * a
else:
return 0
factor = int(deg_config.deg_scale)
k = np.zeros((factor * 4))
for i in range(factor * 4):
x = (1 / factor) * (i - np.floor(factor * 4 / 2) + 0.5)
k[i] = bicubic_kernel(x)
k = k / np.sum(k)
kernel = torch.from_numpy(k).float().to(device)
A_funcs = svd_op.SRConv(kernel / kernel.sum(),
deg_config.channels,
deg_config.image_size,
device,
stride=factor)
return A_funcs
@register_degradation(name='deblur_uni')
def deg_deblur_uni(deg_config, device):
A_funcs = svd_op.Deblurring(torch.tensor([1/deg_config.deg_scale]*deg_config.deg_scale).to(device),
deg_config.channels,
deg_config.image_size,
device)
return A_funcs
@register_degradation(name='deblur_gauss')
def deg_deblur_gauss(deg_config, device):
sigma = 3.0
pdf = lambda x: torch.exp(torch.Tensor([-0.5 * (x / sigma) ** 2]))
size = deg_config.deg_scale
ker = []
for k in range(-size//2, size//2):
ker.append(pdf(k))
kernel = torch.Tensor(ker).to(device)
A_funcs = svd_op.Deblurring(kernel / kernel.sum(),
deg_config.channels,
deg_config.image_size,
device)
return A_funcs
@register_degradation(name='deblur_aniso')
def deg_deblur_aniso(deg_config, device):
sigma = 20
pdf = lambda x: torch.exp(torch.Tensor([-0.5 * (x / sigma) ** 2]))
kernel2 = torch.Tensor([pdf(-4), pdf(-3), pdf(-2), pdf(-1), pdf(0), pdf(1), pdf(2), pdf(3), pdf(4)]).to(device)
sigma = 1
pdf = lambda x: torch.exp(torch.Tensor([-0.5 * (x / sigma) ** 2]))
kernel1 = torch.Tensor([pdf(-4), pdf(-3), pdf(-2), pdf(-1), pdf(0), pdf(1), pdf(2), pdf(3), pdf(4)]).to(device)
A_funcs = svd_op.Deblurring2D(kernel1 / kernel1.sum(),
kernel2 / kernel2.sum(),
deg_config.channels,
deg_config.image_size,
device)
return A_funcs
@register_degradation(name='deblur_motion')
def deg_deblur_motion(deg_config, device):
A_funcs = measurements.MotionBlurOperator(
kernel_size=deg_config.deg_scale,
intensity=0.5,
device=device
)
return A_funcs
@register_degradation(name='deblur_nonuniform')
def deg_deblur_motion(deg_config, device, kernels=None, masks=None):
A_funcs = measurements.NonuniformBlurOperator(
deg_config.image_size,
deg_config.deg_scale,
device,
kernels=kernels,
masks=masks,
)
return A_funcs
# ======= FOR arbitraty image size =======
@register_degradation(name='sr_avgpool_gen')
def deg_sr_avgpool_general(deg_config, device):
blur_by = int(deg_config.deg_scale)
A_funcs = svd_op.SuperResolutionGeneral(deg_config.channels,
deg_config.imgH,
deg_config.imgW,
blur_by,
device)
return A_funcs
@register_degradation(name='deblur_gauss_gen')
def deg_deblur_guass_general(deg_config, device):
A_funcs = measurements.GaussialBlurOperator(
kernel_size=deg_config.deg_scale,
intensity=3.0,
device=device
)
return A_funcs
from src.flair.functions.jpeg import jpeg_encode, jpeg_decode
class JPEGOperator():
def __init__(self, qf: int, device):
self.qf = qf
self.device = device
def A(self, img):
x_luma, x_chroma = jpeg_encode(img, self.qf)
return x_luma, x_chroma
def At(self, encoded):
return jpeg_decode(encoded, self.qf)
@register_degradation(name='jpeg')
def deg_jpeg(deg_config, device):
A_funcs = JPEGOperator(
qf = deg_config.deg_scale,
device=device
)
return A_funcs
|