vmem / extern /CUT3R /src /dust3r /datasets /base /base_multiview_dataset.py
liguang0115's picture
Add initial project structure with core files, configurations, and sample images
2df809d
import PIL
import numpy as np
import torch
import random
import itertools
from dust3r.datasets.base.easy_dataset import EasyDataset
from dust3r.datasets.utils.transforms import ImgNorm, SeqColorJitter
from dust3r.utils.geometry import depthmap_to_absolute_camera_coordinates
import dust3r.datasets.utils.cropping as cropping
from dust3r.datasets.utils.corr import extract_correspondences_from_pts3d
def get_ray_map(c2w1, c2w2, intrinsics, h, w):
c2w = np.linalg.inv(c2w1) @ c2w2
i, j = np.meshgrid(np.arange(w), np.arange(h), indexing="xy")
grid = np.stack([i, j, np.ones_like(i)], axis=-1)
ro = c2w[:3, 3]
rd = np.linalg.inv(intrinsics) @ grid.reshape(-1, 3).T
rd = (c2w @ np.vstack([rd, np.ones_like(rd[0])])).T[:, :3].reshape(h, w, 3)
rd = rd / np.linalg.norm(rd, axis=-1, keepdims=True)
ro = np.broadcast_to(ro, (h, w, 3))
ray_map = np.concatenate([ro, rd], axis=-1)
return ray_map
class BaseMultiViewDataset(EasyDataset):
"""Define all basic options.
Usage:
class MyDataset (BaseMultiViewDataset):
def _get_views(self, idx, rng):
# overload here
views = []
views.append(dict(img=, ...))
return views
"""
def __init__(
self,
*, # only keyword arguments
num_views=None,
split=None,
resolution=None, # square_size or (width, height) or list of [(width,height), ...]
transform=ImgNorm,
aug_crop=False,
n_corres=0,
nneg=0,
seed=None,
allow_repeat=False,
seq_aug_crop=False,
):
assert num_views is not None, "undefined num_views"
self.num_views = num_views
self.split = split
self._set_resolutions(resolution)
self.n_corres = n_corres
self.nneg = nneg
assert (
self.n_corres == "all"
or isinstance(self.n_corres, int)
or (
isinstance(self.n_corres, list) and len(self.n_corres) == self.num_views
)
), f"Error, n_corres should either be 'all', a single integer or a list of length {self.num_views}"
assert (
self.nneg == 0 or self.n_corres != "all"
), "nneg should be 0 if n_corres is all"
self.is_seq_color_jitter = False
if isinstance(transform, str):
transform = eval(transform)
if transform == SeqColorJitter:
transform = SeqColorJitter()
self.is_seq_color_jitter = True
self.transform = transform
self.aug_crop = aug_crop
self.seed = seed
self.allow_repeat = allow_repeat
self.seq_aug_crop = seq_aug_crop
def __len__(self):
return len(self.scenes)
@staticmethod
def efficient_random_intervals(
start,
num_elements,
interval_range,
fixed_interval_prob=0.8,
weights=None,
seed=42,
):
if random.random() < fixed_interval_prob:
intervals = random.choices(interval_range, weights=weights) * (
num_elements - 1
)
else:
intervals = [
random.choices(interval_range, weights=weights)[0]
for _ in range(num_elements - 1)
]
return list(itertools.accumulate([start] + intervals))
def sample_based_on_timestamps(self, i, timestamps, num_views, interval=1):
time_diffs = np.abs(timestamps - timestamps[i])
ids_candidate = np.where(time_diffs < interval)[0]
ids_candidate = np.sort(ids_candidate)
if (self.allow_repeat and len(ids_candidate) < num_views // 3) or (
len(ids_candidate) < num_views
):
return []
ids_sel_list = []
ids_candidate_left = ids_candidate.copy()
while len(ids_candidate_left) >= num_views:
ids_sel = np.random.choice(ids_candidate_left, num_views, replace=False)
ids_sel_list.append(sorted(ids_sel))
ids_candidate_left = np.setdiff1d(ids_candidate_left, ids_sel)
if len(ids_candidate_left) > 0 and len(ids_candidate) >= num_views:
ids_sel = np.concatenate(
[
ids_candidate_left,
np.random.choice(
np.setdiff1d(ids_candidate, ids_candidate_left),
num_views - len(ids_candidate_left),
replace=False,
),
]
)
ids_sel_list.append(sorted(ids_sel))
if self.allow_repeat:
ids_sel_list.append(
sorted(np.random.choice(ids_candidate, num_views, replace=True))
)
# add sequences with fixed intervals (all possible intervals)
pos_i = np.where(ids_candidate == i)[0][0]
curr_interval = 1
stop = len(ids_candidate) < num_views
while not stop:
pos_sel = [pos_i]
count = 0
while len(pos_sel) < num_views:
if count % 2 == 0:
curr_pos_i = pos_sel[-1] + curr_interval
if curr_pos_i >= len(ids_candidate):
stop = True
break
pos_sel.append(curr_pos_i)
else:
curr_pos_i = pos_sel[0] - curr_interval
if curr_pos_i < 0:
stop = True
break
pos_sel.insert(0, curr_pos_i)
count += 1
if not stop and len(pos_sel) == num_views:
ids_sel = sorted([ids_candidate[pos] for pos in pos_sel])
if ids_sel not in ids_sel_list:
ids_sel_list.append(ids_sel)
curr_interval += 1
return ids_sel_list
@staticmethod
def blockwise_shuffle(x, rng, block_shuffle):
if block_shuffle is None:
return rng.permutation(x).tolist()
else:
assert block_shuffle > 0
blocks = [x[i : i + block_shuffle] for i in range(0, len(x), block_shuffle)]
shuffled_blocks = [rng.permutation(block).tolist() for block in blocks]
shuffled_list = [item for block in shuffled_blocks for item in block]
return shuffled_list
def get_seq_from_start_id(
self,
num_views,
id_ref,
ids_all,
rng,
min_interval=1,
max_interval=25,
video_prob=0.5,
fix_interval_prob=0.5,
block_shuffle=None,
):
"""
args:
num_views: number of views to return
id_ref: the reference id (first id)
ids_all: all the ids
rng: random number generator
max_interval: maximum interval between two views
returns:
pos: list of positions of the views in ids_all, i.e., index for ids_all
is_video: True if the views are consecutive
"""
assert min_interval > 0, f"min_interval should be > 0, got {min_interval}"
assert (
min_interval <= max_interval
), f"min_interval should be <= max_interval, got {min_interval} and {max_interval}"
assert id_ref in ids_all
pos_ref = ids_all.index(id_ref)
all_possible_pos = np.arange(pos_ref, len(ids_all))
remaining_sum = len(ids_all) - 1 - pos_ref
if remaining_sum >= num_views - 1:
if remaining_sum == num_views - 1:
assert ids_all[-num_views] == id_ref
return [pos_ref + i for i in range(num_views)], True
max_interval = min(max_interval, 2 * remaining_sum // (num_views - 1))
intervals = [
rng.choice(range(min_interval, max_interval + 1))
for _ in range(num_views - 1)
]
# if video or collection
if rng.random() < video_prob:
# if fixed interval or random
if rng.random() < fix_interval_prob:
# regular interval
fixed_interval = rng.choice(
range(
1,
min(remaining_sum // (num_views - 1) + 1, max_interval + 1),
)
)
intervals = [fixed_interval for _ in range(num_views - 1)]
is_video = True
else:
is_video = False
pos = list(itertools.accumulate([pos_ref] + intervals))
pos = [p for p in pos if p < len(ids_all)]
pos_candidates = [p for p in all_possible_pos if p not in pos]
pos = (
pos
+ rng.choice(
pos_candidates, num_views - len(pos), replace=False
).tolist()
)
pos = (
sorted(pos)
if is_video
else self.blockwise_shuffle(pos, rng, block_shuffle)
)
else:
# assert self.allow_repeat
uniq_num = remaining_sum
new_pos_ref = rng.choice(np.arange(pos_ref + 1))
new_remaining_sum = len(ids_all) - 1 - new_pos_ref
new_max_interval = min(max_interval, new_remaining_sum // (uniq_num - 1))
new_intervals = [
rng.choice(range(1, new_max_interval + 1)) for _ in range(uniq_num - 1)
]
revisit_random = rng.random()
video_random = rng.random()
if rng.random() < fix_interval_prob and video_random < video_prob:
# regular interval
fixed_interval = rng.choice(range(1, new_max_interval + 1))
new_intervals = [fixed_interval for _ in range(uniq_num - 1)]
pos = list(itertools.accumulate([new_pos_ref] + new_intervals))
is_video = False
if revisit_random < 0.5 or video_prob == 1.0: # revisit, video / collection
is_video = video_random < video_prob
pos = (
self.blockwise_shuffle(pos, rng, block_shuffle)
if not is_video
else pos
)
num_full_repeat = num_views // uniq_num
pos = (
pos * num_full_repeat
+ pos[: num_views - len(pos) * num_full_repeat]
)
elif revisit_random < 0.9: # random
pos = rng.choice(pos, num_views, replace=True)
else: # ordered
pos = sorted(rng.choice(pos, num_views, replace=True))
assert len(pos) == num_views
return pos, is_video
def get_img_and_ray_masks(self, is_metric, v, rng, p=[0.8, 0.15, 0.05]):
# generate img mask and raymap mask
if v == 0 or (not is_metric):
img_mask = True
raymap_mask = False
else:
rand_val = rng.random()
if rand_val < p[0]:
img_mask = True
raymap_mask = False
elif rand_val < p[0] + p[1]:
img_mask = False
raymap_mask = True
else:
img_mask = True
raymap_mask = True
return img_mask, raymap_mask
def get_stats(self):
return f"{len(self)} groups of views"
def __repr__(self):
resolutions_str = "[" + ";".join(f"{w}x{h}" for w, h in self._resolutions) + "]"
return (
f"""{type(self).__name__}({self.get_stats()},
{self.num_views=},
{self.split=},
{self.seed=},
resolutions={resolutions_str},
{self.transform=})""".replace(
"self.", ""
)
.replace("\n", "")
.replace(" ", "")
)
def _get_views(self, idx, resolution, rng, num_views):
raise NotImplementedError()
def __getitem__(self, idx):
# print("Receiving:" , idx)
if isinstance(idx, (tuple, list, np.ndarray)):
# the idx is specifying the aspect-ratio
idx, ar_idx, nview = idx
else:
assert len(self._resolutions) == 1
ar_idx = 0
nview = self.num_views
assert nview >= 1 and nview <= self.num_views
# set-up the rng
if self.seed: # reseed for each __getitem__
self._rng = np.random.default_rng(seed=self.seed + idx)
elif not hasattr(self, "_rng"):
seed = torch.randint(0, 2**32, (1,)).item()
self._rng = np.random.default_rng(seed=seed)
if self.aug_crop > 1 and self.seq_aug_crop:
self.delta_target_resolution = self._rng.integers(0, self.aug_crop)
# over-loaded code
resolution = self._resolutions[
ar_idx
] # DO NOT CHANGE THIS (compatible with BatchedRandomSampler)
views = self._get_views(idx, resolution, self._rng, nview)
assert len(views) == nview
if "camera_pose" not in views[0]:
views[0]["camera_pose"] = np.ones((4, 4), dtype=np.float32)
first_view_camera_pose = views[0]["camera_pose"]
transform = SeqColorJitter() if self.is_seq_color_jitter else self.transform
for v, view in enumerate(views):
assert (
"pts3d" not in view
), f"pts3d should not be there, they will be computed afterwards based on intrinsics+depthmap for view {view_name(view)}"
view["idx"] = (idx, ar_idx, v)
# encode the image
width, height = view["img"].size
view["true_shape"] = np.int32((height, width))
view["img"] = transform(view["img"])
view["sky_mask"] = view["depthmap"] < 0
assert "camera_intrinsics" in view
if "camera_pose" not in view:
view["camera_pose"] = np.full((4, 4), np.nan, dtype=np.float32)
else:
assert np.isfinite(
view["camera_pose"]
).all(), f"NaN in camera pose for view {view_name(view)}"
ray_map = get_ray_map(
first_view_camera_pose,
view["camera_pose"],
view["camera_intrinsics"],
height,
width,
)
view["ray_map"] = ray_map.astype(np.float32)
assert "pts3d" not in view
assert "valid_mask" not in view
assert np.isfinite(
view["depthmap"]
).all(), f"NaN in depthmap for view {view_name(view)}"
pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**view)
view["pts3d"] = pts3d
view["valid_mask"] = valid_mask & np.isfinite(pts3d).all(axis=-1)
# check all datatypes
for key, val in view.items():
res, err_msg = is_good_type(key, val)
assert res, f"{err_msg} with {key}={val} for view {view_name(view)}"
K = view["camera_intrinsics"]
if self.n_corres > 0:
ref_view = views[0]
for view in views:
corres1, corres2, valid = extract_correspondences_from_pts3d(
ref_view, view, self.n_corres, self._rng, nneg=self.nneg
)
view["corres"] = (corres1, corres2)
view["valid_corres"] = valid
# last thing done!
for view in views:
view["rng"] = int.from_bytes(self._rng.bytes(4), "big")
return views
def _set_resolutions(self, resolutions):
assert resolutions is not None, "undefined resolution"
if not isinstance(resolutions, list):
resolutions = [resolutions]
self._resolutions = []
for resolution in resolutions:
if isinstance(resolution, int):
width = height = resolution
else:
width, height = resolution
assert isinstance(
width, int
), f"Bad type for {width=} {type(width)=}, should be int"
assert isinstance(
height, int
), f"Bad type for {height=} {type(height)=}, should be int"
self._resolutions.append((width, height))
def _crop_resize_if_necessary(
self, image, depthmap, intrinsics, resolution, rng=None, info=None
):
"""This function:
- first downsizes the image with LANCZOS inteprolation,
which is better than bilinear interpolation in
"""
if not isinstance(image, PIL.Image.Image):
image = PIL.Image.fromarray(image)
# downscale with lanczos interpolation so that image.size == resolution
# cropping centered on the principal point
W, H = image.size
cx, cy = intrinsics[:2, 2].round().astype(int)
min_margin_x = min(cx, W - cx)
min_margin_y = min(cy, H - cy)
assert min_margin_x > W / 5, f"Bad principal point in view={info}"
assert min_margin_y > H / 5, f"Bad principal point in view={info}"
# the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy)
l, t = cx - min_margin_x, cy - min_margin_y
r, b = cx + min_margin_x, cy + min_margin_y
crop_bbox = (l, t, r, b)
image, depthmap, intrinsics = cropping.crop_image_depthmap(
image, depthmap, intrinsics, crop_bbox
)
# transpose the resolution if necessary
W, H = image.size # new size
# high-quality Lanczos down-scaling
target_resolution = np.array(resolution)
if self.aug_crop > 1:
target_resolution += (
rng.integers(0, self.aug_crop)
if not self.seq_aug_crop
else self.delta_target_resolution
)
image, depthmap, intrinsics = cropping.rescale_image_depthmap(
image, depthmap, intrinsics, target_resolution
)
# actual cropping (if necessary) with bilinear interpolation
intrinsics2 = cropping.camera_matrix_of_crop(
intrinsics, image.size, resolution, offset_factor=0.5
)
crop_bbox = cropping.bbox_from_intrinsics_in_out(
intrinsics, intrinsics2, resolution
)
image, depthmap, intrinsics2 = cropping.crop_image_depthmap(
image, depthmap, intrinsics, crop_bbox
)
return image, depthmap, intrinsics2
def is_good_type(key, v):
"""returns (is_good, err_msg)"""
if isinstance(v, (str, int, tuple)):
return True, None
if v.dtype not in (np.float32, torch.float32, bool, np.int32, np.int64, np.uint8):
return False, f"bad {v.dtype=}"
return True, None
def view_name(view, batch_index=None):
def sel(x):
return x[batch_index] if batch_index not in (None, slice(None)) else x
db = sel(view["dataset"])
label = sel(view["label"])
instance = sel(view["instance"])
return f"{db}/{label}/{instance}"
def transpose_to_landscape(view):
height, width = view["true_shape"]
if width < height:
# rectify portrait to landscape
assert view["img"].shape == (3, height, width)
view["img"] = view["img"].swapaxes(1, 2)
assert view["valid_mask"].shape == (height, width)
view["valid_mask"] = view["valid_mask"].swapaxes(0, 1)
assert view["depthmap"].shape == (height, width)
view["depthmap"] = view["depthmap"].swapaxes(0, 1)
assert view["pts3d"].shape == (height, width, 3)
view["pts3d"] = view["pts3d"].swapaxes(0, 1)
# transpose x and y pixels
view["camera_intrinsics"] = view["camera_intrinsics"][[1, 0, 2]]
assert view["ray_map"].shape == (height, width, 6)
view["ray_map"] = view["ray_map"].swapaxes(0, 1)
assert view["sky_mask"].shape == (height, width)
view["sky_mask"] = view["sky_mask"].swapaxes(0, 1)
if "corres" in view:
# transpose correspondences x and y
view["corres"][0] = view["corres"][0][:, [1, 0]]
view["corres"][1] = view["corres"][1][:, [1, 0]]