Spaces:
Running
on
Zero
Running
on
Zero
import numpy as np | |
import time | |
import torch | |
from torchvision import transforms | |
from torchvision.datasets.cityscapes import Cityscapes | |
from torch.utils.data import Dataset | |
def resize_with_padding(img, target_size, padding_value, interpolation): | |
target_h, target_w = target_size | |
width, height = img.size | |
aspect = width / height | |
if aspect > (target_w / target_h): | |
new_w = target_w | |
new_h = int(target_w / aspect) | |
else: | |
new_h = target_h | |
new_w = int(target_h * aspect) | |
img = transforms.functional.resize(img, (new_h, new_w), interpolation) | |
pad_h = target_h - new_h | |
pad_w = target_w - new_w | |
padding = (pad_w // 2, pad_h // 2, pad_w - pad_w // 2, pad_h - pad_h // 2) | |
return transforms.functional.pad(img, padding, fill=padding_value) | |
class CityscapesSeg(Dataset): | |
def __init__(self, root, image_set, image_size=(192, 640)): | |
super(CityscapesSeg, self).__init__() | |
self.split = image_set | |
self.root = root | |
transform = transforms.Compose([ | |
#transforms.Lambda(lambda img: resize_with_padding(img, image_size, padding_value=0, interpolation=transforms.InterpolationMode.BILINEAR)), | |
transforms.Resize((320, 640), interpolation=transforms.InterpolationMode.BILINEAR), | |
transforms.CenterCrop(image_size), | |
transforms.ToTensor(), | |
]) | |
target_transform = transforms.Compose([ | |
#transforms.Lambda(lambda img: resize_with_padding(img, image_size, padding_value=-1, interpolation=transforms.InterpolationMode.NEAREST)), | |
transforms.Resize((320, 640), interpolation=transforms.InterpolationMode.NEAREST), | |
transforms.CenterCrop(image_size), | |
transforms.PILToTensor(), | |
transforms.Lambda(lambda x: x.long()), | |
]) | |
self.inner_loader = Cityscapes(self.root, image_set, | |
mode="fine", | |
target_type="semantic", | |
transform=transform, | |
target_transform=target_transform) | |
def __getitem__(self, index): | |
_start_time = time.time() | |
image, target = self.inner_loader[index] # (3, h, w) / (1, h, w) | |
image = 2.0 * image - 1.0 | |
poses = torch.eye(4) # (4, 4) | |
projs = torch.eye(3) # (3, 3) | |
target = target.squeeze(0) # (h, w) | |
_proc_time = time.time() - _start_time | |
data = { | |
"imgs": [image.numpy()], | |
"poses": [poses.numpy()], | |
"projs": [projs.numpy()], | |
"segs": [target.numpy()], | |
"t__get_item__": np.array([_proc_time]), | |
"index": [np.array([index])], | |
} | |
return data | |
def __len__(self): | |
return len(self.inner_loader) | |