File size: 3,872 Bytes
4c954ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
""" Rotated Day-Night Image Matching dataset. """

import os
import numpy as np
import torch
import cv2
import csv
from pathlib import Path
from torch.utils.data import Dataset, DataLoader

from ..config.project_config import Config as cfg

def read_timestamps(text_file):
    """
    Read a text file containing the timestamps of images
    and return a dictionary matching the name of the image
    to its timestamp.
    """
    timestamps = {'name': [], 'date': [], 'hour': [],
                  'minute': [], 'time': []}
    with open(text_file, 'r') as csvfile:
        reader = csv.reader(csvfile, delimiter=' ')
        for row in reader:
            timestamps['name'].append(row[0])
            timestamps['date'].append(row[1])
            hour = int(row[2])
            timestamps['hour'].append(hour)
            minute = int(row[3])
            timestamps['minute'].append(minute)
            timestamps['time'].append(hour + minute / 60.)
    return timestamps

class RDNIM(torch.utils.data.Dataset):
    default_conf = {
        'dataset_dir': 'RDNIM',
        'reference': 'day',
    }

    def __init__(self, conf):
        self._root_dir = Path(cfg.rdnim_dataroot)
        ref = conf['reference']

        # Extract the timestamps
        timestamp_files = [p for p
                           in Path(self._root_dir, 'time_stamps').iterdir()]
        timestamps = {}
        for f in timestamp_files:
            id = f.stem
            timestamps[id] = read_timestamps(str(f))

        # Extract the reference images paths
        references = {}
        seq_paths = [p for p in Path(self._root_dir, 'references').iterdir()]
        for seq in seq_paths:
            id = seq.stem
            references[id] = str(Path(seq, ref + '.jpg'))

        # Extract the images paths and the homographies
        seq_path = [p for p in Path(self._root_dir, 'images').iterdir()]
        self._files = []
        for seq in seq_path:
            id = seq.stem
            images_path = [x for x in seq.iterdir() if x.suffix == '.jpg']
            for img in images_path:
                timestamp = timestamps[id]['time'][
                    timestamps[id]['name'].index(img.name)]
                H = np.loadtxt(str(img)[:-4] + '.txt').astype(float)
                self._files.append({
                    'img': str(img),
                    'ref': str(references[id]),
                    'H': H,
                    'timestamp': timestamp})

    def __getitem__(self, item):
        img0_path = self._files[item]['ref']
        img0 = cv2.imread(img0_path, 0)
        img1_path = self._files[item]['img']
        img1 = cv2.imread(img1_path, 0)
        img_size = img0.shape[:2]
        H = self._files[item]['H']

        # Normalize the images in [0, 1]
        img0 = img0.astype(float) / 255.
        img1 = img1.astype(float) / 255.

        img0 = torch.tensor(img0[None], dtype=torch.float)
        img1 = torch.tensor(img1[None], dtype=torch.float)
        H = torch.tensor(H, dtype=torch.float)

        return {'image': img0, 'warped_image': img1, 'H': H,
                'timestamp': self._files[item]['timestamp'],
                'image_path': img0_path, 'warped_image_path': img1_path}

    def __len__(self):
        return len(self._files)

    def get_dataset(self, split):
        assert split in ['test']
        return self

    # Overwrite the parent data loader to handle custom collate_fn
    def get_data_loader(self, split, shuffle=False):
        """Return a data loader for a given split."""
        assert split in ['test']
        batch_size = self.conf.get(split+'_batch_size')
        num_workers = self.conf.get('num_workers', batch_size)
        return DataLoader(self, batch_size=batch_size,
                          shuffle=shuffle or split == 'train',
                          pin_memory=True, num_workers=num_workers)