liguang0115's picture
Add initial project structure with core files, configurations, and sample images
2df809d
raw
history blame
7.64 kB
import os.path as osp
import json
import itertools
from collections import deque
import sys
sys.path.append(osp.join(osp.dirname(__file__), "..", ".."))
import cv2
import numpy as np
import time
from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset
from dust3r.utils.image import imread_cv2
class Co3d_Multi(BaseMultiViewDataset):
def __init__(self, mask_bg="rand", *args, ROOT, **kwargs):
self.ROOT = ROOT
super().__init__(*args, **kwargs)
assert mask_bg in (True, False, "rand")
self.mask_bg = mask_bg
self.is_metric = False
self.dataset_label = "Co3d_v2"
# load all scenes
with open(osp.join(self.ROOT, f"selected_seqs_{self.split}.json"), "r") as f:
self.scenes = json.load(f)
self.scenes = {k: v for k, v in self.scenes.items() if len(v) > 0}
self.scenes = {
(k, k2): v2 for k, v in self.scenes.items() for k2, v2 in v.items()
}
self.scene_list = list(self.scenes.keys())
cut_off = (
self.num_views if not self.allow_repeat else max(self.num_views // 3, 3)
)
self.cut_off = cut_off
self.all_ref_imgs = [
(key, value)
for key, values in self.scenes.items()
for value in values[: len(values) - cut_off + 1]
]
self.invalidate = {scene: {} for scene in self.scene_list}
self.invalid_scenes = {scene: False for scene in self.scene_list}
def __len__(self):
return len(self.all_ref_imgs)
def _get_metadatapath(self, obj, instance, view_idx):
return osp.join(self.ROOT, obj, instance, "images", f"frame{view_idx:06n}.npz")
def _get_impath(self, obj, instance, view_idx):
return osp.join(self.ROOT, obj, instance, "images", f"frame{view_idx:06n}.jpg")
def _get_depthpath(self, obj, instance, view_idx):
return osp.join(
self.ROOT, obj, instance, "depths", f"frame{view_idx:06n}.jpg.geometric.png"
)
def _get_maskpath(self, obj, instance, view_idx):
return osp.join(self.ROOT, obj, instance, "masks", f"frame{view_idx:06n}.png")
def _read_depthmap(self, depthpath, input_metadata):
depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED)
depthmap = (depthmap.astype(np.float32) / 65535) * np.nan_to_num(
input_metadata["maximum_depth"]
)
return depthmap
def _get_views(self, idx, resolution, rng, num_views):
invalid_seq = True
scene_info, ref_img_idx = self.all_ref_imgs[idx]
while invalid_seq:
while self.invalid_scenes[scene_info]:
idx = rng.integers(low=0, high=len(self.all_ref_imgs))
scene_info, ref_img_idx = self.all_ref_imgs[idx]
obj, instance = scene_info
image_pool = self.scenes[obj, instance]
if len(image_pool) < self.cut_off:
print("Invalid scene!")
self.invalid_scenes[scene_info] = True
continue
imgs_idxs, ordered_video = self.get_seq_from_start_id(
num_views, ref_img_idx, image_pool, rng
)
if resolution not in self.invalidate[obj, instance]: # flag invalid images
self.invalidate[obj, instance][resolution] = [
False for _ in range(len(image_pool))
]
# decide now if we mask the bg
mask_bg = (self.mask_bg == True) or (
self.mask_bg == "rand" and rng.choice(2, p=[0.9, 0.1])
)
views = []
imgs_idxs = deque(imgs_idxs)
while len(imgs_idxs) > 0: # some images (few) have zero depth
if (
len(image_pool) - sum(self.invalidate[obj, instance][resolution])
< self.cut_off
):
print("Invalid scene!")
invalid_seq = True
self.invalid_scenes[scene_info] = True
break
im_idx = imgs_idxs.pop()
if self.invalidate[obj, instance][resolution][im_idx]:
# search for a valid image
ordered_video = False
random_direction = 2 * rng.choice(2) - 1
for offset in range(1, len(image_pool)):
tentative_im_idx = (im_idx + (random_direction * offset)) % len(
image_pool
)
if not self.invalidate[obj, instance][resolution][
tentative_im_idx
]:
im_idx = tentative_im_idx
break
view_idx = image_pool[im_idx]
impath = self._get_impath(obj, instance, view_idx)
depthpath = self._get_depthpath(obj, instance, view_idx)
# load camera params
metadata_path = self._get_metadatapath(obj, instance, view_idx)
input_metadata = np.load(metadata_path)
camera_pose = input_metadata["camera_pose"].astype(np.float32)
intrinsics = input_metadata["camera_intrinsics"].astype(np.float32)
# load image and depth
rgb_image = imread_cv2(impath)
depthmap = self._read_depthmap(depthpath, input_metadata)
if mask_bg:
# load object mask
maskpath = self._get_maskpath(obj, instance, view_idx)
maskmap = imread_cv2(maskpath, cv2.IMREAD_UNCHANGED).astype(
np.float32
)
maskmap = (maskmap / 255.0) > 0.1
# update the depthmap with mask
depthmap *= maskmap
rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
rgb_image, depthmap, intrinsics, resolution, rng=rng, info=impath
)
num_valid = (depthmap > 0.0).sum()
if num_valid == 0:
# problem, invalidate image and retry
self.invalidate[obj, instance][resolution][im_idx] = True
imgs_idxs.append(im_idx)
continue
# generate img mask and raymap mask
img_mask, ray_mask = self.get_img_and_ray_masks(
self.is_metric, len(views), rng
)
views.append(
dict(
img=rgb_image,
depthmap=depthmap,
camera_pose=camera_pose,
camera_intrinsics=intrinsics,
dataset=self.dataset_label,
label=osp.join(obj, instance),
instance=osp.split(impath)[1],
is_metric=self.is_metric,
is_video=ordered_video,
quantile=np.array(0.9, dtype=np.float32),
img_mask=img_mask,
ray_mask=ray_mask,
camera_only=False,
depth_only=False,
single_view=False,
reset=False,
)
)
if len(views) == num_views and not all(
[view["instance"] == views[0]["instance"] for view in views]
):
invalid_seq = False
return views