Spaces:
Running
on
L40S
Running
on
L40S
| # Copyright (C) 2024-present Naver Corporation. All rights reserved. | |
| # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). | |
| # | |
| # -------------------------------------------------------- | |
| # Base class for colmap / kapture | |
| # -------------------------------------------------------- | |
| import os | |
| import numpy as np | |
| from tqdm import tqdm | |
| import collections | |
| import pickle | |
| import PIL.Image | |
| import torch | |
| from scipy.spatial.transform import Rotation | |
| import torchvision.transforms as tvf | |
| from kapture.core import CameraType | |
| from kapture.io.csv import kapture_from_dir | |
| from kapture_localization.utils.pairsfile import get_ordered_pairs_from_file | |
| from dust3r_visloc.datasets.utils import cam_to_world_from_kapture, get_resize_function, rescale_points3d | |
| from dust3r_visloc.datasets.base_dataset import BaseVislocDataset | |
| from dust3r.datasets.utils.transforms import ImgNorm | |
| from dust3r.utils.geometry import colmap_to_opencv_intrinsics | |
| KaptureSensor = collections.namedtuple('Sensor', 'sensor_params camera_params') | |
| def kapture_to_opencv_intrinsics(sensor): | |
| """ | |
| Convert from Kapture to OpenCV parameters. | |
| Warning: we assume that the camera and pixel coordinates follow Colmap conventions here. | |
| Args: | |
| sensor: Kapture sensor | |
| """ | |
| sensor_type = sensor.sensor_params[0] | |
| if sensor_type == "SIMPLE_PINHOLE": | |
| # Simple pinhole model. | |
| # We still call OpenCV undistorsion however for code simplicity. | |
| w, h, f, cx, cy = sensor.camera_params | |
| k1 = 0 | |
| k2 = 0 | |
| p1 = 0 | |
| p2 = 0 | |
| fx = fy = f | |
| elif sensor_type == "PINHOLE": | |
| w, h, fx, fy, cx, cy = sensor.camera_params | |
| k1 = 0 | |
| k2 = 0 | |
| p1 = 0 | |
| p2 = 0 | |
| elif sensor_type == "SIMPLE_RADIAL": | |
| w, h, f, cx, cy, k1 = sensor.camera_params | |
| k2 = 0 | |
| p1 = 0 | |
| p2 = 0 | |
| fx = fy = f | |
| elif sensor_type == "RADIAL": | |
| w, h, f, cx, cy, k1, k2 = sensor.camera_params | |
| p1 = 0 | |
| p2 = 0 | |
| fx = fy = f | |
| elif sensor_type == "OPENCV": | |
| w, h, fx, fy, cx, cy, k1, k2, p1, p2 = sensor.camera_params | |
| else: | |
| raise NotImplementedError(f"Sensor type {sensor_type} is not supported yet.") | |
| cameraMatrix = np.asarray([[fx, 0, cx], | |
| [0, fy, cy], | |
| [0, 0, 1]], dtype=np.float32) | |
| # We assume that Kapture data comes from Colmap: the origin is different. | |
| cameraMatrix = colmap_to_opencv_intrinsics(cameraMatrix) | |
| distCoeffs = np.asarray([k1, k2, p1, p2], dtype=np.float32) | |
| return cameraMatrix, distCoeffs, (w, h) | |
| def K_from_colmap(elems): | |
| sensor = KaptureSensor(elems, tuple(map(float, elems[1:]))) | |
| cameraMatrix, distCoeffs, (w, h) = kapture_to_opencv_intrinsics(sensor) | |
| res = dict(resolution=(w, h), | |
| intrinsics=cameraMatrix, | |
| distortion=distCoeffs) | |
| return res | |
| def pose_from_qwxyz_txyz(elems): | |
| qw, qx, qy, qz, tx, ty, tz = map(float, elems) | |
| pose = np.eye(4) | |
| pose[:3, :3] = Rotation.from_quat((qx, qy, qz, qw)).as_matrix() | |
| pose[:3, 3] = (tx, ty, tz) | |
| return np.linalg.inv(pose) # returns cam2world | |
| class BaseVislocColmapDataset(BaseVislocDataset): | |
| def __init__(self, image_path, map_path, query_path, pairsfile_path, topk=1, cache_sfm=False): | |
| super().__init__() | |
| self.topk = topk | |
| self.num_views = self.topk + 1 | |
| self.image_path = image_path | |
| self.cache_sfm = cache_sfm | |
| self._load_sfm(map_path) | |
| kdata_query = kapture_from_dir(query_path) | |
| assert kdata_query.records_camera is not None and kdata_query.trajectories is not None | |
| kdata_query_searchindex = {kdata_query.records_camera[(timestamp, sensor_id)]: (timestamp, sensor_id) | |
| for timestamp, sensor_id in kdata_query.records_camera.key_pairs()} | |
| self.query_data = {'kdata': kdata_query, 'searchindex': kdata_query_searchindex} | |
| self.pairs = get_ordered_pairs_from_file(pairsfile_path) | |
| self.scenes = kdata_query.records_camera.data_list() | |
| def _load_sfm(self, sfm_dir): | |
| sfm_cache_path = os.path.join(sfm_dir, 'dust3r_cache.pkl') | |
| if os.path.isfile(sfm_cache_path) and self.cache_sfm: | |
| with open(sfm_cache_path, "rb") as f: | |
| data = pickle.load(f) | |
| self.img_infos = data['img_infos'] | |
| self.points3D = data['points3D'] | |
| return | |
| # load cameras | |
| with open(os.path.join(sfm_dir, 'cameras.txt'), 'r') as f: | |
| raw = f.read().splitlines()[3:] # skip header | |
| intrinsics = {} | |
| for camera in tqdm(raw): | |
| camera = camera.split(' ') | |
| intrinsics[int(camera[0])] = K_from_colmap(camera[1:]) | |
| # load images | |
| with open(os.path.join(sfm_dir, 'images.txt'), 'r') as f: | |
| raw = f.read().splitlines() | |
| raw = [line for line in raw if not line.startswith('#')] # skip header | |
| self.img_infos = {} | |
| for image, points in tqdm(zip(raw[0::2], raw[1::2]), total=len(raw) // 2): | |
| image = image.split(' ') | |
| points = points.split(' ') | |
| img_name = image[-1] | |
| current_points2D = {int(i): (float(x), float(y)) | |
| for i, x, y in zip(points[2::3], points[0::3], points[1::3]) if i != '-1'} | |
| self.img_infos[img_name] = dict(intrinsics[int(image[-2])], | |
| path=img_name, | |
| camera_pose=pose_from_qwxyz_txyz(image[1: -2]), | |
| sparse_pts2d=current_points2D) | |
| # load 3D points | |
| with open(os.path.join(sfm_dir, 'points3D.txt'), 'r') as f: | |
| raw = f.read().splitlines() | |
| raw = [line for line in raw if not line.startswith('#')] # skip header | |
| self.points3D = {} | |
| for point in tqdm(raw): | |
| point = point.split() | |
| self.points3D[int(point[0])] = tuple(map(float, point[1:4])) | |
| if self.cache_sfm: | |
| to_save = \ | |
| { | |
| 'img_infos': self.img_infos, | |
| 'points3D': self.points3D | |
| } | |
| with open(sfm_cache_path, "wb") as f: | |
| pickle.dump(to_save, f) | |
| def __len__(self): | |
| return len(self.scenes) | |
| def _get_view_query(self, imgname): | |
| kdata, searchindex = map(self.query_data.get, ['kdata', 'searchindex']) | |
| timestamp, camera_id = searchindex[imgname] | |
| camera_params = kdata.sensors[camera_id].camera_params | |
| if kdata.sensors[camera_id].camera_type == CameraType.SIMPLE_PINHOLE: | |
| W, H, f, cx, cy = camera_params | |
| k1 = 0 | |
| fx = fy = f | |
| elif kdata.sensors[camera_id].camera_type == CameraType.SIMPLE_RADIAL: | |
| W, H, f, cx, cy, k1 = camera_params | |
| fx = fy = f | |
| else: | |
| raise NotImplementedError('not implemented') | |
| W, H = int(W), int(H) | |
| intrinsics = np.float32([(fx, 0, cx), | |
| (0, fy, cy), | |
| (0, 0, 1)]) | |
| intrinsics = colmap_to_opencv_intrinsics(intrinsics) | |
| distortion = [k1, 0, 0, 0] | |
| if kdata.trajectories is not None and (timestamp, camera_id) in kdata.trajectories: | |
| cam_to_world = cam_to_world_from_kapture(kdata, timestamp, camera_id) | |
| else: | |
| cam_to_world = np.eye(4, dtype=np.float32) | |
| # Load RGB image | |
| rgb_image = PIL.Image.open(os.path.join(self.image_path, imgname)).convert('RGB') | |
| rgb_image.load() | |
| resize_func, _, to_orig = get_resize_function(self.maxdim, self.patch_size, H, W) | |
| rgb_tensor = resize_func(ImgNorm(rgb_image)) | |
| view = { | |
| 'intrinsics': intrinsics, | |
| 'distortion': distortion, | |
| 'cam_to_world': cam_to_world, | |
| 'rgb': rgb_image, | |
| 'rgb_rescaled': rgb_tensor, | |
| 'to_orig': to_orig, | |
| 'idx': 0, | |
| 'image_name': imgname | |
| } | |
| return view | |
| def _get_view_map(self, imgname, idx): | |
| infos = self.img_infos[imgname] | |
| rgb_image = PIL.Image.open(os.path.join(self.image_path, infos['path'])).convert('RGB') | |
| rgb_image.load() | |
| W, H = rgb_image.size | |
| intrinsics = infos['intrinsics'] | |
| intrinsics = colmap_to_opencv_intrinsics(intrinsics) | |
| distortion_coefs = infos['distortion'] | |
| pts2d = infos['sparse_pts2d'] | |
| sparse_pos2d = np.float32(list(pts2d.values())).reshape((-1, 2)) # pts2d from colmap | |
| sparse_pts3d = np.float32([self.points3D[i] for i in pts2d]).reshape((-1, 3)) | |
| # store full resolution 2D->3D | |
| sparse_pos2d_cv2 = sparse_pos2d.copy() | |
| sparse_pos2d_cv2[:, 0] -= 0.5 | |
| sparse_pos2d_cv2[:, 1] -= 0.5 | |
| sparse_pos2d_int = sparse_pos2d_cv2.round().astype(np.int64) | |
| valid = (sparse_pos2d_int[:, 0] >= 0) & (sparse_pos2d_int[:, 0] < W) & ( | |
| sparse_pos2d_int[:, 1] >= 0) & (sparse_pos2d_int[:, 1] < H) | |
| sparse_pos2d_int = sparse_pos2d_int[valid] | |
| # nan => invalid | |
| pts3d = np.full((H, W, 3), np.nan, dtype=np.float32) | |
| pts3d[sparse_pos2d_int[:, 1], sparse_pos2d_int[:, 0]] = sparse_pts3d[valid] | |
| pts3d = torch.from_numpy(pts3d) | |
| cam_to_world = infos['camera_pose'] # cam2world | |
| # also store resized resolution 2D->3D | |
| resize_func, to_resize, to_orig = get_resize_function(self.maxdim, self.patch_size, H, W) | |
| rgb_tensor = resize_func(ImgNorm(rgb_image)) | |
| HR, WR = rgb_tensor.shape[1:] | |
| _, _, pts3d_rescaled, valid_rescaled = rescale_points3d(sparse_pos2d_cv2, sparse_pts3d, to_resize, HR, WR) | |
| pts3d_rescaled = torch.from_numpy(pts3d_rescaled) | |
| valid_rescaled = torch.from_numpy(valid_rescaled) | |
| view = { | |
| 'intrinsics': intrinsics, | |
| 'distortion': distortion_coefs, | |
| 'cam_to_world': cam_to_world, | |
| 'rgb': rgb_image, | |
| "pts3d": pts3d, | |
| "valid": pts3d.sum(dim=-1).isfinite(), | |
| 'rgb_rescaled': rgb_tensor, | |
| "pts3d_rescaled": pts3d_rescaled, | |
| "valid_rescaled": valid_rescaled, | |
| 'to_orig': to_orig, | |
| 'idx': idx, | |
| 'image_name': imgname | |
| } | |
| return view | |
| def __getitem__(self, idx): | |
| assert self.maxdim is not None and self.patch_size is not None | |
| query_image = self.scenes[idx] | |
| map_images = [p[0] for p in self.pairs[query_image][:self.topk]] | |
| views = [] | |
| views.append(self._get_view_query(query_image)) | |
| for idx, map_image in enumerate(map_images): | |
| views.append(self._get_view_map(map_image, idx + 1)) | |
| return views | |