alexnasa commited on
Commit
c382c9b
·
verified ·
1 Parent(s): 0ced7d0

Upload 22 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,17 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ figs/bird1.png filter=lfs diff=lfs merge=lfs -text
37
+ figs/building.png filter=lfs diff=lfs merge=lfs -text
38
+ figs/data_real_sup.jpg filter=lfs diff=lfs merge=lfs -text
39
+ figs/data_real_suppl.jpg filter=lfs diff=lfs merge=lfs -text
40
+ figs/data_real_suppl.png filter=lfs diff=lfs merge=lfs -text
41
+ figs/data_real.png filter=lfs diff=lfs merge=lfs -text
42
+ figs/data_syn.png filter=lfs diff=lfs merge=lfs -text
43
+ figs/framework.png filter=lfs diff=lfs merge=lfs -text
44
+ figs/nature.png filter=lfs diff=lfs merge=lfs -text
45
+ figs/person1.png filter=lfs diff=lfs merge=lfs -text
46
+ figs/turbo_steps02_building.png filter=lfs diff=lfs merge=lfs -text
47
+ figs/turbo_steps02_frog.png filter=lfs diff=lfs merge=lfs -text
48
+ figs/turbo_steps04_building.png filter=lfs diff=lfs merge=lfs -text
49
+ figs/turbo_steps04_frog.png filter=lfs diff=lfs merge=lfs -text
dataloaders/paired_dataset.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ from PIL import Image
4
+ import random
5
+ import numpy as np
6
+
7
+ from torch import nn
8
+ from torchvision import transforms
9
+ from torch.utils import data as data
10
+ import torch.nn.functional as F
11
+
12
+ from .realesrgan import RealESRGAN_degradation
13
+
14
+ class PairedCaptionDataset(data.Dataset):
15
+ def __init__(
16
+ self,
17
+ root_folders=None,
18
+ tokenizer=None,
19
+ null_text_ratio=0.5,
20
+ # use_ram_encoder=False,
21
+ # use_gt_caption=False,
22
+ # caption_type = 'gt_caption',
23
+ ):
24
+ super(PairedCaptionDataset, self).__init__()
25
+
26
+ self.null_text_ratio = null_text_ratio
27
+ self.lr_list = []
28
+ self.gt_list = []
29
+ self.tag_path_list = []
30
+
31
+ root_folders = root_folders.split(',')
32
+ for root_folder in root_folders:
33
+ lr_path = root_folder +'/sr_bicubic'
34
+ tag_path = root_folder +'/tag'
35
+ gt_path = root_folder +'/gt'
36
+
37
+ self.lr_list += glob.glob(os.path.join(lr_path, '*.png'))
38
+ self.gt_list += glob.glob(os.path.join(gt_path, '*.png'))
39
+ self.tag_path_list += glob.glob(os.path.join(tag_path, '*.txt'))
40
+
41
+
42
+ assert len(self.lr_list) == len(self.gt_list)
43
+ assert len(self.lr_list) == len(self.tag_path_list)
44
+
45
+ self.img_preproc = transforms.Compose([
46
+ transforms.ToTensor(),
47
+ ])
48
+
49
+ ram_mean = [0.485, 0.456, 0.406]
50
+ ram_std = [0.229, 0.224, 0.225]
51
+ self.ram_normalize = transforms.Normalize(mean=ram_mean, std=ram_std)
52
+
53
+ self.tokenizer = tokenizer
54
+
55
+ def tokenize_caption(self, caption=""):
56
+ inputs = self.tokenizer(
57
+ caption, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
58
+ )
59
+
60
+ return inputs.input_ids
61
+
62
+ def __getitem__(self, index):
63
+
64
+
65
+ gt_path = self.gt_list[index]
66
+ gt_img = Image.open(gt_path).convert('RGB')
67
+ gt_img = self.img_preproc(gt_img)
68
+
69
+ lq_path = self.lr_list[index]
70
+ lq_img = Image.open(lq_path).convert('RGB')
71
+ lq_img = self.img_preproc(lq_img)
72
+
73
+ if random.random() < self.null_text_ratio:
74
+ tag = ''
75
+ else:
76
+ tag_path = self.tag_path_list[index]
77
+ file = open(tag_path, 'r')
78
+ tag = file.read()
79
+ file.close()
80
+
81
+ example = dict()
82
+ example["conditioning_pixel_values"] = lq_img.squeeze(0)
83
+ example["pixel_values"] = gt_img.squeeze(0) * 2.0 - 1.0
84
+ example["input_ids"] = self.tokenize_caption(caption=tag).squeeze(0)
85
+
86
+ lq_img = lq_img.squeeze()
87
+
88
+ ram_values = F.interpolate(lq_img.unsqueeze(0), size=(384, 384), mode='bicubic')
89
+ ram_values = ram_values.clamp(0.0, 1.0)
90
+ example["ram_values"] = self.ram_normalize(ram_values.squeeze(0))
91
+
92
+ return example
93
+
94
+ def __len__(self):
95
+ return len(self.gt_list)
dataloaders/params_realesrgan.yml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ scale: 4
2
+ color_jitter_prob: 0.0
3
+ gray_prob: 0.0
4
+
5
+ # the first degradation process
6
+ resize_prob: [0.2, 0.7, 0.1] # up, down, keep
7
+ resize_range: [0.3, 1.5]
8
+ gaussian_noise_prob: 0.5
9
+ noise_range: [1, 15]
10
+ poisson_scale_range: [0.05, 2.0]
11
+ gray_noise_prob: 0.4
12
+ jpeg_range: [60, 95]
13
+
14
+ # the second degradation process
15
+ second_blur_prob: 0.5
16
+ resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
17
+ resize_range2: [0.6, 1.2]
18
+ gaussian_noise_prob2: 0.5
19
+ noise_range2: [1, 12]
20
+ poisson_scale_range2: [0.05, 1.0]
21
+ gray_noise_prob2: 0.4
22
+ jpeg_range2: [60, 100]
23
+
24
+ kernel_info:
25
+ blur_kernel_size: 21
26
+ kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
27
+ kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
28
+ sinc_prob: 0.1
29
+ blur_sigma: [0.2, 3]
30
+ betag_range: [0.5, 4]
31
+ betap_range: [1, 2]
32
+
33
+ blur_kernel_size2: 21
34
+ kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
35
+ kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
36
+ sinc_prob2: 0.1
37
+ blur_sigma2: [0.2, 1.5]
38
+ betag_range2: [0.5, 4]
39
+ betap_range2: [1, 2]
40
+
41
+ final_sinc_prob: 0.8
42
+
43
+
dataloaders/realesrgan.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import cv2
4
+ import glob
5
+ import math
6
+ import yaml
7
+ import random
8
+ from collections import OrderedDict
9
+ import torch
10
+ import torch.nn.functional as F
11
+
12
+ from basicsr.data.transforms import augment
13
+ from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
14
+ from basicsr.utils import DiffJPEG, USMSharp, img2tensor, tensor2img
15
+ from basicsr.utils.img_process_util import filter2D
16
+ from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
17
+ from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation,
18
+ normalize, rgb_to_grayscale)
19
+
20
+ cur_path = os.path.dirname(os.path.abspath(__file__))
21
+
22
+
23
+ def ordered_yaml():
24
+ """Support OrderedDict for yaml.
25
+
26
+ Returns:
27
+ yaml Loader and Dumper.
28
+ """
29
+ try:
30
+ from yaml import CDumper as Dumper
31
+ from yaml import CLoader as Loader
32
+ except ImportError:
33
+ from yaml import Dumper, Loader
34
+
35
+ _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
36
+
37
+ def dict_representer(dumper, data):
38
+ return dumper.represent_dict(data.items())
39
+
40
+ def dict_constructor(loader, node):
41
+ return OrderedDict(loader.construct_pairs(node))
42
+
43
+ Dumper.add_representer(OrderedDict, dict_representer)
44
+ Loader.add_constructor(_mapping_tag, dict_constructor)
45
+ return Loader, Dumper
46
+
47
+ def opt_parse(opt_path):
48
+ with open(opt_path, mode='r') as f:
49
+ Loader, _ = ordered_yaml()
50
+ opt = yaml.load(f, Loader=Loader)
51
+
52
+ return opt
53
+
54
+ class RealESRGAN_degradation(object):
55
+ def __init__(self, opt_path='', device='cpu'):
56
+ self.opt = opt_parse(opt_path)
57
+ self.device = device #torch.device('cpu')
58
+ optk = self.opt['kernel_info']
59
+
60
+ # blur settings for the first degradation
61
+ self.blur_kernel_size = optk['blur_kernel_size']
62
+ self.kernel_list = optk['kernel_list']
63
+ self.kernel_prob = optk['kernel_prob']
64
+ self.blur_sigma = optk['blur_sigma']
65
+ self.betag_range = optk['betag_range']
66
+ self.betap_range = optk['betap_range']
67
+ self.sinc_prob = optk['sinc_prob']
68
+
69
+ # blur settings for the second degradation
70
+ self.blur_kernel_size2 = optk['blur_kernel_size2']
71
+ self.kernel_list2 = optk['kernel_list2']
72
+ self.kernel_prob2 = optk['kernel_prob2']
73
+ self.blur_sigma2 = optk['blur_sigma2']
74
+ self.betag_range2 = optk['betag_range2']
75
+ self.betap_range2 = optk['betap_range2']
76
+ self.sinc_prob2 = optk['sinc_prob2']
77
+
78
+ # a final sinc filter
79
+ self.final_sinc_prob = optk['final_sinc_prob']
80
+
81
+ self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
82
+ self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect
83
+ self.pulse_tensor[10, 10] = 1
84
+
85
+ self.jpeger = DiffJPEG(differentiable=False).to(self.device)
86
+ self.usm_shaper = USMSharp().to(self.device)
87
+
88
+ def color_jitter_pt(self, img, brightness, contrast, saturation, hue):
89
+ fn_idx = torch.randperm(4)
90
+ for fn_id in fn_idx:
91
+ if fn_id == 0 and brightness is not None:
92
+ brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
93
+ img = adjust_brightness(img, brightness_factor)
94
+
95
+ if fn_id == 1 and contrast is not None:
96
+ contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
97
+ img = adjust_contrast(img, contrast_factor)
98
+
99
+ if fn_id == 2 and saturation is not None:
100
+ saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
101
+ img = adjust_saturation(img, saturation_factor)
102
+
103
+ if fn_id == 3 and hue is not None:
104
+ hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
105
+ img = adjust_hue(img, hue_factor)
106
+ return img
107
+
108
+ def random_augment(self, img_gt):
109
+ # random horizontal flip
110
+ img_gt, status = augment(img_gt, hflip=True, rotation=False, return_status=True)
111
+ """
112
+ # random color jitter
113
+ if np.random.uniform() < self.opt['color_jitter_prob']:
114
+ jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
115
+ img_gt = img_gt + jitter_val
116
+ img_gt = np.clip(img_gt, 0, 1)
117
+
118
+ # random grayscale
119
+ if np.random.uniform() < self.opt['gray_prob']:
120
+ #img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY)
121
+ img_gt = cv2.cvtColor(img_gt, cv2.COLOR_RGB2GRAY)
122
+ img_gt = np.tile(img_gt[:, :, None], [1, 1, 3])
123
+ """
124
+ # BGR to RGB, HWC to CHW, numpy to tensor
125
+ img_gt = img2tensor([img_gt], bgr2rgb=False, float32=True)[0].unsqueeze(0)
126
+
127
+ return img_gt
128
+
129
+ def random_kernels(self):
130
+ # ------------------------ Generate kernels (used in the first degradation) ------------------------ #
131
+ kernel_size = random.choice(self.kernel_range)
132
+ if np.random.uniform() < self.sinc_prob:
133
+ # this sinc filter setting is for kernels ranging from [7, 21]
134
+ if kernel_size < 13:
135
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
136
+ else:
137
+ omega_c = np.random.uniform(np.pi / 5, np.pi)
138
+ kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
139
+ else:
140
+ kernel = random_mixed_kernels(
141
+ self.kernel_list,
142
+ self.kernel_prob,
143
+ kernel_size,
144
+ self.blur_sigma,
145
+ self.blur_sigma, [-math.pi, math.pi],
146
+ self.betag_range,
147
+ self.betap_range,
148
+ noise_range=None)
149
+ # pad kernel
150
+ pad_size = (21 - kernel_size) // 2
151
+ kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
152
+
153
+ # ------------------------ Generate kernels (used in the second degradation) ------------------------ #
154
+ kernel_size = random.choice(self.kernel_range)
155
+ if np.random.uniform() < self.sinc_prob2:
156
+ if kernel_size < 13:
157
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
158
+ else:
159
+ omega_c = np.random.uniform(np.pi / 5, np.pi)
160
+ kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
161
+ else:
162
+ kernel2 = random_mixed_kernels(
163
+ self.kernel_list2,
164
+ self.kernel_prob2,
165
+ kernel_size,
166
+ self.blur_sigma2,
167
+ self.blur_sigma2, [-math.pi, math.pi],
168
+ self.betag_range2,
169
+ self.betap_range2,
170
+ noise_range=None)
171
+
172
+ # pad kernel
173
+ pad_size = (21 - kernel_size) // 2
174
+ kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
175
+
176
+ # ------------------------------------- sinc kernel ------------------------------------- #
177
+ if np.random.uniform() < self.final_sinc_prob:
178
+ kernel_size = random.choice(self.kernel_range)
179
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
180
+ sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
181
+ sinc_kernel = torch.FloatTensor(sinc_kernel)
182
+ else:
183
+ sinc_kernel = self.pulse_tensor
184
+
185
+ kernel = torch.FloatTensor(kernel)
186
+ kernel2 = torch.FloatTensor(kernel2)
187
+
188
+ return kernel, kernel2, sinc_kernel
189
+
190
+ @torch.no_grad()
191
+ def degrade_process(self, img_gt, resize_bak=False):
192
+ img_gt = self.random_augment(img_gt)
193
+ kernel1, kernel2, sinc_kernel = self.random_kernels()
194
+ img_gt, kernel1, kernel2, sinc_kernel = img_gt.to(self.device), kernel1.to(self.device), kernel2.to(self.device), sinc_kernel.to(self.device)
195
+ #img_gt = self.usm_shaper(img_gt) # shaper gt
196
+ ori_h, ori_w = img_gt.size()[2:4]
197
+
198
+ #scale_final = random.randint(4, 16)
199
+ scale_final = 4
200
+
201
+ # ----------------------- The first degradation process ----------------------- #
202
+ # blur
203
+ out = filter2D(img_gt, kernel1)
204
+ # random resize
205
+ updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0]
206
+ if updown_type == 'up':
207
+ scale = np.random.uniform(1, self.opt['resize_range'][1])
208
+ elif updown_type == 'down':
209
+ scale = np.random.uniform(self.opt['resize_range'][0], 1)
210
+ else:
211
+ scale = 1
212
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
213
+ out = F.interpolate(out, scale_factor=scale, mode=mode)
214
+ # noise
215
+ gray_noise_prob = self.opt['gray_noise_prob']
216
+ if np.random.uniform() < self.opt['gaussian_noise_prob']:
217
+ out = random_add_gaussian_noise_pt(
218
+ out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)
219
+ else:
220
+ out = random_add_poisson_noise_pt(
221
+ out,
222
+ scale_range=self.opt['poisson_scale_range'],
223
+ gray_prob=gray_noise_prob,
224
+ clip=True,
225
+ rounds=False)
226
+ # JPEG compression
227
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
228
+ out = torch.clamp(out, 0, 1)
229
+ out = self.jpeger(out, quality=jpeg_p)
230
+
231
+ # ----------------------- The second degradation process ----------------------- #
232
+ # blur
233
+ if np.random.uniform() < self.opt['second_blur_prob']:
234
+ out = filter2D(out, kernel2)
235
+ # random resize
236
+ updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0]
237
+ if updown_type == 'up':
238
+ scale = np.random.uniform(1, self.opt['resize_range2'][1])
239
+ elif updown_type == 'down':
240
+ scale = np.random.uniform(self.opt['resize_range2'][0], 1)
241
+ else:
242
+ scale = 1
243
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
244
+ out = F.interpolate(
245
+ out, size=(int(ori_h / scale_final * scale), int(ori_w / scale_final * scale)), mode=mode)
246
+ # noise
247
+ gray_noise_prob = self.opt['gray_noise_prob2']
248
+ if np.random.uniform() < self.opt['gaussian_noise_prob2']:
249
+ out = random_add_gaussian_noise_pt(
250
+ out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)
251
+ else:
252
+ out = random_add_poisson_noise_pt(
253
+ out,
254
+ scale_range=self.opt['poisson_scale_range2'],
255
+ gray_prob=gray_noise_prob,
256
+ clip=True,
257
+ rounds=False)
258
+
259
+ # JPEG compression + the final sinc filter
260
+ # We also need to resize images to desired sizes. We group [resize back + sinc filter] together
261
+ # as one operation.
262
+ # We consider two orders:
263
+ # 1. [resize back + sinc filter] + JPEG compression
264
+ # 2. JPEG compression + [resize back + sinc filter]
265
+ # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
266
+ if np.random.uniform() < 0.5:
267
+ # resize back + the final sinc filter
268
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
269
+ out = F.interpolate(out, size=(ori_h // scale_final, ori_w // scale_final), mode=mode)
270
+ out = filter2D(out, sinc_kernel)
271
+ # JPEG compression
272
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
273
+ out = torch.clamp(out, 0, 1)
274
+ out = self.jpeger(out, quality=jpeg_p)
275
+ else:
276
+ # JPEG compression
277
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
278
+ out = torch.clamp(out, 0, 1)
279
+ out = self.jpeger(out, quality=jpeg_p)
280
+ # resize back + the final sinc filter
281
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
282
+ out = F.interpolate(out, size=(ori_h // scale_final, ori_w // scale_final), mode=mode)
283
+ out = filter2D(out, sinc_kernel)
284
+
285
+ if np.random.uniform() < self.opt['gray_prob']:
286
+ out = rgb_to_grayscale(out, num_output_channels=1)
287
+
288
+ if np.random.uniform() < self.opt['color_jitter_prob']:
289
+ brightness = self.opt.get('brightness', (0.5, 1.5))
290
+ contrast = self.opt.get('contrast', (0.5, 1.5))
291
+ saturation = self.opt.get('saturation', (0, 1.5))
292
+ hue = self.opt.get('hue', (-0.1, 0.1))
293
+ out = self.color_jitter_pt(out, brightness, contrast, saturation, hue)
294
+
295
+ if resize_bak:
296
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
297
+ out = F.interpolate(out, size=(ori_h, ori_w), mode=mode)
298
+ # clamp and round
299
+ img_lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
300
+
301
+ return img_gt, img_lq
302
+
303
+
dataloaders/simple_dataset.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import os
3
+ import glob
4
+ import torch
5
+ from torch.utils.data import Dataset
6
+ from torchvision import transforms
7
+ import random
8
+ import numpy as np
9
+ import math
10
+
11
+ from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
12
+ from basicsr.data.transforms import augment
13
+ from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
14
+
15
+ from PIL import Image
16
+
17
+
18
+
19
+ class SimpleDataset(Dataset):
20
+ def __init__(self, opt, fix_size=512):
21
+
22
+ self.opt = opt
23
+ self.image_root = opt['gt_path']
24
+ self.fix_size = fix_size
25
+ exts = ['*.jpg', '*.png']
26
+ self.image_list = []
27
+ for image_root in self.image_root:
28
+ for ext in exts:
29
+ image_list = glob.glob(os.path.join(image_root, ext))
30
+ self.image_list += image_list
31
+ # if add lsdir dataset
32
+ image_list = glob.glob(os.path.join(image_root, '00*', ext))
33
+ self.image_list += image_list
34
+
35
+ self.crop_preproc = transforms.Compose([
36
+ # transforms.CenterCrop(fix_size),
37
+ transforms.Resize(fix_size)
38
+ # transforms.RandomHorizontalFlip(),
39
+ ])
40
+
41
+ self.img_preproc = transforms.Compose([
42
+ transforms.ToTensor(),
43
+ ])
44
+
45
+ # blur settings for the first degradation
46
+ self.blur_kernel_size = opt['blur_kernel_size']
47
+ self.kernel_list = opt['kernel_list']
48
+ self.kernel_prob = opt['kernel_prob'] # a list for each kernel probability
49
+ self.blur_sigma = opt['blur_sigma']
50
+ self.betag_range = opt['betag_range'] # betag used in generalized Gaussian blur kernels
51
+ self.betap_range = opt['betap_range'] # betap used in plateau blur kernels
52
+ self.sinc_prob = opt['sinc_prob'] # the probability for sinc filters
53
+
54
+ # blur settings for the second degradation
55
+ self.blur_kernel_size2 = opt['blur_kernel_size2']
56
+ self.kernel_list2 = opt['kernel_list2']
57
+ self.kernel_prob2 = opt['kernel_prob2']
58
+ self.blur_sigma2 = opt['blur_sigma2']
59
+ self.betag_range2 = opt['betag_range2']
60
+ self.betap_range2 = opt['betap_range2']
61
+ self.sinc_prob2 = opt['sinc_prob2']
62
+
63
+ # a final sinc filter
64
+ self.final_sinc_prob = opt['final_sinc_prob']
65
+
66
+ self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
67
+ # TODO: kernel range is now hard-coded, should be in the configure file
68
+ self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect
69
+ self.pulse_tensor[10, 10] = 1
70
+
71
+ print(f'The dataset length: {len(self.image_list)}')
72
+
73
+
74
+ def __getitem__(self, index):
75
+ image = Image.open(self.image_list[index]).convert('RGB')
76
+ # width, height = image.size
77
+ # if width > height:
78
+ # width_after = self.fix_size
79
+ # height_after = int(height*width_after/width)
80
+ # elif height > width:
81
+ # height_after = self.fix_size
82
+ # width_after = int(width*height_after/height)
83
+ # elif height == width:
84
+ # height_after = self.fix_size
85
+ # width_after = self.fix_size
86
+ image = image.resize((self.fix_size, self.fix_size),Image.LANCZOS)
87
+ # image = self.crop_preproc(image)
88
+ image = self.img_preproc(image)
89
+
90
+ # ------------------------ Generate kernels (used in the first degradation) ------------------------ #
91
+ kernel_size = random.choice(self.kernel_range)
92
+ if np.random.uniform() < self.opt['sinc_prob']:
93
+ # this sinc filter setting is for kernels ranging from [7, 21]
94
+ if kernel_size < 13:
95
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
96
+ else:
97
+ omega_c = np.random.uniform(np.pi / 5, np.pi)
98
+ kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
99
+ else:
100
+ kernel = random_mixed_kernels(
101
+ self.kernel_list,
102
+ self.kernel_prob,
103
+ kernel_size,
104
+ self.blur_sigma,
105
+ self.blur_sigma, [-math.pi, math.pi],
106
+ self.betag_range,
107
+ self.betap_range,
108
+ noise_range=None)
109
+ # pad kernel
110
+ pad_size = (21 - kernel_size) // 2
111
+ kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
112
+
113
+ # ------------------------ Generate kernels (used in the second degradation) ------------------------ #
114
+ kernel_size = random.choice(self.kernel_range)
115
+ if np.random.uniform() < self.opt['sinc_prob2']:
116
+ if kernel_size < 13:
117
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
118
+ else:
119
+ omega_c = np.random.uniform(np.pi / 5, np.pi)
120
+ kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
121
+ else:
122
+ kernel2 = random_mixed_kernels(
123
+ self.kernel_list2,
124
+ self.kernel_prob2,
125
+ kernel_size,
126
+ self.blur_sigma2,
127
+ self.blur_sigma2, [-math.pi, math.pi],
128
+ self.betag_range2,
129
+ self.betap_range2,
130
+ noise_range=None)
131
+
132
+ # pad kernel
133
+ pad_size = (21 - kernel_size) // 2
134
+ kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
135
+
136
+ # ------------------------------------- the final sinc kernel ------------------------------------- #
137
+ if np.random.uniform() < self.opt['final_sinc_prob']:
138
+ kernel_size = random.choice(self.kernel_range)
139
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
140
+ sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
141
+ sinc_kernel = torch.FloatTensor(sinc_kernel)
142
+ else:
143
+ sinc_kernel = self.pulse_tensor
144
+
145
+ # BGR to RGB, HWC to CHW, numpy to tensor
146
+ # img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0]
147
+ kernel = torch.FloatTensor(kernel)
148
+ kernel2 = torch.FloatTensor(kernel2)
149
+
150
+ return_d = {'gt': image, 'kernel1': kernel, 'kernel2': kernel2, 'sinc_kernel': sinc_kernel, 'lq_path': self.image_list[index]}
151
+ return return_d
152
+
153
+
154
+ def __len__(self):
155
+ return len(self.image_list)
156
+
figs/bird1.png ADDED

Git LFS Details

  • SHA256: af7fffbfca6e4ef77670e62434c40175cd031e349d6c577282aa5e88c90310fc
  • Pointer size: 131 Bytes
  • Size of remote file: 930 kB
figs/building.png ADDED

Git LFS Details

  • SHA256: dca6bc44e7926326e9866bc3f299f74f9e191142b5ec46a6723ca1ee1124fb47
  • Pointer size: 132 Bytes
  • Size of remote file: 1.8 MB
figs/data_real.png ADDED

Git LFS Details

  • SHA256: a463f1158e4a44d060477628d1f6316e6dca23fec28cb6c2ad2c1abc5b049eb4
  • Pointer size: 133 Bytes
  • Size of remote file: 18.9 MB
figs/data_real_sup.jpg ADDED

Git LFS Details

  • SHA256: dfd6572b111764d91ca440f18558200a7d1b75a659c9a18cea6dc13d3521e993
  • Pointer size: 132 Bytes
  • Size of remote file: 3.16 MB
figs/data_real_suppl.jpg ADDED

Git LFS Details

  • SHA256: dfd6572b111764d91ca440f18558200a7d1b75a659c9a18cea6dc13d3521e993
  • Pointer size: 132 Bytes
  • Size of remote file: 3.16 MB
figs/data_real_suppl.png ADDED

Git LFS Details

  • SHA256: dfd6572b111764d91ca440f18558200a7d1b75a659c9a18cea6dc13d3521e993
  • Pointer size: 132 Bytes
  • Size of remote file: 3.16 MB
figs/data_syn.png ADDED

Git LFS Details

  • SHA256: 897be62b9a785823e105fb3895cd8218eabb61c519eee1606759c447d557f1e3
  • Pointer size: 133 Bytes
  • Size of remote file: 21.6 MB
figs/figs.md ADDED
@@ -0,0 +1 @@
 
 
1
+
figs/framework.png ADDED

Git LFS Details

  • SHA256: eca9b5ce802ecdcc71e4dbcec84e5aba9efb6139a4b906346521d9ca584a4a22
  • Pointer size: 131 Bytes
  • Size of remote file: 731 kB
figs/gradio.png ADDED
figs/ground.jpg ADDED
figs/logo1.png ADDED
figs/nature.png ADDED

Git LFS Details

  • SHA256: 05fb047d11194e4e3c5c18f07cf693b3dac6f2220b93d3df0cb46623c7c6cae2
  • Pointer size: 132 Bytes
  • Size of remote file: 1.67 MB
figs/person1.png ADDED

Git LFS Details

  • SHA256: 092e57500438225f3b8665c7206fa500ad95151b20b7f3ff9679c375f089d87a
  • Pointer size: 131 Bytes
  • Size of remote file: 368 kB
figs/turbo_steps02_building.png ADDED

Git LFS Details

  • SHA256: cc6d5ddeb55e4ea8086eff9c14684fb5c4b8368e21c2b46f6b2ee7c4f98a81fe
  • Pointer size: 132 Bytes
  • Size of remote file: 1.27 MB
figs/turbo_steps02_frog.png ADDED

Git LFS Details

  • SHA256: b70fb23130b31e5bfdcbed85601eceb8c606d9e31e13fd61e6846387464043db
  • Pointer size: 131 Bytes
  • Size of remote file: 389 kB
figs/turbo_steps04_building.png ADDED

Git LFS Details

  • SHA256: e3a6843d5cad8e5bac989af608eae630537ed2615fe3d55e89ece2b708e03c24
  • Pointer size: 131 Bytes
  • Size of remote file: 329 kB
figs/turbo_steps04_frog.png ADDED

Git LFS Details

  • SHA256: 2935058d927441a28c28d7b8f20a827721aed07d0eeedadd926aa7ba1bb1a0a9
  • Pointer size: 131 Bytes
  • Size of remote file: 308 kB