File size: 2,879 Bytes
9e15541
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83

import numpy as np
import time
import torch

from torchvision import transforms
from torchvision.datasets.cityscapes import Cityscapes
from torch.utils.data import Dataset


def resize_with_padding(img, target_size, padding_value, interpolation):
    target_h, target_w = target_size
    width, height = img.size
    aspect = width / height

    if aspect > (target_w / target_h):  
        new_w = target_w
        new_h = int(target_w / aspect)
    else:
        new_h = target_h
        new_w = int(target_h * aspect)

    img = transforms.functional.resize(img, (new_h, new_w), interpolation)

    pad_h = target_h - new_h
    pad_w = target_w - new_w
    padding = (pad_w // 2, pad_h // 2, pad_w - pad_w // 2, pad_h - pad_h // 2)

    return transforms.functional.pad(img, padding, fill=padding_value)

class CityscapesSeg(Dataset):
    def __init__(self, root, image_set, image_size=(192, 640)):
        super(CityscapesSeg, self).__init__()
        self.split = image_set
        self.root = root

        transform = transforms.Compose([
            #transforms.Lambda(lambda img: resize_with_padding(img, image_size, padding_value=0, interpolation=transforms.InterpolationMode.BILINEAR)), 

            transforms.Resize((320, 640), interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
        ])

        target_transform = transforms.Compose([
            #transforms.Lambda(lambda img: resize_with_padding(img, image_size, padding_value=-1, interpolation=transforms.InterpolationMode.NEAREST)), 

            transforms.Resize((320, 640), interpolation=transforms.InterpolationMode.NEAREST),
            transforms.CenterCrop(image_size),
            transforms.PILToTensor(),
            transforms.Lambda(lambda x: x.long()),
        ])

        self.inner_loader = Cityscapes(self.root, image_set,
                                       mode="fine",
                                       target_type="semantic",
                                       transform=transform,
                                       target_transform=target_transform)

    def __getitem__(self, index):
        _start_time = time.time()
        image, target = self.inner_loader[index]  # (3, h, w) / (1, h, w)
      
        image = 2.0 * image - 1.0
        poses = torch.eye(4)        # (4, 4) 
        projs = torch.eye(3)        # (3, 3) 
        target = target.squeeze(0)  # (h, w)
        
        _proc_time = time.time() - _start_time

        data = {
            "imgs": [image.numpy()],
            "poses": [poses.numpy()],
            "projs": [projs.numpy()],
            "segs": [target.numpy()],
            "t__get_item__": np.array([_proc_time]),
            "index": [np.array([index])],
        }
        return data

    def __len__(self):
        return len(self.inner_loader)