Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,259 Bytes
b3cdf05 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import cv2
import os
import glob
import torch
from torch.utils.data import Dataset
from torchvision import transforms
import random
import numpy as np
import math
from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
from basicsr.data.transforms import augment
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
from basicsr.utils.registry import DATASET_REGISTRY
from PIL import Image
@DATASET_REGISTRY.register()
class DAPEDataset(Dataset):
def __init__(self, opt, image_size=384):
self.opt = opt
self.root = opt['root']
exts = opt['ext']
gt_lists = []
lr_lists = []
for idx_dir, root_dir in enumerate(self.root):
gt_path = os.path.join(root_dir, 'gt')
lr_path = os.path.join(root_dir, 'sr_bicubic')
print(f'gt_path: {gt_path}')
for ext in exts:
gt_list = glob.glob(os.path.join(gt_path, ext))
lr_list = glob.glob(os.path.join(lr_path, ext))
gt_lists += gt_list
lr_lists += lr_list
self.lr_lists = lr_lists
self.gt_lists = gt_lists
assert len(self.gt_lists) == len(self.lr_lists)
print(f'=========================Dataset Length {len(self.gt_lists)}=========================')
self.img_preproc = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((512, 512)),
])
self.ram_preproc = transforms.Compose([
transforms.Resize((384, 384)),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def __getitem__(self, index):
gt_image = Image.open(self.gt_lists[index]).convert('RGB')
lr_image = Image.open(self.lr_lists[index]).convert('RGB')
lr_image, gt_image = self.img_preproc(lr_image), self.img_preproc(gt_image)
lr_image_ram, gt_image_ram = self.ram_preproc(lr_image), self.ram_preproc(gt_image)
return_d = {'gt': gt_image, 'lq': lr_image, 'gt_ram': gt_image_ram, 'lq_ram': lr_image_ram, 'lq_path':self.lr_lists[index]}
return return_d
def __len__(self):
return len(self.gt_lists)
|