File size: 6,916 Bytes
9e15541
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165

import numpy as np
import time
import torch
import os

from PIL import Image

from torchvision import transforms
from torch.utils.data import Dataset

from collections import namedtuple
from datasets.kitti_360.labels import trainId2label


Label = namedtuple(
    "Label",
    [
        "name",
        "id",
        "trainId",
        "category",
        "categoryId",
        "hasInstances",
        "ignoreInEval",
        "color",
        "to_cs27",
    ],
)

BDD_LABEL = [
    Label("unlabeled", 0, 255, "void", 0, False, True, (0, 0, 0), 255),
    Label("dynamic", 1, 255, "void", 0, False, True, (111, 74, 0), 255),
    Label("ego vehicle", 2, 255, "void", 0, False, True, (0, 0, 0), 255),
    Label("ground", 3, 255, "void", 0, False, True, (81, 0, 81), 255),
    Label("static", 4, 255, "void", 0, False, True, (0, 0, 0), 255),
    Label("parking", 5, 255, "flat", 1, False, True, (250, 170, 160), 2),
    Label("rail track", 6, 255, "flat", 1, False, True, (230, 150, 140), 3),
    Label("road", 7, 0, "flat", 1, False, False, (128, 64, 128), 0),
    Label("sidewalk", 8, 1, "flat", 1, False, False, (244, 35, 232), 1),
    Label("bridge", 9, 255, "construction", 2, False, True, (150, 100, 100), 8),
    Label("building", 10, 2, "construction", 2, False, False, (70, 70, 70), 4),
    Label("fence", 11, 4, "construction", 2, False, False, (190, 153, 153), 6),
    Label("garage", 12, 255, "construction", 2, False, True, (180, 100, 180), 255),
    Label("guard rail", 13, 255, "construction", 2, False, True, (180, 165, 180), 7),
    Label("tunnel", 14, 255, "construction", 2, False, True, (150, 120, 90), 9),
    Label("wall", 15, 3, "construction", 2, False, False, (102, 102, 156), 5),
    Label("banner", 16, 255, "object", 3, False, True, (250, 170, 100), 255),
    Label("billboard", 17, 255, "object", 3, False, True, (220, 220, 250), 255),
    Label("lane divider", 18, 255, "object", 3, False, True, (255, 165, 0), 255),
    Label("parking sign", 19, 255, "object", 3, False, False, (220, 20, 60), 255),
    Label("pole", 20, 5, "object", 3, False, False, (153, 153, 153), 10),
    Label("polegroup", 21, 255, "object", 3, False, True, (153, 153, 153), 11),
    Label("street light", 22, 255, "object", 3, False, True, (220, 220, 100), 255),
    Label("traffic cone", 23, 255, "object", 3, False, True, (255, 70, 0), 255),
    Label("traffic device", 24, 255, "object", 3, False, True, (220, 220, 220), 255),
    Label("traffic light", 25, 6, "object", 3, False, False, (250, 170, 30), 12),
    Label("traffic sign", 26, 7, "object", 3, False, False, (220, 220, 0), 13),
    Label("traffic sign frame", 27, 255, "object", 3, False, True, (250, 170, 250), 255),
    Label("terrain", 28, 9, "nature", 4, False, False, (152, 251, 152), 15),
    Label("vegetation", 29, 8, "nature", 4, False, False, (107, 142, 35), 14),
    Label("sky", 30, 10, "sky", 5, False, False, (70, 130, 180), 16),
    Label("person", 31, 11, "human", 6, True, False, (220, 20, 60), 17),
    Label("rider", 32, 12, "human", 6, True, False, (255, 0, 0), 18),
    Label("bicycle", 33, 18, "vehicle", 7, True, False, (119, 11, 32), 26),
    Label("bus", 34, 15, "vehicle", 7, True, False, (0, 60, 100), 21),
    Label("car", 35, 13, "vehicle", 7, True, False, (0, 0, 142), 19),
    Label("caravan", 36, 255, "vehicle", 7, True, True, (0, 0, 90), 22),
    Label("motorcycle", 37, 17, "vehicle", 7, True, False, (0, 0, 230), 25),
    Label("trailer", 38, 255, "vehicle", 7, True, True, (0, 0, 110), 23),
    Label("train", 39, 16, "vehicle", 7, True, False, (0, 80, 100), 24),
    Label("truck", 40, 14, "vehicle", 7, True, False, (0, 0, 70), 20),
]


def resize_with_padding(img, target_size, padding_value, interpolation):
    target_h, target_w = target_size
    width, height = img.size
    aspect = width / height

    if aspect > (target_w / target_h):  
        new_w = target_w
        new_h = int(target_w / aspect)
    else:
        new_h = target_h
        new_w = int(target_h * aspect)

    img = transforms.functional.resize(img, (new_h, new_w), interpolation)

    pad_h = target_h - new_h
    pad_w = target_w - new_w
    padding = (pad_w // 2, pad_h // 2, pad_w - pad_w // 2, pad_h - pad_h // 2)

    return transforms.functional.pad(img, padding, fill=padding_value)


class BDDSeg(Dataset):
    def __init__(self, root, image_set, image_size=(192, 640)):
        super(BDDSeg, self).__init__()
        self.split = image_set
        self.root = root

        self.image_transform = transforms.Compose([
            #transforms.Lambda(lambda img: resize_with_padding(img, image_size, padding_value=0, interpolation=transforms.InterpolationMode.BILINEAR)), 

            transforms.Resize((320, 640), interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
        ])

        self.target_transform = transforms.Compose([
            #transforms.Lambda(lambda img: resize_with_padding(img, image_size, padding_value=-1, interpolation=transforms.InterpolationMode.NEAREST)), 

            transforms.Resize((320, 640), interpolation=transforms.InterpolationMode.NEAREST),
            transforms.CenterCrop(image_size),
            transforms.PILToTensor(),
            transforms.Lambda(lambda x: x.long()),
        ])

        self.images, self.targets = [], []

        image_dir = os.path.join(self.root, "images/10k", self.split)
        target_dir = os.path.join(self.root, "labels/pan_seg/bitmasks", self.split)
        for file_name in os.listdir(image_dir):
            image_path = os.path.join(image_dir, file_name)

            target_filename = os.path.splitext(file_name)[0] + ".png"
            target_path = os.path.join(target_dir, target_filename)
            assert os.path.isfile(target_path)

            self.images.append(image_path)
            self.targets.append(target_path)

        self.class_mapping = torch.Tensor([trainId2label[c.trainId].id for c in BDD_LABEL]).int()

    def __getitem__(self, index):
        _start_time = time.time()

        image = Image.open(self.images[index]).convert("RGB")
        target = Image.open(self.targets[index])

        image = self.image_transform(image)
        target = self.target_transform(target)
      
        image = 2.0 * image - 1.0
        poses = torch.eye(4)        # (4, 4) 
        projs = torch.eye(3)        # (3, 3) 
        target = target[0]  # ("instance", "semantic", "polygon", "color")
        target = self.class_mapping[target]
        
        _proc_time = time.time() - _start_time

        data = {
            "imgs": [image.numpy()],
            "poses": [poses.numpy()],
            "projs": [projs.numpy()],
            "segs": [target.numpy()],
            "t__get_item__": np.array([_proc_time]),
            "index": [np.array([index])],
        }
        return data

    def __len__(self):
        return len(self.images)