mikitona commited on
Commit
8d9c2f3
·
verified ·
1 Parent(s): 396fde9

Upload testing_utils.py

Browse files
Files changed (1) hide show
  1. testing_utils.py +211 -0
testing_utils.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ from PIL import Image
4
+ from torchvision import transforms
5
+ import torch.nn.functional as F
6
+ from glob import glob
7
+
8
+ import cv2
9
+ import math
10
+ import numpy as np
11
+ import os
12
+ import os.path as osp
13
+ import random
14
+ import time
15
+ import torch
16
+ from pathlib import Path
17
+ from torch.utils import data as data
18
+
19
+ from basicsr.utils import DiffJPEG, USMSharp
20
+ from basicsr.utils.img_process_util import filter2D
21
+ #from basicsr.data.transforms import paired_random_crop, triplet_random_crop
22
+ from basicsr.data.transforms import paired_random_crop
23
+ #from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt, random_add_speckle_noise_pt, random_add_saltpepper_noise_pt, bivariate_Gaussian
24
+ from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt, bivariate_Gaussian
25
+ from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
26
+ from basicsr.data.transforms import augment
27
+ from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
28
+ from basicsr.utils.registry import DATASET_REGISTRY
29
+
30
+
31
+ def parse_args_paired_testing(input_args=None):
32
+ """
33
+ Parses command-line arguments used for configuring an paired session (pix2pix-Turbo).
34
+ This function sets up an argument parser to handle various training options.
35
+
36
+ Returns:
37
+ argparse.Namespace: The parsed command-line arguments.
38
+ """
39
+ parser = argparse.ArgumentParser()
40
+ parser.add_argument("--ref_path", type=str, default=None,)
41
+ parser.add_argument("--base_config", default="./configs/sr_test.yaml", type=str)
42
+ parser.add_argument("--tracker_project_name", type=str, default="train_pix2pix_turbo", help="The name of the wandb project to log to.")
43
+
44
+ # details about the model architecture
45
+ parser.add_argument("--sd_path")
46
+ parser.add_argument("--de_net_path")
47
+ parser.add_argument("--pretrained_path", type=str, default=None,)
48
+ parser.add_argument("--revision", type=str, default=None,)
49
+ parser.add_argument("--variant", type=str, default=None,)
50
+ parser.add_argument("--tokenizer_name", type=str, default=None)
51
+ parser.add_argument("--lora_rank_unet", default=32, type=int)
52
+ parser.add_argument("--lora_rank_vae", default=16, type=int)
53
+
54
+ parser.add_argument("--scale", type=int, default=4, help="Scale factor for SR.")
55
+ parser.add_argument("--chop_size", type=int, default=128, choices=[512, 256, 128], help="Chopping forward.")
56
+ parser.add_argument("--chop_stride", type=int, default=96, help="Chopping stride.")
57
+ parser.add_argument("--padding_offset", type=int, default=32, help="padding offset.")
58
+
59
+ parser.add_argument("--vae_decoder_tiled_size", type=int, default=224)
60
+ parser.add_argument("--vae_encoder_tiled_size", type=int, default=1024)
61
+ parser.add_argument("--latent_tiled_size", type=int, default=96)
62
+ parser.add_argument("--latent_tiled_overlap", type=int, default=32)
63
+
64
+ parser.add_argument("--align_method", type=str, default="wavelet")
65
+
66
+ parser.add_argument("--pos_prompt", type=str, default="A high-resolution, 8K, ultra-realistic image with sharp focus, vibrant colors, and natural lighting.")
67
+ parser.add_argument("--neg_prompt", type=str, default="oil painting, cartoon, blur, dirty, messy, low quality, deformation, low resolution, oversmooth")
68
+
69
+ # training details
70
+ parser.add_argument("--output_dir", type=str, default='output/')
71
+ parser.add_argument("--cache_dir", default=None,)
72
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
73
+ parser.add_argument("--resolution", type=int, default=512,)
74
+ parser.add_argument("--checkpointing_steps", type=int, default=500,)
75
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.",)
76
+ parser.add_argument("--gradient_checkpointing", action="store_true",)
77
+
78
+ parser.add_argument("--dataloader_num_workers", type=int, default=0,)
79
+ parser.add_argument("--allow_tf32", action="store_true",
80
+ help=(
81
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
82
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
83
+ ),
84
+ )
85
+ parser.add_argument("--report_to", type=str, default="wandb",
86
+ help=(
87
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
88
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
89
+ ),
90
+ )
91
+ parser.add_argument("--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"],)
92
+ parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers.")
93
+ parser.add_argument("--set_grads_to_none", action="store_true",)
94
+
95
+ parser.add_argument('--world_size', default=1, type=int,
96
+ help='number of distributed processes')
97
+ parser.add_argument('--local_rank', default=-1, type=int)
98
+ parser.add_argument('--dist_url', default='env://',
99
+ help='url used to set up distributed training')
100
+
101
+ if input_args is not None:
102
+ args = parser.parse_args(input_args)
103
+ else:
104
+ args = parser.parse_args()
105
+
106
+ return args
107
+
108
+
109
+ class PlainDataset(data.Dataset):
110
+ """Modified dataset based on the dataset used for Real-ESRGAN model:
111
+ Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
112
+
113
+ It loads gt (Ground-Truth) images, and augments them.
114
+ It also generates blur kernels and sinc kernels for generating low-quality images.
115
+ Note that the low-quality images are processed in tensors on GPUS for faster processing.
116
+
117
+ Args:
118
+ opt (dict): Config for train datasets. It contains the following keys:
119
+ dataroot_gt (str): Data root path for gt.
120
+ meta_info (str): Path for meta information file.
121
+ io_backend (dict): IO backend type and other kwarg.
122
+ use_hflip (bool): Use horizontal flips.
123
+ use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
124
+ Please see more options in the codes.
125
+ """
126
+
127
+ def __init__(self, opt):
128
+ super(PlainDataset, self).__init__()
129
+ self.opt = opt
130
+ self.file_client = None
131
+ self.io_backend_opt = opt['io_backend']
132
+
133
+ if 'image_type' not in opt:
134
+ opt['image_type'] = 'png'
135
+
136
+ # support multiple type of data: file path and meta data, remove support of lmdb
137
+ self.lr_paths = []
138
+ if 'lr_path' in opt:
139
+ if isinstance(opt['lr_path'], str):
140
+ self.lr_paths.extend(sorted(
141
+ [str(x) for x in Path(opt['lr_path']).glob('*.png')] +
142
+ [str(x) for x in Path(opt['lr_path']).glob('*.jpg')] +
143
+ [str(x) for x in Path(opt['lr_path']).glob('*.jpeg')]
144
+ ))
145
+ else:
146
+ self.lr_paths.extend(sorted([str(x) for x in Path(opt['lr_path'][0]).glob('*.'+opt['image_type'])]))
147
+ if len(opt['lr_path']) > 1:
148
+ for i in range(len(opt['lr_path'])-1):
149
+ self.lr_paths.extend(sorted([str(x) for x in Path(opt['lr_path'][i+1]).glob('*.'+opt['image_type'])]))
150
+
151
+ def __getitem__(self, index):
152
+ if self.file_client is None:
153
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
154
+
155
+ # -------------------------------- Load gt images -------------------------------- #
156
+ # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
157
+ lr_path = self.lr_paths[index]
158
+
159
+ # avoid errors caused by high latency in reading files
160
+ retry = 3
161
+ while retry > 0:
162
+ try:
163
+ lr_img_bytes = self.file_client.get(lr_path, 'gt')
164
+ except (IOError, OSError) as e:
165
+ # logger = get_root_logger()
166
+ # logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}')
167
+ # change another file to read
168
+ index = random.randint(0, self.__len__()-1)
169
+ lr_path = self.lr_paths[index]
170
+ time.sleep(1) # sleep 1s for occasional server congestion
171
+ else:
172
+ break
173
+ finally:
174
+ retry -= 1
175
+
176
+ img_lr = imfrombytes(lr_img_bytes, float32=True)
177
+
178
+ # BGR to RGB, HWC to CHW, numpy to tensor
179
+ img_lr = img2tensor([img_lr], bgr2rgb=True, float32=True)[0]
180
+
181
+ return_d = {'lr': img_lr, 'lr_path': lr_path}
182
+ return return_d
183
+
184
+ def __len__(self):
185
+ return len(self.lr_paths)
186
+
187
+
188
+ def lr_proc(config, batch, device):
189
+ im_lr = batch['lr'].cuda()
190
+ im_lr = im_lr.to(memory_format=torch.contiguous_format).float()
191
+
192
+ ori_lr = im_lr
193
+
194
+ im_lr = F.interpolate(
195
+ im_lr,
196
+ size=(im_lr.size(-2) * config.sf,
197
+ im_lr.size(-1) * config.sf),
198
+ mode='bicubic',
199
+ )
200
+
201
+ im_lr = im_lr.contiguous()
202
+ im_lr = im_lr * 2 - 1.0
203
+ im_lr = torch.clamp(im_lr, -1.0, 1.0)
204
+
205
+ ori_h, ori_w = im_lr.size(-2), im_lr.size(-1)
206
+
207
+ pad_h = (math.ceil(ori_h / 64)) * 64 - ori_h
208
+ pad_w = (math.ceil(ori_w / 64)) * 64 - ori_w
209
+ im_lr = F.pad(im_lr, pad=(0, pad_w, 0, pad_h), mode='reflect')
210
+
211
+ return im_lr.to(device), ori_lr.to(device), (ori_h, ori_w)