Spaces:
Running
on
Zero
Running
on
Zero
Upload 22 files
Browse files- .gitattributes +14 -0
- dataloaders/paired_dataset.py +95 -0
- dataloaders/params_realesrgan.yml +43 -0
- dataloaders/realesrgan.py +303 -0
- dataloaders/simple_dataset.py +156 -0
- figs/bird1.png +3 -0
- figs/building.png +3 -0
- figs/data_real.png +3 -0
- figs/data_real_sup.jpg +3 -0
- figs/data_real_suppl.jpg +3 -0
- figs/data_real_suppl.png +3 -0
- figs/data_syn.png +3 -0
- figs/figs.md +1 -0
- figs/framework.png +3 -0
- figs/gradio.png +0 -0
- figs/ground.jpg +0 -0
- figs/logo1.png +0 -0
- figs/nature.png +3 -0
- figs/person1.png +3 -0
- figs/turbo_steps02_building.png +3 -0
- figs/turbo_steps02_frog.png +3 -0
- figs/turbo_steps04_building.png +3 -0
- figs/turbo_steps04_frog.png +3 -0
.gitattributes
CHANGED
@@ -33,3 +33,17 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
figs/bird1.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
figs/building.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
figs/data_real_sup.jpg filter=lfs diff=lfs merge=lfs -text
|
39 |
+
figs/data_real_suppl.jpg filter=lfs diff=lfs merge=lfs -text
|
40 |
+
figs/data_real_suppl.png filter=lfs diff=lfs merge=lfs -text
|
41 |
+
figs/data_real.png filter=lfs diff=lfs merge=lfs -text
|
42 |
+
figs/data_syn.png filter=lfs diff=lfs merge=lfs -text
|
43 |
+
figs/framework.png filter=lfs diff=lfs merge=lfs -text
|
44 |
+
figs/nature.png filter=lfs diff=lfs merge=lfs -text
|
45 |
+
figs/person1.png filter=lfs diff=lfs merge=lfs -text
|
46 |
+
figs/turbo_steps02_building.png filter=lfs diff=lfs merge=lfs -text
|
47 |
+
figs/turbo_steps02_frog.png filter=lfs diff=lfs merge=lfs -text
|
48 |
+
figs/turbo_steps04_building.png filter=lfs diff=lfs merge=lfs -text
|
49 |
+
figs/turbo_steps04_frog.png filter=lfs diff=lfs merge=lfs -text
|
dataloaders/paired_dataset.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os
|
3 |
+
from PIL import Image
|
4 |
+
import random
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from torch import nn
|
8 |
+
from torchvision import transforms
|
9 |
+
from torch.utils import data as data
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
from .realesrgan import RealESRGAN_degradation
|
13 |
+
|
14 |
+
class PairedCaptionDataset(data.Dataset):
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
root_folders=None,
|
18 |
+
tokenizer=None,
|
19 |
+
null_text_ratio=0.5,
|
20 |
+
# use_ram_encoder=False,
|
21 |
+
# use_gt_caption=False,
|
22 |
+
# caption_type = 'gt_caption',
|
23 |
+
):
|
24 |
+
super(PairedCaptionDataset, self).__init__()
|
25 |
+
|
26 |
+
self.null_text_ratio = null_text_ratio
|
27 |
+
self.lr_list = []
|
28 |
+
self.gt_list = []
|
29 |
+
self.tag_path_list = []
|
30 |
+
|
31 |
+
root_folders = root_folders.split(',')
|
32 |
+
for root_folder in root_folders:
|
33 |
+
lr_path = root_folder +'/sr_bicubic'
|
34 |
+
tag_path = root_folder +'/tag'
|
35 |
+
gt_path = root_folder +'/gt'
|
36 |
+
|
37 |
+
self.lr_list += glob.glob(os.path.join(lr_path, '*.png'))
|
38 |
+
self.gt_list += glob.glob(os.path.join(gt_path, '*.png'))
|
39 |
+
self.tag_path_list += glob.glob(os.path.join(tag_path, '*.txt'))
|
40 |
+
|
41 |
+
|
42 |
+
assert len(self.lr_list) == len(self.gt_list)
|
43 |
+
assert len(self.lr_list) == len(self.tag_path_list)
|
44 |
+
|
45 |
+
self.img_preproc = transforms.Compose([
|
46 |
+
transforms.ToTensor(),
|
47 |
+
])
|
48 |
+
|
49 |
+
ram_mean = [0.485, 0.456, 0.406]
|
50 |
+
ram_std = [0.229, 0.224, 0.225]
|
51 |
+
self.ram_normalize = transforms.Normalize(mean=ram_mean, std=ram_std)
|
52 |
+
|
53 |
+
self.tokenizer = tokenizer
|
54 |
+
|
55 |
+
def tokenize_caption(self, caption=""):
|
56 |
+
inputs = self.tokenizer(
|
57 |
+
caption, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
|
58 |
+
)
|
59 |
+
|
60 |
+
return inputs.input_ids
|
61 |
+
|
62 |
+
def __getitem__(self, index):
|
63 |
+
|
64 |
+
|
65 |
+
gt_path = self.gt_list[index]
|
66 |
+
gt_img = Image.open(gt_path).convert('RGB')
|
67 |
+
gt_img = self.img_preproc(gt_img)
|
68 |
+
|
69 |
+
lq_path = self.lr_list[index]
|
70 |
+
lq_img = Image.open(lq_path).convert('RGB')
|
71 |
+
lq_img = self.img_preproc(lq_img)
|
72 |
+
|
73 |
+
if random.random() < self.null_text_ratio:
|
74 |
+
tag = ''
|
75 |
+
else:
|
76 |
+
tag_path = self.tag_path_list[index]
|
77 |
+
file = open(tag_path, 'r')
|
78 |
+
tag = file.read()
|
79 |
+
file.close()
|
80 |
+
|
81 |
+
example = dict()
|
82 |
+
example["conditioning_pixel_values"] = lq_img.squeeze(0)
|
83 |
+
example["pixel_values"] = gt_img.squeeze(0) * 2.0 - 1.0
|
84 |
+
example["input_ids"] = self.tokenize_caption(caption=tag).squeeze(0)
|
85 |
+
|
86 |
+
lq_img = lq_img.squeeze()
|
87 |
+
|
88 |
+
ram_values = F.interpolate(lq_img.unsqueeze(0), size=(384, 384), mode='bicubic')
|
89 |
+
ram_values = ram_values.clamp(0.0, 1.0)
|
90 |
+
example["ram_values"] = self.ram_normalize(ram_values.squeeze(0))
|
91 |
+
|
92 |
+
return example
|
93 |
+
|
94 |
+
def __len__(self):
|
95 |
+
return len(self.gt_list)
|
dataloaders/params_realesrgan.yml
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
scale: 4
|
2 |
+
color_jitter_prob: 0.0
|
3 |
+
gray_prob: 0.0
|
4 |
+
|
5 |
+
# the first degradation process
|
6 |
+
resize_prob: [0.2, 0.7, 0.1] # up, down, keep
|
7 |
+
resize_range: [0.3, 1.5]
|
8 |
+
gaussian_noise_prob: 0.5
|
9 |
+
noise_range: [1, 15]
|
10 |
+
poisson_scale_range: [0.05, 2.0]
|
11 |
+
gray_noise_prob: 0.4
|
12 |
+
jpeg_range: [60, 95]
|
13 |
+
|
14 |
+
# the second degradation process
|
15 |
+
second_blur_prob: 0.5
|
16 |
+
resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
|
17 |
+
resize_range2: [0.6, 1.2]
|
18 |
+
gaussian_noise_prob2: 0.5
|
19 |
+
noise_range2: [1, 12]
|
20 |
+
poisson_scale_range2: [0.05, 1.0]
|
21 |
+
gray_noise_prob2: 0.4
|
22 |
+
jpeg_range2: [60, 100]
|
23 |
+
|
24 |
+
kernel_info:
|
25 |
+
blur_kernel_size: 21
|
26 |
+
kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
|
27 |
+
kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
|
28 |
+
sinc_prob: 0.1
|
29 |
+
blur_sigma: [0.2, 3]
|
30 |
+
betag_range: [0.5, 4]
|
31 |
+
betap_range: [1, 2]
|
32 |
+
|
33 |
+
blur_kernel_size2: 21
|
34 |
+
kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
|
35 |
+
kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
|
36 |
+
sinc_prob2: 0.1
|
37 |
+
blur_sigma2: [0.2, 1.5]
|
38 |
+
betag_range2: [0.5, 4]
|
39 |
+
betap_range2: [1, 2]
|
40 |
+
|
41 |
+
final_sinc_prob: 0.8
|
42 |
+
|
43 |
+
|
dataloaders/realesrgan.py
ADDED
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
import glob
|
5 |
+
import math
|
6 |
+
import yaml
|
7 |
+
import random
|
8 |
+
from collections import OrderedDict
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
from basicsr.data.transforms import augment
|
13 |
+
from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
|
14 |
+
from basicsr.utils import DiffJPEG, USMSharp, img2tensor, tensor2img
|
15 |
+
from basicsr.utils.img_process_util import filter2D
|
16 |
+
from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
|
17 |
+
from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation,
|
18 |
+
normalize, rgb_to_grayscale)
|
19 |
+
|
20 |
+
cur_path = os.path.dirname(os.path.abspath(__file__))
|
21 |
+
|
22 |
+
|
23 |
+
def ordered_yaml():
|
24 |
+
"""Support OrderedDict for yaml.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
yaml Loader and Dumper.
|
28 |
+
"""
|
29 |
+
try:
|
30 |
+
from yaml import CDumper as Dumper
|
31 |
+
from yaml import CLoader as Loader
|
32 |
+
except ImportError:
|
33 |
+
from yaml import Dumper, Loader
|
34 |
+
|
35 |
+
_mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
|
36 |
+
|
37 |
+
def dict_representer(dumper, data):
|
38 |
+
return dumper.represent_dict(data.items())
|
39 |
+
|
40 |
+
def dict_constructor(loader, node):
|
41 |
+
return OrderedDict(loader.construct_pairs(node))
|
42 |
+
|
43 |
+
Dumper.add_representer(OrderedDict, dict_representer)
|
44 |
+
Loader.add_constructor(_mapping_tag, dict_constructor)
|
45 |
+
return Loader, Dumper
|
46 |
+
|
47 |
+
def opt_parse(opt_path):
|
48 |
+
with open(opt_path, mode='r') as f:
|
49 |
+
Loader, _ = ordered_yaml()
|
50 |
+
opt = yaml.load(f, Loader=Loader)
|
51 |
+
|
52 |
+
return opt
|
53 |
+
|
54 |
+
class RealESRGAN_degradation(object):
|
55 |
+
def __init__(self, opt_path='', device='cpu'):
|
56 |
+
self.opt = opt_parse(opt_path)
|
57 |
+
self.device = device #torch.device('cpu')
|
58 |
+
optk = self.opt['kernel_info']
|
59 |
+
|
60 |
+
# blur settings for the first degradation
|
61 |
+
self.blur_kernel_size = optk['blur_kernel_size']
|
62 |
+
self.kernel_list = optk['kernel_list']
|
63 |
+
self.kernel_prob = optk['kernel_prob']
|
64 |
+
self.blur_sigma = optk['blur_sigma']
|
65 |
+
self.betag_range = optk['betag_range']
|
66 |
+
self.betap_range = optk['betap_range']
|
67 |
+
self.sinc_prob = optk['sinc_prob']
|
68 |
+
|
69 |
+
# blur settings for the second degradation
|
70 |
+
self.blur_kernel_size2 = optk['blur_kernel_size2']
|
71 |
+
self.kernel_list2 = optk['kernel_list2']
|
72 |
+
self.kernel_prob2 = optk['kernel_prob2']
|
73 |
+
self.blur_sigma2 = optk['blur_sigma2']
|
74 |
+
self.betag_range2 = optk['betag_range2']
|
75 |
+
self.betap_range2 = optk['betap_range2']
|
76 |
+
self.sinc_prob2 = optk['sinc_prob2']
|
77 |
+
|
78 |
+
# a final sinc filter
|
79 |
+
self.final_sinc_prob = optk['final_sinc_prob']
|
80 |
+
|
81 |
+
self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
|
82 |
+
self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect
|
83 |
+
self.pulse_tensor[10, 10] = 1
|
84 |
+
|
85 |
+
self.jpeger = DiffJPEG(differentiable=False).to(self.device)
|
86 |
+
self.usm_shaper = USMSharp().to(self.device)
|
87 |
+
|
88 |
+
def color_jitter_pt(self, img, brightness, contrast, saturation, hue):
|
89 |
+
fn_idx = torch.randperm(4)
|
90 |
+
for fn_id in fn_idx:
|
91 |
+
if fn_id == 0 and brightness is not None:
|
92 |
+
brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
|
93 |
+
img = adjust_brightness(img, brightness_factor)
|
94 |
+
|
95 |
+
if fn_id == 1 and contrast is not None:
|
96 |
+
contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
|
97 |
+
img = adjust_contrast(img, contrast_factor)
|
98 |
+
|
99 |
+
if fn_id == 2 and saturation is not None:
|
100 |
+
saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
|
101 |
+
img = adjust_saturation(img, saturation_factor)
|
102 |
+
|
103 |
+
if fn_id == 3 and hue is not None:
|
104 |
+
hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
|
105 |
+
img = adjust_hue(img, hue_factor)
|
106 |
+
return img
|
107 |
+
|
108 |
+
def random_augment(self, img_gt):
|
109 |
+
# random horizontal flip
|
110 |
+
img_gt, status = augment(img_gt, hflip=True, rotation=False, return_status=True)
|
111 |
+
"""
|
112 |
+
# random color jitter
|
113 |
+
if np.random.uniform() < self.opt['color_jitter_prob']:
|
114 |
+
jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
|
115 |
+
img_gt = img_gt + jitter_val
|
116 |
+
img_gt = np.clip(img_gt, 0, 1)
|
117 |
+
|
118 |
+
# random grayscale
|
119 |
+
if np.random.uniform() < self.opt['gray_prob']:
|
120 |
+
#img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY)
|
121 |
+
img_gt = cv2.cvtColor(img_gt, cv2.COLOR_RGB2GRAY)
|
122 |
+
img_gt = np.tile(img_gt[:, :, None], [1, 1, 3])
|
123 |
+
"""
|
124 |
+
# BGR to RGB, HWC to CHW, numpy to tensor
|
125 |
+
img_gt = img2tensor([img_gt], bgr2rgb=False, float32=True)[0].unsqueeze(0)
|
126 |
+
|
127 |
+
return img_gt
|
128 |
+
|
129 |
+
def random_kernels(self):
|
130 |
+
# ------------------------ Generate kernels (used in the first degradation) ------------------------ #
|
131 |
+
kernel_size = random.choice(self.kernel_range)
|
132 |
+
if np.random.uniform() < self.sinc_prob:
|
133 |
+
# this sinc filter setting is for kernels ranging from [7, 21]
|
134 |
+
if kernel_size < 13:
|
135 |
+
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
136 |
+
else:
|
137 |
+
omega_c = np.random.uniform(np.pi / 5, np.pi)
|
138 |
+
kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
|
139 |
+
else:
|
140 |
+
kernel = random_mixed_kernels(
|
141 |
+
self.kernel_list,
|
142 |
+
self.kernel_prob,
|
143 |
+
kernel_size,
|
144 |
+
self.blur_sigma,
|
145 |
+
self.blur_sigma, [-math.pi, math.pi],
|
146 |
+
self.betag_range,
|
147 |
+
self.betap_range,
|
148 |
+
noise_range=None)
|
149 |
+
# pad kernel
|
150 |
+
pad_size = (21 - kernel_size) // 2
|
151 |
+
kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
|
152 |
+
|
153 |
+
# ------------------------ Generate kernels (used in the second degradation) ------------------------ #
|
154 |
+
kernel_size = random.choice(self.kernel_range)
|
155 |
+
if np.random.uniform() < self.sinc_prob2:
|
156 |
+
if kernel_size < 13:
|
157 |
+
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
158 |
+
else:
|
159 |
+
omega_c = np.random.uniform(np.pi / 5, np.pi)
|
160 |
+
kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
|
161 |
+
else:
|
162 |
+
kernel2 = random_mixed_kernels(
|
163 |
+
self.kernel_list2,
|
164 |
+
self.kernel_prob2,
|
165 |
+
kernel_size,
|
166 |
+
self.blur_sigma2,
|
167 |
+
self.blur_sigma2, [-math.pi, math.pi],
|
168 |
+
self.betag_range2,
|
169 |
+
self.betap_range2,
|
170 |
+
noise_range=None)
|
171 |
+
|
172 |
+
# pad kernel
|
173 |
+
pad_size = (21 - kernel_size) // 2
|
174 |
+
kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
|
175 |
+
|
176 |
+
# ------------------------------------- sinc kernel ------------------------------------- #
|
177 |
+
if np.random.uniform() < self.final_sinc_prob:
|
178 |
+
kernel_size = random.choice(self.kernel_range)
|
179 |
+
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
180 |
+
sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
|
181 |
+
sinc_kernel = torch.FloatTensor(sinc_kernel)
|
182 |
+
else:
|
183 |
+
sinc_kernel = self.pulse_tensor
|
184 |
+
|
185 |
+
kernel = torch.FloatTensor(kernel)
|
186 |
+
kernel2 = torch.FloatTensor(kernel2)
|
187 |
+
|
188 |
+
return kernel, kernel2, sinc_kernel
|
189 |
+
|
190 |
+
@torch.no_grad()
|
191 |
+
def degrade_process(self, img_gt, resize_bak=False):
|
192 |
+
img_gt = self.random_augment(img_gt)
|
193 |
+
kernel1, kernel2, sinc_kernel = self.random_kernels()
|
194 |
+
img_gt, kernel1, kernel2, sinc_kernel = img_gt.to(self.device), kernel1.to(self.device), kernel2.to(self.device), sinc_kernel.to(self.device)
|
195 |
+
#img_gt = self.usm_shaper(img_gt) # shaper gt
|
196 |
+
ori_h, ori_w = img_gt.size()[2:4]
|
197 |
+
|
198 |
+
#scale_final = random.randint(4, 16)
|
199 |
+
scale_final = 4
|
200 |
+
|
201 |
+
# ----------------------- The first degradation process ----------------------- #
|
202 |
+
# blur
|
203 |
+
out = filter2D(img_gt, kernel1)
|
204 |
+
# random resize
|
205 |
+
updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0]
|
206 |
+
if updown_type == 'up':
|
207 |
+
scale = np.random.uniform(1, self.opt['resize_range'][1])
|
208 |
+
elif updown_type == 'down':
|
209 |
+
scale = np.random.uniform(self.opt['resize_range'][0], 1)
|
210 |
+
else:
|
211 |
+
scale = 1
|
212 |
+
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
213 |
+
out = F.interpolate(out, scale_factor=scale, mode=mode)
|
214 |
+
# noise
|
215 |
+
gray_noise_prob = self.opt['gray_noise_prob']
|
216 |
+
if np.random.uniform() < self.opt['gaussian_noise_prob']:
|
217 |
+
out = random_add_gaussian_noise_pt(
|
218 |
+
out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)
|
219 |
+
else:
|
220 |
+
out = random_add_poisson_noise_pt(
|
221 |
+
out,
|
222 |
+
scale_range=self.opt['poisson_scale_range'],
|
223 |
+
gray_prob=gray_noise_prob,
|
224 |
+
clip=True,
|
225 |
+
rounds=False)
|
226 |
+
# JPEG compression
|
227 |
+
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
|
228 |
+
out = torch.clamp(out, 0, 1)
|
229 |
+
out = self.jpeger(out, quality=jpeg_p)
|
230 |
+
|
231 |
+
# ----------------------- The second degradation process ----------------------- #
|
232 |
+
# blur
|
233 |
+
if np.random.uniform() < self.opt['second_blur_prob']:
|
234 |
+
out = filter2D(out, kernel2)
|
235 |
+
# random resize
|
236 |
+
updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0]
|
237 |
+
if updown_type == 'up':
|
238 |
+
scale = np.random.uniform(1, self.opt['resize_range2'][1])
|
239 |
+
elif updown_type == 'down':
|
240 |
+
scale = np.random.uniform(self.opt['resize_range2'][0], 1)
|
241 |
+
else:
|
242 |
+
scale = 1
|
243 |
+
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
244 |
+
out = F.interpolate(
|
245 |
+
out, size=(int(ori_h / scale_final * scale), int(ori_w / scale_final * scale)), mode=mode)
|
246 |
+
# noise
|
247 |
+
gray_noise_prob = self.opt['gray_noise_prob2']
|
248 |
+
if np.random.uniform() < self.opt['gaussian_noise_prob2']:
|
249 |
+
out = random_add_gaussian_noise_pt(
|
250 |
+
out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)
|
251 |
+
else:
|
252 |
+
out = random_add_poisson_noise_pt(
|
253 |
+
out,
|
254 |
+
scale_range=self.opt['poisson_scale_range2'],
|
255 |
+
gray_prob=gray_noise_prob,
|
256 |
+
clip=True,
|
257 |
+
rounds=False)
|
258 |
+
|
259 |
+
# JPEG compression + the final sinc filter
|
260 |
+
# We also need to resize images to desired sizes. We group [resize back + sinc filter] together
|
261 |
+
# as one operation.
|
262 |
+
# We consider two orders:
|
263 |
+
# 1. [resize back + sinc filter] + JPEG compression
|
264 |
+
# 2. JPEG compression + [resize back + sinc filter]
|
265 |
+
# Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
|
266 |
+
if np.random.uniform() < 0.5:
|
267 |
+
# resize back + the final sinc filter
|
268 |
+
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
269 |
+
out = F.interpolate(out, size=(ori_h // scale_final, ori_w // scale_final), mode=mode)
|
270 |
+
out = filter2D(out, sinc_kernel)
|
271 |
+
# JPEG compression
|
272 |
+
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
|
273 |
+
out = torch.clamp(out, 0, 1)
|
274 |
+
out = self.jpeger(out, quality=jpeg_p)
|
275 |
+
else:
|
276 |
+
# JPEG compression
|
277 |
+
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
|
278 |
+
out = torch.clamp(out, 0, 1)
|
279 |
+
out = self.jpeger(out, quality=jpeg_p)
|
280 |
+
# resize back + the final sinc filter
|
281 |
+
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
282 |
+
out = F.interpolate(out, size=(ori_h // scale_final, ori_w // scale_final), mode=mode)
|
283 |
+
out = filter2D(out, sinc_kernel)
|
284 |
+
|
285 |
+
if np.random.uniform() < self.opt['gray_prob']:
|
286 |
+
out = rgb_to_grayscale(out, num_output_channels=1)
|
287 |
+
|
288 |
+
if np.random.uniform() < self.opt['color_jitter_prob']:
|
289 |
+
brightness = self.opt.get('brightness', (0.5, 1.5))
|
290 |
+
contrast = self.opt.get('contrast', (0.5, 1.5))
|
291 |
+
saturation = self.opt.get('saturation', (0, 1.5))
|
292 |
+
hue = self.opt.get('hue', (-0.1, 0.1))
|
293 |
+
out = self.color_jitter_pt(out, brightness, contrast, saturation, hue)
|
294 |
+
|
295 |
+
if resize_bak:
|
296 |
+
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
297 |
+
out = F.interpolate(out, size=(ori_h, ori_w), mode=mode)
|
298 |
+
# clamp and round
|
299 |
+
img_lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
|
300 |
+
|
301 |
+
return img_gt, img_lq
|
302 |
+
|
303 |
+
|
dataloaders/simple_dataset.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import os
|
3 |
+
import glob
|
4 |
+
import torch
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
from torchvision import transforms
|
7 |
+
import random
|
8 |
+
import numpy as np
|
9 |
+
import math
|
10 |
+
|
11 |
+
from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
|
12 |
+
from basicsr.data.transforms import augment
|
13 |
+
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
|
14 |
+
|
15 |
+
from PIL import Image
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
class SimpleDataset(Dataset):
|
20 |
+
def __init__(self, opt, fix_size=512):
|
21 |
+
|
22 |
+
self.opt = opt
|
23 |
+
self.image_root = opt['gt_path']
|
24 |
+
self.fix_size = fix_size
|
25 |
+
exts = ['*.jpg', '*.png']
|
26 |
+
self.image_list = []
|
27 |
+
for image_root in self.image_root:
|
28 |
+
for ext in exts:
|
29 |
+
image_list = glob.glob(os.path.join(image_root, ext))
|
30 |
+
self.image_list += image_list
|
31 |
+
# if add lsdir dataset
|
32 |
+
image_list = glob.glob(os.path.join(image_root, '00*', ext))
|
33 |
+
self.image_list += image_list
|
34 |
+
|
35 |
+
self.crop_preproc = transforms.Compose([
|
36 |
+
# transforms.CenterCrop(fix_size),
|
37 |
+
transforms.Resize(fix_size)
|
38 |
+
# transforms.RandomHorizontalFlip(),
|
39 |
+
])
|
40 |
+
|
41 |
+
self.img_preproc = transforms.Compose([
|
42 |
+
transforms.ToTensor(),
|
43 |
+
])
|
44 |
+
|
45 |
+
# blur settings for the first degradation
|
46 |
+
self.blur_kernel_size = opt['blur_kernel_size']
|
47 |
+
self.kernel_list = opt['kernel_list']
|
48 |
+
self.kernel_prob = opt['kernel_prob'] # a list for each kernel probability
|
49 |
+
self.blur_sigma = opt['blur_sigma']
|
50 |
+
self.betag_range = opt['betag_range'] # betag used in generalized Gaussian blur kernels
|
51 |
+
self.betap_range = opt['betap_range'] # betap used in plateau blur kernels
|
52 |
+
self.sinc_prob = opt['sinc_prob'] # the probability for sinc filters
|
53 |
+
|
54 |
+
# blur settings for the second degradation
|
55 |
+
self.blur_kernel_size2 = opt['blur_kernel_size2']
|
56 |
+
self.kernel_list2 = opt['kernel_list2']
|
57 |
+
self.kernel_prob2 = opt['kernel_prob2']
|
58 |
+
self.blur_sigma2 = opt['blur_sigma2']
|
59 |
+
self.betag_range2 = opt['betag_range2']
|
60 |
+
self.betap_range2 = opt['betap_range2']
|
61 |
+
self.sinc_prob2 = opt['sinc_prob2']
|
62 |
+
|
63 |
+
# a final sinc filter
|
64 |
+
self.final_sinc_prob = opt['final_sinc_prob']
|
65 |
+
|
66 |
+
self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
|
67 |
+
# TODO: kernel range is now hard-coded, should be in the configure file
|
68 |
+
self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect
|
69 |
+
self.pulse_tensor[10, 10] = 1
|
70 |
+
|
71 |
+
print(f'The dataset length: {len(self.image_list)}')
|
72 |
+
|
73 |
+
|
74 |
+
def __getitem__(self, index):
|
75 |
+
image = Image.open(self.image_list[index]).convert('RGB')
|
76 |
+
# width, height = image.size
|
77 |
+
# if width > height:
|
78 |
+
# width_after = self.fix_size
|
79 |
+
# height_after = int(height*width_after/width)
|
80 |
+
# elif height > width:
|
81 |
+
# height_after = self.fix_size
|
82 |
+
# width_after = int(width*height_after/height)
|
83 |
+
# elif height == width:
|
84 |
+
# height_after = self.fix_size
|
85 |
+
# width_after = self.fix_size
|
86 |
+
image = image.resize((self.fix_size, self.fix_size),Image.LANCZOS)
|
87 |
+
# image = self.crop_preproc(image)
|
88 |
+
image = self.img_preproc(image)
|
89 |
+
|
90 |
+
# ------------------------ Generate kernels (used in the first degradation) ------------------------ #
|
91 |
+
kernel_size = random.choice(self.kernel_range)
|
92 |
+
if np.random.uniform() < self.opt['sinc_prob']:
|
93 |
+
# this sinc filter setting is for kernels ranging from [7, 21]
|
94 |
+
if kernel_size < 13:
|
95 |
+
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
96 |
+
else:
|
97 |
+
omega_c = np.random.uniform(np.pi / 5, np.pi)
|
98 |
+
kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
|
99 |
+
else:
|
100 |
+
kernel = random_mixed_kernels(
|
101 |
+
self.kernel_list,
|
102 |
+
self.kernel_prob,
|
103 |
+
kernel_size,
|
104 |
+
self.blur_sigma,
|
105 |
+
self.blur_sigma, [-math.pi, math.pi],
|
106 |
+
self.betag_range,
|
107 |
+
self.betap_range,
|
108 |
+
noise_range=None)
|
109 |
+
# pad kernel
|
110 |
+
pad_size = (21 - kernel_size) // 2
|
111 |
+
kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
|
112 |
+
|
113 |
+
# ------------------------ Generate kernels (used in the second degradation) ------------------------ #
|
114 |
+
kernel_size = random.choice(self.kernel_range)
|
115 |
+
if np.random.uniform() < self.opt['sinc_prob2']:
|
116 |
+
if kernel_size < 13:
|
117 |
+
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
118 |
+
else:
|
119 |
+
omega_c = np.random.uniform(np.pi / 5, np.pi)
|
120 |
+
kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
|
121 |
+
else:
|
122 |
+
kernel2 = random_mixed_kernels(
|
123 |
+
self.kernel_list2,
|
124 |
+
self.kernel_prob2,
|
125 |
+
kernel_size,
|
126 |
+
self.blur_sigma2,
|
127 |
+
self.blur_sigma2, [-math.pi, math.pi],
|
128 |
+
self.betag_range2,
|
129 |
+
self.betap_range2,
|
130 |
+
noise_range=None)
|
131 |
+
|
132 |
+
# pad kernel
|
133 |
+
pad_size = (21 - kernel_size) // 2
|
134 |
+
kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
|
135 |
+
|
136 |
+
# ------------------------------------- the final sinc kernel ------------------------------------- #
|
137 |
+
if np.random.uniform() < self.opt['final_sinc_prob']:
|
138 |
+
kernel_size = random.choice(self.kernel_range)
|
139 |
+
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
140 |
+
sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
|
141 |
+
sinc_kernel = torch.FloatTensor(sinc_kernel)
|
142 |
+
else:
|
143 |
+
sinc_kernel = self.pulse_tensor
|
144 |
+
|
145 |
+
# BGR to RGB, HWC to CHW, numpy to tensor
|
146 |
+
# img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0]
|
147 |
+
kernel = torch.FloatTensor(kernel)
|
148 |
+
kernel2 = torch.FloatTensor(kernel2)
|
149 |
+
|
150 |
+
return_d = {'gt': image, 'kernel1': kernel, 'kernel2': kernel2, 'sinc_kernel': sinc_kernel, 'lq_path': self.image_list[index]}
|
151 |
+
return return_d
|
152 |
+
|
153 |
+
|
154 |
+
def __len__(self):
|
155 |
+
return len(self.image_list)
|
156 |
+
|
figs/bird1.png
ADDED
![]() |
Git LFS Details
|
figs/building.png
ADDED
![]() |
Git LFS Details
|
figs/data_real.png
ADDED
![]() |
Git LFS Details
|
figs/data_real_sup.jpg
ADDED
![]() |
Git LFS Details
|
figs/data_real_suppl.jpg
ADDED
![]() |
Git LFS Details
|
figs/data_real_suppl.png
ADDED
![]() |
Git LFS Details
|
figs/data_syn.png
ADDED
![]() |
Git LFS Details
|
figs/figs.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
figs/framework.png
ADDED
![]() |
Git LFS Details
|
figs/gradio.png
ADDED
![]() |
figs/ground.jpg
ADDED
![]() |
figs/logo1.png
ADDED
![]() |
figs/nature.png
ADDED
![]() |
Git LFS Details
|
figs/person1.png
ADDED
![]() |
Git LFS Details
|
figs/turbo_steps02_building.png
ADDED
![]() |
Git LFS Details
|
figs/turbo_steps02_frog.png
ADDED
![]() |
Git LFS Details
|
figs/turbo_steps04_building.png
ADDED
![]() |
Git LFS Details
|
figs/turbo_steps04_frog.png
ADDED
![]() |
Git LFS Details
|