Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,761 Bytes
4c954ae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
from pathlib import Path
import cv2
import PIL
import numpy as np
import torch
import torch.utils
from torch.utils.data import DataLoader
from tqdm import tqdm
import glob
from .transforms.homographic_transforms import sample_homography
from kornia.geometry import warp_perspective,transform_points
homography_params = {
'translation': True,
'rotation': True,
'scaling': True,
'perspective': True,
'scaling_amplitude': 0.2,
'perspective_amplitude_x': 0.2,
'perspective_amplitude_y': 0.2,
'patch_ratio': 0.85,
'max_angle': 1.57,
'allow_artifacts': True
}
class Hybrid_Dataset(torch.utils.data.Dataset):
def __init__(self, datacfg=None, images_root=None, overwrite=False):
self.conf = datacfg
self.root = images_root
# torch.manual_seed(self.conf.seed)
# np.random.seed(self.conf.seed)
# # Extract images paths
# self.files = [Path(self.root)/img for img in Path(self.root).iterdir()
# if img.with_suffix('.png') or img.with_suffix('.jpg')]
self.files = glob.glob(f'{images_root}/*.png') + glob.glob(f'{images_root}/*.jpg')
self.files.sort()
self.npz_files = [] if overwrite else glob.glob(f'{images_root}/*.npz')
self.size = (512, 512)
self.overwrite = overwrite
if len(self.files) == 0:
raise ValueError(f'Could not find any images in the path of {self.root}. Please check the input images root path.')
# Randomly generate the homography for each image to ensure reproducibility
for file in tqdm(self.files):
npz_file = Path(file).with_suffix('.npz')
if not npz_file.exists() or self.overwrite:
image = cv2.imread(file, 0)
image = cv2.resize(image, self.size)
image = np.array(image, dtype=np.float32)/255.0
w, h = image.shape[:2]
H = sample_homography(self.size, **homography_params)[0]
warped_image = cv2.warpPerspective(image, H, self.size)
warped_image = np.array(warped_image, dtype=np.float32)
data = {
'ref_image': image,
'target_image': warped_image,
'homo_mat': H,
}
np.savez(npz_file, ref_image=image, target_image=warped_image, homo_mat=H)
self.npz_files.append(npz_file)
def get_dataset(self):
return self.npz_files
def get_images(self):
return self.files
def len_dataset(self):
return len(self.files)
def __getitem__(self, idx):
npz_file = self.npz_files(idx)
data = np.load(npz_file)
return data
|