Spaces:
Sleeping
Sleeping
File size: 7,346 Bytes
6342ac4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision
import torchvision.transforms.functional as TF
import numpy as np
import json
import os
from glob import glob
from utils import load_img, modcrop
DEG_MAP = {
"noise": 0,
"blur" : 1,
"rain" : 2,
"haze" : 3,
"lol" : 4,
"sr" : 5,
"en" : 6,
}
DEG2TASK = {
"noise": "denoising",
"blur" : "deblurring",
"rain" : "deraining",
"haze" : "dehazing",
"lol" : "lol",
"sr" : "sr",
"en" : "enhancement"
}
def augment_prompt (prompt):
### special prompts
lol_prompts = ["fix the illumination", "increase the exposure of the photo", "the image is too dark to see anything, correct the photo", "poor illumination, improve the shot", "brighten dark regions", "make it HDR", "improve the light of the image", "Can you make the image brighter?"]
sr_prompts = ["I need to enhance the size and quality of this image.", "My photo is lacking size and clarity; can you improve it?", "I'd appreciate it if you could upscale this photo.", "My picture is too little, enlarge it.", "upsample this image", "increase the resolution of this photo", "increase the number of pixels", "upsample this photo", "Add details to this image", "improve the quality of this photo"]
en_prompts = ["make my image look like DSLR", "improve the colors of my image", "improve the contrast of this photo", "apply tonemapping", "enhance the colors of the image", "retouch the photo like a photograper"]
init = np.random.choice(["Remove the", "Reduce the", "Clean the", "Fix the", "Remove", "Improve the", "Correct the",])
end = np.random.choice(["please", "fast", "now", "in the photo", "in the picture", "in the image", ""])
newp = f"{init} {prompt} {end}"
if "lol" in prompt:
newp = np.random.choice(lol_prompts)
elif "sr" in prompt:
newp = np.random.choice(sr_prompts)
elif "en" in prompt:
newp = np.random.choice(en_prompts)
newp = newp.strip().replace(" ", " ").replace("\n", "")
return newp
def get_deg_name(path):
"""
Get the degradation name from the path
"""
if ("gopro" in path) or ("GoPro" in path) or ("blur" in path) or ("Blur" in path) or ("RealBlur" in path):
return "blur"
elif ("SOTS" in path) or ("haze" in path) or ("sots" in path) or ("RESIDE" in path):
return "haze"
elif ("LOL" in path):
return "lol"
elif ("fiveK" in path):
return "en"
elif ("super" in path) or ("classicalSR" in path):
return "sr"
elif ("Rain100" in path) or ("rain13k" in path) or ("Rain13k" in path):
return "rain"
else:
return "noise"
def crop_img(image, base=16):
"""
Mod crop the image to ensure the dimension is divisible by base. Also done by SwinIR, Restormer and others.
"""
h = image.shape[0]
w = image.shape[1]
crop_h = h % base
crop_w = w % base
return image[crop_h // 2:h - crop_h + crop_h // 2, crop_w // 2:w - crop_w + crop_w // 2, :]
################# DATASETS
class RefDegImage(Dataset):
"""
Dataset for Image Restoration having low-quality image and the reference image.
Tasks: synthetic denoising, deblurring, super-res, etc.
"""
def __init__(self, hq_img_paths, lq_img_paths, augmentations=None, val=False, name="test", deg_name="noise", deg_class=0):
assert len(hq_img_paths) == len(lq_img_paths)
self.hq_paths = hq_img_paths
self.lq_paths = lq_img_paths
self.totensor = torchvision.transforms.ToTensor()
self.val = val
self.augs = augmentations
self.name = name
self.degradation = deg_name
self.deg_class = deg_class
if self.val:
self.augs = None # No augmentations during validation/test
def __len__(self):
return len(self.hq_paths)
def __getitem__(self, idx):
hq_path = self.hq_paths[idx]
lq_path = self.lq_paths[idx]
hq_image = load_img(hq_path)
lq_image = load_img(lq_path)
if self.val:
# if an image has an odd number dimension we trim for example from [321, 189] to [320, 188].
hq_image = crop_img(hq_image)
lq_image = crop_img(lq_image)
hq_image = self.totensor(hq_image.astype(np.float32))
lq_image = self.totensor(lq_image.astype(np.float32))
return hq_image, lq_image, hq_path
def create_testsets (testsets, debug=False):
"""
Given a list of testsets create pytorch datasets for each.
The method requires the paths to references and noisy images.
"""
assert len(testsets) > 0
if debug:
print (20*'****')
print ("Creating Testsets", len(testsets))
datasets = []
for testdt in testsets:
path_hq , path_lq = testdt[0], testdt[1]
if debug: print (path_hq , path_lq)
if ("denoising" in path_hq) or ("jpeg" in path_hq):
dataset_name = path_hq.split("/")[-1]
dataset_sigma = path_lq.split("/")[-1].split("_")[-1].split(".")[0]
dataset_name = dataset_name+ f"_{dataset_sigma}"
elif "Rain" in path_hq:
if "Rain100L" in path_hq:
dataset_name = "Rain100L"
else:
dataset_name = path_hq.split("/")[3]
elif ("gopro" in path_hq) or ("GoPro" in path_hq):
dataset_name = "GoPro"
elif "LOL" in path_hq:
dataset_name = "LOL"
elif "SOTS" in path_hq:
dataset_name = "SOTS"
elif "fiveK" in path_hq:
dataset_name = "MIT5K"
else:
assert False, f"{path_hq} - unknown dataset"
hq_img_paths = sorted(glob(os.path.join(path_hq, "*")))
lq_img_paths = sorted(glob(os.path.join(path_lq, "*")))
if "SOTS" in path_hq:
# Haze removal SOTS test dataset
dataset_name = "SOTS"
hq_img_paths = sorted(glob(os.path.join(path_hq, "*.jpg")))
assert len(hq_img_paths) == 500
lq_img_paths = [file.replace("GT", "IN") for file in hq_img_paths]
if "fiveK" in path_hq:
dataset_name = "MIT5K"
testf = "test-data/mit5k/test.txt"
f = open(testf, "r")
test_ids = f.readlines()
test_ids = [x.strip() for x in test_ids]
f.close()
hq_img_paths = [os.path.join(path_hq, f"{x}.jpg") for x in test_ids]
lq_img_paths = [x.replace("expertC", "input") for x in hq_img_paths]
assert len(hq_img_paths) == 498
if "gopro" in path_hq:
assert len(hq_img_paths) == 1111
if "LOL" in path_hq:
assert len(hq_img_paths) == 15
assert len(hq_img_paths) == len(lq_img_paths)
deg_name = get_deg_name(path_hq)
deg_class = DEG_MAP[deg_name]
valdts = RefDegImage(hq_img_paths = hq_img_paths,
lq_img_paths = lq_img_paths,
val = True, name= dataset_name, deg_name=deg_name, deg_class=deg_class)
datasets.append(valdts)
assert len(datasets) == len(testsets)
print (20*'****')
return datasets |