from pathlib import Path import cv2 import PIL import numpy as np import torch import torch.utils from torch.utils.data import DataLoader from tqdm import tqdm import glob from .transforms.homographic_transforms import sample_homography from kornia.geometry import warp_perspective,transform_points homography_params = { 'translation': True, 'rotation': True, 'scaling': True, 'perspective': True, 'scaling_amplitude': 0.2, 'perspective_amplitude_x': 0.2, 'perspective_amplitude_y': 0.2, 'patch_ratio': 0.85, 'max_angle': 1.57, 'allow_artifacts': True } class Hybrid_Dataset(torch.utils.data.Dataset): def __init__(self, datacfg=None, images_root=None, overwrite=False): self.conf = datacfg self.root = images_root # torch.manual_seed(self.conf.seed) # np.random.seed(self.conf.seed) # # Extract images paths # self.files = [Path(self.root)/img for img in Path(self.root).iterdir() # if img.with_suffix('.png') or img.with_suffix('.jpg')] self.files = glob.glob(f'{images_root}/*.png') + glob.glob(f'{images_root}/*.jpg') self.files.sort() self.npz_files = [] if overwrite else glob.glob(f'{images_root}/*.npz') self.size = (512, 512) self.overwrite = overwrite if len(self.files) == 0: raise ValueError(f'Could not find any images in the path of {self.root}. Please check the input images root path.') # Randomly generate the homography for each image to ensure reproducibility for file in tqdm(self.files): npz_file = Path(file).with_suffix('.npz') if not npz_file.exists() or self.overwrite: image = cv2.imread(file, 0) image = cv2.resize(image, self.size) image = np.array(image, dtype=np.float32)/255.0 w, h = image.shape[:2] H = sample_homography(self.size, **homography_params)[0] warped_image = cv2.warpPerspective(image, H, self.size) warped_image = np.array(warped_image, dtype=np.float32) data = { 'ref_image': image, 'target_image': warped_image, 'homo_mat': H, } np.savez(npz_file, ref_image=image, target_image=warped_image, homo_mat=H) self.npz_files.append(npz_file) def get_dataset(self): return self.npz_files def get_images(self): return self.files def len_dataset(self): return len(self.files) def __getitem__(self, idx): npz_file = self.npz_files(idx) data = np.load(npz_file) return data