File size: 4,346 Bytes
4c954ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
""" YorkUrban dataset for VP estimation evaluation. """

import os
import csv
import numpy as np
import torch
import cv2
import scipy.io
from skimage.io import imread
from torch.utils.data import Dataset, DataLoader
from pathlib import Path

from ..config.project_config import Config as cfg


def unproject_vp_to_world(vp, K):
    """ Convert the VPs from homogenous format in the image plane

        to world direction. """
    proj_vp = (np.linalg.inv(K) @ vp.T).T
    proj_vp[:, 1] *= -1
    proj_vp /= np.linalg.norm(proj_vp, axis=1, keepdims=True)
    return proj_vp

class NYU(torch.utils.data.Dataset):
    def __init__(self, mode='test', config=None):

        # assert mode in ['val', 'test']

        # Extract the image names
        num_imgs = 1449
        val_size = -49
        
        self.root_dir = cfg.nyu_dataroot
        self.img_paths = [os.path.join(self.root_dir, 'images', 'nyu_rgb_'+str(i+1).zfill(4) + '.png')
                          for i in range(num_imgs)]
        self.vps_paths = [os.path.join(self.root_dir, 'vps', 'vps_' + str(i).zfill(4) + '.csv')
            for i in range(num_imgs)]
        self.lines_paths = [os.path.join(self.root_dir, 'labelled_lines', 'labelled_lines_' + str(i).zfill(4) + '.csv')
            for i in range(num_imgs)]
        self.img_names = [str(i).zfill(4) for i in range(num_imgs)]

        # Separate validation and test
        if mode == 'val':
            self.img_paths = self.img_paths[-val_size:]
            self.vps_paths = self.vps_paths[-val_size:]
            self.lines_paths = self.lines_paths[-val_size:]
            self.img_names = self.img_names[-val_size:]
        elif mode == 'test':
            self.img_paths = self.img_paths[:-val_size]
            self.vps_paths = self.vps_paths[:-val_size]
            self.lines_paths = self.lines_paths[:-val_size]
            self.img_names = self.img_names[:-val_size]

        # Load the intrinsics
        fx_rgb = 5.1885790117450188e+02
        fy_rgb = 5.1946961112127485e+02
        cx_rgb = 3.2558244941119034e+02
        cy_rgb = 2.5373616633400465e+02
        self.K = torch.tensor([[fx_rgb, 0, cx_rgb],
                               [0, fy_rgb, cy_rgb],
                               [0, 0, 1]])

    def get_dataset(self, split):
        return self

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        name = str(Path(img_path).stem)
        img = cv2.imread(img_path)

        # Load the GT VPs
        vps = []
        with open(self.vps_paths[idx]) as csv_file:
            reader = csv.reader(csv_file, delimiter=' ')
            for ri, row in enumerate(reader):
                if ri == 0:
                    continue
                vps.append([float(row[1]), float(row[2]), 1.])
        vps = unproject_vp_to_world(np.array(vps), self.K.numpy())

        lines = []
        with open(self.lines_paths[idx]) as csv_file:
            reader = csv.reader(csv_file, delimiter=' ')
            for ri, row in enumerate(reader):
                if ri == 0:
                    continue
                lines.append([float(row[1]), float(row[2]), 1.])

        # Normalize the images in [0, 1]
        # img = img.astype(float) / 255.

        # Convert to torch tensors
        # img = torch.tensor(img[None], dtype=torch.float)
        vps = torch.tensor(vps, dtype=torch.float)
        lines = torch.tensor(lines, dtype=torch.float)

        data = {'image': img,
                'image_path': img_path,
                'name': name, 
                'gt_lines': lines,
                'vps': vps, 
                'K': self.K
                }

        return data      

    def __len__(self):
        return len(self.img_paths)

    # Overwrite the parent data loader to handle custom split
    def get_data_loader(self, split, shuffle=False):
        """Return a data loader for a given split."""
        assert split in ['val', 'test', 'export']
        batch_size = self.conf.get(split+'_batch_size')
        num_workers = self.conf.get('num_workers', batch_size)
        return DataLoader(self.get_dataset(split), batch_size=batch_size,
                          shuffle=False, pin_memory=True,
                          num_workers=num_workers)