SceneDINO / datasets /realestate10k /realestate10k_dataset.py
jev-aleks's picture
scenedino init
9e15541
import os
import pickle
import time
from pathlib import Path
from typing import Optional
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision.transforms import ColorJitter
from scenedino.common.augmentation import get_color_aug_fn
class RealEstate10kDataset(Dataset):
def __init__(
self,
data_path: str,
split_path: Optional[str] = None,
target_image_size=(256, 384),
frame_count=2,
dilation=1,
keyframe_offset=0,
color_aug=False,
):
self.data_path = data_path
self.split_path = split_path
self.target_image_size = target_image_size
self.frame_count = frame_count
self.dilation = dilation
self.keyframe_offset = keyframe_offset
self.color_aug = color_aug
if self.split_path is None:
self.split = "train"
else:
self.split = "test"
self._seq_data = self._load_seq_data(self.data_path, self.split)
self._seq_keys = list(self._seq_data.keys())
if isinstance(self.dilation, int):
self._left_offset = (
(self.frame_count - 1) // 2 + self.keyframe_offset
) * self.dilation
dilation = self.dilation
else:
self._left_offset = 0
dilation = 0
if self.split == "train":
self._key_id_pairs = self._full_index(
self._seq_keys,
self._seq_data,
self._left_offset,
(self.frame_count - 1) * dilation,
)
else:
self._key_id_pairs = self._load_index(split_path)
self._skip = 0
self.length = len(self._key_id_pairs)
@staticmethod
def _load_seq_data(data_path, split):
file_path = Path(data_path) / f"{split}.pickle"
with open(file_path, "rb") as f:
seq_data = pickle.load(f)
return seq_data
@staticmethod
def _full_index(seq_keys, seq_data, left_offset, extra_frames):
key_id_pairs = []
for k in seq_keys:
seq_len = len(seq_data[k]["timestamps"])
seq_key_id_pairs = [
(k, i + left_offset) for i in range(seq_len - extra_frames)
]
key_id_pairs += seq_key_id_pairs
return key_id_pairs
@staticmethod
def _load_index(index_path):
def get_key_id(s):
parts = s.split(" ")
key = parts[0]
id = int(parts[1])
return key, id
with open(index_path, "r") as f:
lines = f.readlines()
key_id_pairs = list(map(get_key_id, lines))
return key_id_pairs
def load_images(self, key, ids):
imgs = []
for id in ids:
timestamp = self._seq_data[key]["timestamps"][id]
img = (
cv2.cvtColor(
cv2.imread(
os.path.join(
self.data_path,
"frames",
self.split,
key,
f"{timestamp}.jpg",
)
),
cv2.COLOR_BGR2RGB,
).astype(np.float32)
/ 255
)
imgs += [img]
return imgs
def process_img(self, img: np.array, color_aug_fn=None):
if self.target_image_size:
img = cv2.resize(
img,
(self.target_image_size[1], self.target_image_size[0]),
interpolation=cv2.INTER_LINEAR,
)
img = np.transpose(img, (2, 0, 1))
if color_aug_fn is not None:
img = color_aug_fn(torch.tensor(img))
img = img * 2 - 1
return img
@staticmethod
def process_pose(pose):
pose = np.concatenate(
(pose.astype(np.float32), np.array([[0, 0, 0, 1]], dtype=np.float32)),
axis=0,
)
pose = np.linalg.inv(pose)
return pose
@staticmethod
def process_projs(proj):
K = np.eye(3, dtype=np.float32)
K[0, 0] = 2 * proj[0]
K[1, 1] = 2 * proj[1]
K[0, 2] = 2 * proj[2] - 1
K[1, 2] = 2 * proj[3] - 1
return K
def __getitem__(self, index: int):
_start_time = time.time()
if index >= self.length:
raise IndexError()
if self._skip != 0:
index += self._skip
if self.color_aug:
color_aug_fn = get_color_aug_fn(
ColorJitter.get_params(
brightness=(0.8, 1.2),
contrast=(0.8, 1.2),
saturation=(0.8, 1.2),
hue=(-0.1, 0.1),
)
)
else:
color_aug_fn = None
key, index = self._key_id_pairs[index]
seq_len = len(self._seq_data[key]["timestamps"])
if self.dilation == "random":
dilation = torch.randint(1, 30, (1,)).item()
left_offset = self._left_offset
if self.frame_count > 2:
left_offset = dilation * (self.frame_count // 2)
else:
dilation = self.dilation
left_offset = self._left_offset
ids = [index] + [
max(min(i, seq_len - 1), 0)
for i in range(
index - left_offset,
index - left_offset + self.frame_count * dilation,
dilation,
)
if i != index
]
imgs = self.load_images(key, ids)
imgs = [self.process_img(img, color_aug_fn=color_aug_fn) for img in imgs]
# These poses are camera to world !!
poses = [self.process_pose(self._seq_data[key]["poses"][i, :, :]) for i in ids]
projs = [
self.process_projs(self._seq_data[key]["intrinsics"][i, :]) for i in ids
]
depths = [np.ones_like(imgs[0][:1, :, :])]
_proc_time = np.array(time.time() - _start_time)
data = {
"imgs": imgs,
"projs": projs,
"poses": poses,
"depths": depths,
"t__get_item__": np.array([_proc_time]),
}
return data
def __len__(self) -> int:
return self.length