File size: 2,761 Bytes
4c954ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from pathlib import Path
import cv2
import PIL
import numpy as np
import torch
import torch.utils
from torch.utils.data import DataLoader
from tqdm import tqdm
import glob

from .transforms.homographic_transforms import sample_homography
from kornia.geometry import warp_perspective,transform_points


homography_params = {
    'translation': True,
    'rotation': True,
    'scaling': True,
    'perspective': True,
    'scaling_amplitude': 0.2,
    'perspective_amplitude_x': 0.2,
    'perspective_amplitude_y': 0.2,
    'patch_ratio': 0.85,
    'max_angle': 1.57,
    'allow_artifacts': True
}

class Hybrid_Dataset(torch.utils.data.Dataset):
    def __init__(self, datacfg=None, images_root=None, overwrite=False):
        self.conf = datacfg
        self.root = images_root

        # torch.manual_seed(self.conf.seed)
        # np.random.seed(self.conf.seed)

        # # Extract images paths
        # self.files = [Path(self.root)/img for img in Path(self.root).iterdir()
        #                if img.with_suffix('.png') or img.with_suffix('.jpg')]
        self.files = glob.glob(f'{images_root}/*.png') + glob.glob(f'{images_root}/*.jpg')
        self.files.sort()

        self.npz_files = [] if overwrite else glob.glob(f'{images_root}/*.npz')

        self.size = (512, 512)

        self.overwrite = overwrite

        if len(self.files) == 0:
            raise ValueError(f'Could not find any images in the path of {self.root}. Please check the input images root path.')
        
        # Randomly generate the homography for each image to ensure reproducibility
        for file in tqdm(self.files):
            npz_file = Path(file).with_suffix('.npz')
            if not npz_file.exists() or self.overwrite:
                image = cv2.imread(file, 0)
                image = cv2.resize(image, self.size)
                image = np.array(image, dtype=np.float32)/255.0

                w, h = image.shape[:2]
                H = sample_homography(self.size, **homography_params)[0]
                warped_image = cv2.warpPerspective(image, H, self.size)
                warped_image = np.array(warped_image, dtype=np.float32)

                data = {
                    'ref_image': image,
                    'target_image': warped_image,
                    'homo_mat': H,
                }

                np.savez(npz_file, ref_image=image, target_image=warped_image, homo_mat=H)

                self.npz_files.append(npz_file)

    def get_dataset(self):
        return self.npz_files
    
    def get_images(self):
        return self.files

    def len_dataset(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        npz_file = self.npz_files(idx)
        data = np.load(npz_file)

        return data