Spaces:
Running
on
Zero
Running
on
Zero
| ''' | |
| * SeeSR: Towards Semantics-Aware Real-World Image Super-Resolution | |
| * Modified from diffusers by Rongyuan Wu | |
| * 24/12/2023 | |
| ''' | |
| import os | |
| import cv2 | |
| import torch | |
| import torch.nn.functional as F | |
| from pytorch_lightning import seed_everything | |
| import argparse | |
| import sys | |
| sys.path.append(os.getcwd()) | |
| from basicsr.data.realesrgan_dataset import RealESRGANDataset | |
| from dataloaders.simple_dataset import SimpleDataset | |
| from ram.models import ram | |
| from ram import inference_ram as inference | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--gt_path", nargs='+', default=['PATH 1', 'PATH 2'], help='the path of high-resolution images') | |
| parser.add_argument("--save_dir", type=str, default='preset/datasets/train_datasets/training_for_dape', help='the save path of the training dataset.') | |
| parser.add_argument("--start_gpu", type=int, default=1, help='if you have 5 GPUs, you can set it to 1/2/3/4/5 on five gpus for parallel processing., which will save your time. ') | |
| parser.add_argument("--batch_size", type=int, default=10, help='smaller batch size means much time but more extensive degradation for making the training dataset.') | |
| parser.add_argument("--epoch", type=int, default=1, help='decide how many epochs to create for the dataset.') | |
| args = parser.parse_args() | |
| print(f'====== START GPU: {args.start_gpu} =========') | |
| seed_everything(24+args.start_gpu*1000) | |
| from torchvision.transforms import Normalize, Compose | |
| args_training_dataset = {} | |
| # Please set your gt path here. If you have multi dirs, you can set it as ['PATH1', 'PATH2', 'PATH3', ...] | |
| args_training_dataset['gt_path'] = args.gt_path | |
| #################### REALESRGAN SETTING ########################### | |
| args_training_dataset['queue_size'] = 160 | |
| args_training_dataset['crop_size'] = 512 | |
| args_training_dataset['io_backend'] = {} | |
| args_training_dataset['io_backend']['type'] = 'disk' | |
| args_training_dataset['blur_kernel_size'] = 21 | |
| args_training_dataset['kernel_list'] = ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] | |
| args_training_dataset['kernel_prob'] = [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] | |
| args_training_dataset['sinc_prob'] = 0.1 | |
| args_training_dataset['blur_sigma'] = [0.2, 3] | |
| args_training_dataset['betag_range'] = [0.5, 4] | |
| args_training_dataset['betap_range'] = [1, 2] | |
| args_training_dataset['blur_kernel_size2'] = 11 | |
| args_training_dataset['kernel_list2'] = ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] | |
| args_training_dataset['kernel_prob2'] = [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] | |
| args_training_dataset['sinc_prob2'] = 0.1 | |
| args_training_dataset['blur_sigma2'] = [0.2, 1.5] | |
| args_training_dataset['betag_range2'] = [0.5, 4.0] | |
| args_training_dataset['betap_range2'] = [1, 2] | |
| args_training_dataset['final_sinc_prob'] = 0.8 | |
| args_training_dataset['use_hflip'] = True | |
| args_training_dataset['use_rot'] = False | |
| train_dataset = SimpleDataset(args_training_dataset, fix_size=512) | |
| batch_size = args.batch_size | |
| train_dataloader = torch.utils.data.DataLoader( | |
| train_dataset, | |
| shuffle=False, | |
| batch_size=batch_size, | |
| num_workers=11, | |
| drop_last=True, | |
| ) | |
| #################### REALESRGAN SETTING ########################### | |
| args_degradation = {} | |
| # the first degradation process | |
| args_degradation['resize_prob'] = [0.2, 0.7, 0.1] # up, down, keep | |
| args_degradation['resize_range'] = [0.15, 1.5] | |
| args_degradation['gaussian_noise_prob'] = 0.5 | |
| args_degradation['noise_range'] = [1, 30] | |
| args_degradation['poisson_scale_range'] = [0.05, 3.0] | |
| args_degradation['gray_noise_prob'] = 0.4 | |
| args_degradation['jpeg_range'] = [30, 95] | |
| # the second degradation process | |
| args_degradation['second_blur_prob'] = 0.8 | |
| args_degradation['resize_prob2'] = [0.3, 0.4, 0.3] # up, down, keep | |
| args_degradation['resize_range2'] = [0.3, 1.2] | |
| args_degradation['gaussian_noise_prob2'] = 0.5 | |
| args_degradation['noise_range2'] = [1, 25] | |
| args_degradation['poisson_scale_range2'] = [0.05, 2.5] | |
| args_degradation['gray_noise_prob2'] = 0.4 | |
| args_degradation['jpeg_range2'] = [30, 95] | |
| args_degradation['gt_size']= 512 | |
| args_degradation['no_degradation_prob']= 0.01 | |
| from basicsr.utils import DiffJPEG, USMSharp | |
| from basicsr.utils.img_process_util import filter2D | |
| from basicsr.data.transforms import paired_random_crop, triplet_random_crop | |
| from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt, random_add_speckle_noise_pt, random_add_saltpepper_noise_pt, bivariate_Gaussian | |
| import random | |
| import torch.nn.functional as F | |
| def realesrgan_degradation(batch, args_degradation, use_usm=True, sf=4, resize_lq=True): | |
| jpeger = DiffJPEG(differentiable=False).cuda() | |
| usm_sharpener = USMSharp().cuda() # do usm sharpening | |
| im_gt = batch['gt'].cuda() | |
| if use_usm: | |
| im_gt = usm_sharpener(im_gt) | |
| im_gt = im_gt.to(memory_format=torch.contiguous_format).float() | |
| kernel1 = batch['kernel1'].cuda() | |
| kernel2 = batch['kernel2'].cuda() | |
| sinc_kernel = batch['sinc_kernel'].cuda() | |
| ori_h, ori_w = im_gt.size()[2:4] | |
| # ----------------------- The first degradation process ----------------------- # | |
| # blur | |
| out = filter2D(im_gt, kernel1) | |
| # random resize | |
| updown_type = random.choices( | |
| ['up', 'down', 'keep'], | |
| args_degradation['resize_prob'], | |
| )[0] | |
| if updown_type == 'up': | |
| scale = random.uniform(1, args_degradation['resize_range'][1]) | |
| elif updown_type == 'down': | |
| scale = random.uniform(args_degradation['resize_range'][0], 1) | |
| else: | |
| scale = 1 | |
| mode = random.choice(['area', 'bilinear', 'bicubic']) | |
| out = F.interpolate(out, scale_factor=scale, mode=mode) | |
| # add noise | |
| gray_noise_prob = args_degradation['gray_noise_prob'] | |
| if random.random() < args_degradation['gaussian_noise_prob']: | |
| out = random_add_gaussian_noise_pt( | |
| out, | |
| sigma_range=args_degradation['noise_range'], | |
| clip=True, | |
| rounds=False, | |
| gray_prob=gray_noise_prob, | |
| ) | |
| else: | |
| out = random_add_poisson_noise_pt( | |
| out, | |
| scale_range=args_degradation['poisson_scale_range'], | |
| gray_prob=gray_noise_prob, | |
| clip=True, | |
| rounds=False) | |
| # JPEG compression | |
| jpeg_p = out.new_zeros(out.size(0)).uniform_(*args_degradation['jpeg_range']) | |
| out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts | |
| out = jpeger(out, quality=jpeg_p) | |
| # ----------------------- The second degradation process ----------------------- # | |
| # blur | |
| if random.random() < args_degradation['second_blur_prob']: | |
| out = filter2D(out, kernel2) | |
| # random resize | |
| updown_type = random.choices( | |
| ['up', 'down', 'keep'], | |
| args_degradation['resize_prob2'], | |
| )[0] | |
| if updown_type == 'up': | |
| scale = random.uniform(1, args_degradation['resize_range2'][1]) | |
| elif updown_type == 'down': | |
| scale = random.uniform(args_degradation['resize_range2'][0], 1) | |
| else: | |
| scale = 1 | |
| mode = random.choice(['area', 'bilinear', 'bicubic']) | |
| out = F.interpolate( | |
| out, | |
| size=(int(ori_h / sf * scale), | |
| int(ori_w / sf * scale)), | |
| mode=mode, | |
| ) | |
| # add noise | |
| gray_noise_prob = args_degradation['gray_noise_prob2'] | |
| if random.random() < args_degradation['gaussian_noise_prob2']: | |
| out = random_add_gaussian_noise_pt( | |
| out, | |
| sigma_range=args_degradation['noise_range2'], | |
| clip=True, | |
| rounds=False, | |
| gray_prob=gray_noise_prob, | |
| ) | |
| else: | |
| out = random_add_poisson_noise_pt( | |
| out, | |
| scale_range=args_degradation['poisson_scale_range2'], | |
| gray_prob=gray_noise_prob, | |
| clip=True, | |
| rounds=False, | |
| ) | |
| # JPEG compression + the final sinc filter | |
| # We also need to resize images to desired sizes. We group [resize back + sinc filter] together | |
| # as one operation. | |
| # We consider two orders: | |
| # 1. [resize back + sinc filter] + JPEG compression | |
| # 2. JPEG compression + [resize back + sinc filter] | |
| # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines. | |
| if random.random() < 0.5: | |
| # resize back + the final sinc filter | |
| mode = random.choice(['area', 'bilinear', 'bicubic']) | |
| out = F.interpolate( | |
| out, | |
| size=(ori_h // sf, | |
| ori_w // sf), | |
| mode=mode, | |
| ) | |
| out = filter2D(out, sinc_kernel) | |
| # JPEG compression | |
| jpeg_p = out.new_zeros(out.size(0)).uniform_(*args_degradation['jpeg_range2']) | |
| out = torch.clamp(out, 0, 1) | |
| out = jpeger(out, quality=jpeg_p) | |
| else: | |
| # JPEG compression | |
| jpeg_p = out.new_zeros(out.size(0)).uniform_(*args_degradation['jpeg_range2']) | |
| out = torch.clamp(out, 0, 1) | |
| out = jpeger(out, quality=jpeg_p) | |
| # resize back + the final sinc filter | |
| mode = random.choice(['area', 'bilinear', 'bicubic']) | |
| out = F.interpolate( | |
| out, | |
| size=(ori_h // sf, | |
| ori_w // sf), | |
| mode=mode, | |
| ) | |
| out = filter2D(out, sinc_kernel) | |
| # clamp and round | |
| im_lq = torch.clamp(out, 0, 1.0) | |
| # random crop | |
| gt_size = args_degradation['gt_size'] | |
| im_gt, im_lq = paired_random_crop(im_gt, im_lq, gt_size, sf) | |
| lq, gt = im_lq, im_gt | |
| gt = torch.clamp(gt, 0, 1) | |
| lq = torch.clamp(lq, 0, 1) | |
| return lq, gt | |
| root_path = args.save_dir | |
| gt_path = os.path.join(root_path, 'gt') | |
| lr_path = os.path.join(root_path, 'lr') | |
| sr_bicubic_path = os.path.join(root_path, 'sr_bicubic') | |
| os.makedirs(gt_path, exist_ok=True) | |
| os.makedirs(lr_path, exist_ok=True) | |
| os.makedirs(sr_bicubic_path, exist_ok=True) | |
| epochs = args.epoch | |
| step = len(train_dataset) * epochs * args.start_gpu | |
| with torch.no_grad(): | |
| for epoch in range(epochs): | |
| for num_batch, batch in enumerate(train_dataloader): | |
| lr_batch, gt_batch = realesrgan_degradation(batch, args_degradation=args_degradation) | |
| sr_bicubic_batch = F.interpolate(lr_batch, size=(gt_batch.size(-2), gt_batch.size(-1)), mode='bicubic',) | |
| for i in range(batch_size): | |
| step += 1 | |
| print('process {} images...'.format(step)) | |
| lr = lr_batch[i, ...] | |
| gt = gt_batch[i, ...] | |
| sr_bicubic = sr_bicubic_batch[i, ...] | |
| lr_save_path = os.path.join(lr_path,'{}.png'.format(str(step).zfill(7))) | |
| gt_save_path = os.path.join(gt_path, '{}.png'.format(str(step).zfill(7))) | |
| sr_bicubic_save_path = os.path.join(sr_bicubic_path, '{}.png'.format(str(step).zfill(7))) | |
| cv2.imwrite(lr_save_path, 255*lr.detach().cpu().squeeze().permute(1,2,0).numpy()[..., ::-1]) | |
| cv2.imwrite(gt_save_path, 255*gt.detach().cpu().squeeze().permute(1,2,0).numpy()[..., ::-1]) | |
| cv2.imwrite(sr_bicubic_save_path, 255*sr_bicubic.detach().cpu().squeeze().permute(1,2,0).numpy()[..., ::-1]) | |
| del lr_batch, gt_batch, sr_bicubic_batch | |
| torch.cuda.empty_cache() | |