|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import numpy as np |
|
from tqdm import tqdm |
|
import collections |
|
import pickle |
|
import PIL.Image |
|
import torch |
|
from scipy.spatial.transform import Rotation |
|
import torchvision.transforms as tvf |
|
|
|
from kapture.core import CameraType |
|
from kapture.io.csv import kapture_from_dir |
|
from kapture_localization.utils.pairsfile import get_ordered_pairs_from_file |
|
|
|
from dust3r_visloc.datasets.utils import cam_to_world_from_kapture, get_resize_function, rescale_points3d |
|
from dust3r_visloc.datasets.base_dataset import BaseVislocDataset |
|
from dust3r.datasets.utils.transforms import ImgNorm |
|
from dust3r.utils.geometry import colmap_to_opencv_intrinsics |
|
|
|
KaptureSensor = collections.namedtuple('Sensor', 'sensor_params camera_params') |
|
|
|
|
|
def kapture_to_opencv_intrinsics(sensor): |
|
""" |
|
Convert from Kapture to OpenCV parameters. |
|
Warning: we assume that the camera and pixel coordinates follow Colmap conventions here. |
|
Args: |
|
sensor: Kapture sensor |
|
""" |
|
sensor_type = sensor.sensor_params[0] |
|
if sensor_type == "SIMPLE_PINHOLE": |
|
|
|
|
|
w, h, f, cx, cy = sensor.camera_params |
|
k1 = 0 |
|
k2 = 0 |
|
p1 = 0 |
|
p2 = 0 |
|
fx = fy = f |
|
elif sensor_type == "PINHOLE": |
|
w, h, fx, fy, cx, cy = sensor.camera_params |
|
k1 = 0 |
|
k2 = 0 |
|
p1 = 0 |
|
p2 = 0 |
|
elif sensor_type == "SIMPLE_RADIAL": |
|
w, h, f, cx, cy, k1 = sensor.camera_params |
|
k2 = 0 |
|
p1 = 0 |
|
p2 = 0 |
|
fx = fy = f |
|
elif sensor_type == "RADIAL": |
|
w, h, f, cx, cy, k1, k2 = sensor.camera_params |
|
p1 = 0 |
|
p2 = 0 |
|
fx = fy = f |
|
elif sensor_type == "OPENCV": |
|
w, h, fx, fy, cx, cy, k1, k2, p1, p2 = sensor.camera_params |
|
else: |
|
raise NotImplementedError(f"Sensor type {sensor_type} is not supported yet.") |
|
|
|
cameraMatrix = np.asarray([[fx, 0, cx], |
|
[0, fy, cy], |
|
[0, 0, 1]], dtype=np.float32) |
|
|
|
|
|
cameraMatrix = colmap_to_opencv_intrinsics(cameraMatrix) |
|
|
|
distCoeffs = np.asarray([k1, k2, p1, p2], dtype=np.float32) |
|
return cameraMatrix, distCoeffs, (w, h) |
|
|
|
|
|
def K_from_colmap(elems): |
|
sensor = KaptureSensor(elems, tuple(map(float, elems[1:]))) |
|
cameraMatrix, distCoeffs, (w, h) = kapture_to_opencv_intrinsics(sensor) |
|
res = dict(resolution=(w, h), |
|
intrinsics=cameraMatrix, |
|
distortion=distCoeffs) |
|
return res |
|
|
|
|
|
def pose_from_qwxyz_txyz(elems): |
|
qw, qx, qy, qz, tx, ty, tz = map(float, elems) |
|
pose = np.eye(4) |
|
pose[:3, :3] = Rotation.from_quat((qx, qy, qz, qw)).as_matrix() |
|
pose[:3, 3] = (tx, ty, tz) |
|
return np.linalg.inv(pose) |
|
|
|
|
|
class BaseVislocColmapDataset(BaseVislocDataset): |
|
def __init__(self, image_path, map_path, query_path, pairsfile_path, topk=1, cache_sfm=False): |
|
super().__init__() |
|
self.topk = topk |
|
self.num_views = self.topk + 1 |
|
self.image_path = image_path |
|
self.cache_sfm = cache_sfm |
|
|
|
self._load_sfm(map_path) |
|
|
|
kdata_query = kapture_from_dir(query_path) |
|
assert kdata_query.records_camera is not None and kdata_query.trajectories is not None |
|
|
|
kdata_query_searchindex = {kdata_query.records_camera[(timestamp, sensor_id)]: (timestamp, sensor_id) |
|
for timestamp, sensor_id in kdata_query.records_camera.key_pairs()} |
|
self.query_data = {'kdata': kdata_query, 'searchindex': kdata_query_searchindex} |
|
|
|
self.pairs = get_ordered_pairs_from_file(pairsfile_path) |
|
self.scenes = kdata_query.records_camera.data_list() |
|
|
|
def _load_sfm(self, sfm_dir): |
|
sfm_cache_path = os.path.join(sfm_dir, 'dust3r_cache.pkl') |
|
if os.path.isfile(sfm_cache_path) and self.cache_sfm: |
|
with open(sfm_cache_path, "rb") as f: |
|
data = pickle.load(f) |
|
self.img_infos = data['img_infos'] |
|
self.points3D = data['points3D'] |
|
return |
|
|
|
|
|
with open(os.path.join(sfm_dir, 'cameras.txt'), 'r') as f: |
|
raw = f.read().splitlines()[3:] |
|
|
|
intrinsics = {} |
|
for camera in tqdm(raw): |
|
camera = camera.split(' ') |
|
intrinsics[int(camera[0])] = K_from_colmap(camera[1:]) |
|
|
|
|
|
with open(os.path.join(sfm_dir, 'images.txt'), 'r') as f: |
|
raw = f.read().splitlines() |
|
raw = [line for line in raw if not line.startswith('#')] |
|
|
|
self.img_infos = {} |
|
for image, points in tqdm(zip(raw[0::2], raw[1::2]), total=len(raw) // 2): |
|
image = image.split(' ') |
|
points = points.split(' ') |
|
|
|
img_name = image[-1] |
|
current_points2D = {int(i): (float(x), float(y)) |
|
for i, x, y in zip(points[2::3], points[0::3], points[1::3]) if i != '-1'} |
|
self.img_infos[img_name] = dict(intrinsics[int(image[-2])], |
|
path=img_name, |
|
camera_pose=pose_from_qwxyz_txyz(image[1: -2]), |
|
sparse_pts2d=current_points2D) |
|
|
|
|
|
with open(os.path.join(sfm_dir, 'points3D.txt'), 'r') as f: |
|
raw = f.read().splitlines() |
|
raw = [line for line in raw if not line.startswith('#')] |
|
|
|
self.points3D = {} |
|
for point in tqdm(raw): |
|
point = point.split() |
|
self.points3D[int(point[0])] = tuple(map(float, point[1:4])) |
|
|
|
if self.cache_sfm: |
|
to_save = \ |
|
{ |
|
'img_infos': self.img_infos, |
|
'points3D': self.points3D |
|
} |
|
with open(sfm_cache_path, "wb") as f: |
|
pickle.dump(to_save, f) |
|
|
|
def __len__(self): |
|
return len(self.scenes) |
|
|
|
def _get_view_query(self, imgname): |
|
kdata, searchindex = map(self.query_data.get, ['kdata', 'searchindex']) |
|
|
|
timestamp, camera_id = searchindex[imgname] |
|
|
|
camera_params = kdata.sensors[camera_id].camera_params |
|
if kdata.sensors[camera_id].camera_type == CameraType.SIMPLE_PINHOLE: |
|
W, H, f, cx, cy = camera_params |
|
k1 = 0 |
|
fx = fy = f |
|
elif kdata.sensors[camera_id].camera_type == CameraType.SIMPLE_RADIAL: |
|
W, H, f, cx, cy, k1 = camera_params |
|
fx = fy = f |
|
else: |
|
raise NotImplementedError('not implemented') |
|
|
|
W, H = int(W), int(H) |
|
intrinsics = np.float32([(fx, 0, cx), |
|
(0, fy, cy), |
|
(0, 0, 1)]) |
|
intrinsics = colmap_to_opencv_intrinsics(intrinsics) |
|
distortion = [k1, 0, 0, 0] |
|
|
|
if kdata.trajectories is not None and (timestamp, camera_id) in kdata.trajectories: |
|
cam_to_world = cam_to_world_from_kapture(kdata, timestamp, camera_id) |
|
else: |
|
cam_to_world = np.eye(4, dtype=np.float32) |
|
|
|
|
|
rgb_image = PIL.Image.open(os.path.join(self.image_path, imgname)).convert('RGB') |
|
rgb_image.load() |
|
resize_func, _, to_orig = get_resize_function(self.maxdim, self.patch_size, H, W) |
|
rgb_tensor = resize_func(ImgNorm(rgb_image)) |
|
|
|
view = { |
|
'intrinsics': intrinsics, |
|
'distortion': distortion, |
|
'cam_to_world': cam_to_world, |
|
'rgb': rgb_image, |
|
'rgb_rescaled': rgb_tensor, |
|
'to_orig': to_orig, |
|
'idx': 0, |
|
'image_name': imgname |
|
} |
|
return view |
|
|
|
def _get_view_map(self, imgname, idx): |
|
infos = self.img_infos[imgname] |
|
|
|
rgb_image = PIL.Image.open(os.path.join(self.image_path, infos['path'])).convert('RGB') |
|
rgb_image.load() |
|
W, H = rgb_image.size |
|
intrinsics = infos['intrinsics'] |
|
intrinsics = colmap_to_opencv_intrinsics(intrinsics) |
|
distortion_coefs = infos['distortion'] |
|
|
|
pts2d = infos['sparse_pts2d'] |
|
sparse_pos2d = np.float32(list(pts2d.values())).reshape((-1, 2)) |
|
sparse_pts3d = np.float32([self.points3D[i] for i in pts2d]).reshape((-1, 3)) |
|
|
|
|
|
sparse_pos2d_cv2 = sparse_pos2d.copy() |
|
sparse_pos2d_cv2[:, 0] -= 0.5 |
|
sparse_pos2d_cv2[:, 1] -= 0.5 |
|
sparse_pos2d_int = sparse_pos2d_cv2.round().astype(np.int64) |
|
valid = (sparse_pos2d_int[:, 0] >= 0) & (sparse_pos2d_int[:, 0] < W) & ( |
|
sparse_pos2d_int[:, 1] >= 0) & (sparse_pos2d_int[:, 1] < H) |
|
sparse_pos2d_int = sparse_pos2d_int[valid] |
|
|
|
pts3d = np.full((H, W, 3), np.nan, dtype=np.float32) |
|
pts3d[sparse_pos2d_int[:, 1], sparse_pos2d_int[:, 0]] = sparse_pts3d[valid] |
|
pts3d = torch.from_numpy(pts3d) |
|
|
|
cam_to_world = infos['camera_pose'] |
|
|
|
|
|
resize_func, to_resize, to_orig = get_resize_function(self.maxdim, self.patch_size, H, W) |
|
rgb_tensor = resize_func(ImgNorm(rgb_image)) |
|
|
|
HR, WR = rgb_tensor.shape[1:] |
|
_, _, pts3d_rescaled, valid_rescaled = rescale_points3d(sparse_pos2d_cv2, sparse_pts3d, to_resize, HR, WR) |
|
pts3d_rescaled = torch.from_numpy(pts3d_rescaled) |
|
valid_rescaled = torch.from_numpy(valid_rescaled) |
|
|
|
view = { |
|
'intrinsics': intrinsics, |
|
'distortion': distortion_coefs, |
|
'cam_to_world': cam_to_world, |
|
'rgb': rgb_image, |
|
"pts3d": pts3d, |
|
"valid": pts3d.sum(dim=-1).isfinite(), |
|
'rgb_rescaled': rgb_tensor, |
|
"pts3d_rescaled": pts3d_rescaled, |
|
"valid_rescaled": valid_rescaled, |
|
'to_orig': to_orig, |
|
'idx': idx, |
|
'image_name': imgname |
|
} |
|
return view |
|
|
|
def __getitem__(self, idx): |
|
assert self.maxdim is not None and self.patch_size is not None |
|
query_image = self.scenes[idx] |
|
map_images = [p[0] for p in self.pairs[query_image][:self.topk]] |
|
views = [] |
|
views.append(self._get_view_query(query_image)) |
|
for idx, map_image in enumerate(map_images): |
|
views.append(self._get_view_map(map_image, idx + 1)) |
|
return views |
|
|