import collections import math import os import re import time from pathlib import Path from typing import List, NamedTuple, Tuple import cv2 import numpy as np import open3d as o3d import PIL.Image import roma import scipy import torch import torchvision.transforms as tvf import torchvision.transforms.functional as tf from PIL.ImageOps import exif_transpose from plyfile import PlyData, PlyElement from tqdm import tqdm from dust3r.utils.device import to_numpy from dust3r.utils.image import _resize_pil_image from field_construction.scene.colmap_loader import (qvec2rotmat, read_extrinsics_binary, rotmat2qvec, write_cameras_binary, write_cameras_text, write_images_binary, write_images_text) try: from pillow_heif import register_heif_opener register_heif_opener() heif_support_enabled = True except ImportError: heif_support_enabled = False ImgNorm = tvf.Compose([ tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) def save_time(time_dir, process_name, sub_time): if isinstance(time_dir, str): time_dir = Path(time_dir) time_dir.mkdir(parents=True, exist_ok=True) minutes, seconds = divmod(sub_time, 60) formatted_time = f"{int(minutes)} min {int(seconds)} sec" with open(time_dir / f'train_time.txt', 'a') as f: f.write(f'{process_name}: {formatted_time}\n') def split_train_test(image_files, llffhold=8, n_views=None, verbose=True): test_idx = np.linspace(1, len(image_files) - 2, num=12, dtype=int) train_idx = [i for i in range(len(image_files)) if i not in test_idx] sparse_idx = np.linspace(0, len(train_idx) - 1, num=n_views, dtype=int) train_idx = [train_idx[i] for i in sparse_idx] if verbose: print(">> Spliting Train-Test Set: ") # print(" - sparse_idx: ", sparse_idx) print(" - train_set_indices: ", train_idx) print(" - test_set_indices: ", test_idx) train_img_files = [image_files[i] for i in train_idx] test_img_files = [image_files[i] for i in test_idx] return train_img_files, test_img_files def get_sorted_image_files(image_dir: str) -> Tuple[List[str], List[str]]: """ Get sorted image files from the given directory. Args: image_dir (str): Path to the directory containing images. Returns: Tuple[List[str], List[str]]: A tuple containing two lists: - List of sorted image file paths - List of corresponding file suffixes """ allowed_extensions = {'.png', '.jpg', '.jpeg', '.bmp', '.tiff', '.JPG', '.PNG'} image_path = Path(image_dir) def extract_number(filename): match = re.search(r'\d+', filename.stem) return int(match.group()) if match else float('inf') image_files = [ str(f) for f in image_path.iterdir() if f.is_file() and f.suffix.lower() in allowed_extensions ] sorted_files = sorted(image_files, key=lambda x: extract_number(Path(x))) suffixes = [Path(file).suffix for file in sorted_files] return sorted_files, suffixes[0] def rigid_points_registration(pts1, pts2, conf=None): R, T, s = roma.rigid_points_registration( pts1.reshape(-1, 3), pts2.reshape(-1, 3), weights=conf, compute_scaling=True) return s, R, T # return un-scaled (R, T) def init_filestructure(save_path, n_views=None): if n_views is not None and n_views != 0: sparse_0_path = save_path / f'sparse_{n_views}/0' sparse_1_path = save_path / f'sparse_{n_views}/1' print(f'>> Doing {n_views} views reconstrution!') elif n_views is None or n_views == 0: sparse_0_path = save_path / 'sparse_0/0' sparse_1_path = save_path / 'sparse_0/1' print(f'>> Doing full views reconstrution!') save_path.mkdir(exist_ok=True, parents=True) sparse_0_path.mkdir(exist_ok=True, parents=True) sparse_1_path.mkdir(exist_ok=True, parents=True) return save_path, sparse_0_path, sparse_1_path def load_images(folder_or_list, size=512, square_ok=False, verbose=True): """ open and convert all images in a list or folder to proper input format for DUSt3R """ if isinstance(folder_or_list, str): if verbose: print(f'>> Loading images from {folder_or_list}') root, folder_content = folder_or_list, sorted(os.listdir(folder_or_list)) elif isinstance(folder_or_list, list): if verbose: print(f'>> Loading a list of {len(folder_or_list)} images') root, folder_content = '', folder_or_list else: raise ValueError(f'bad {folder_or_list=} ({type(folder_or_list)})') supported_images_extensions = ['.jpg', '.jpeg', '.png', '.JPG', 'PNG'] if heif_support_enabled: supported_images_extensions += ['.heic', '.heif'] supported_images_extensions = tuple(supported_images_extensions) imgs = [] for path in folder_content: if not path.lower().endswith(supported_images_extensions): continue img = exif_transpose(PIL.Image.open(os.path.join(root, path))).convert('RGB') W1, H1 = img.size if size == 224: # resize short side to 224 (then crop) img = _resize_pil_image(img, round(size * max(W1/H1, H1/W1))) else: # resize long side to 512 img = _resize_pil_image(img, size) W, H = img.size cx, cy = W//2, H//2 if size == 224: half = min(cx, cy) img = img.crop((cx-half, cy-half, cx+half, cy+half)) else: halfw, halfh = ((2*cx)//16)*8, ((2*cy)//16)*8 if not (square_ok) and W == H: halfh = 3*halfw/4 img = img.crop((cx-halfw, cy-halfh, cx+halfw, cy+halfh)) W2, H2 = img.size if verbose: print(f' - adding {path} with resolution {W1}x{H1} --> {W2}x{H2}') imgs.append(dict(img=ImgNorm(img)[None], true_shape=np.int32( [img.size[::-1]]), idx=len(imgs), instance=str(len(imgs)))) assert imgs, 'no images foud at '+root if verbose: print(f' (Found {len(imgs)} images)') return imgs, (W1,H1) import collections CameraModel = collections.namedtuple("CameraModel", ["model_id", "model_name", "num_params"]) Camera = collections.namedtuple("Camera", ["id", "model", "width", "height", "params"]) BaseImage = collections.namedtuple("Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) Point3D = collections.namedtuple("Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) CAMERA_MODELS = { CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), CameraModel(model_id=1, model_name="PINHOLE", num_params=4), CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), CameraModel(model_id=3, model_name="RADIAL", num_params=5), CameraModel(model_id=4, model_name="OPENCV", num_params=8), CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), CameraModel(model_id=7, model_name="FOV", num_params=5), CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) } CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) for camera_model in CAMERA_MODELS]) CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model) for camera_model in CAMERA_MODELS]) def save_extrinsic(sparse_path, extrinsics_w2c, img_files, image_suffix): images_bin_file = sparse_path / 'images.bin' images_txt_file = sparse_path / 'images.txt' images = {} for i, (w2c, img_file) in enumerate(zip(extrinsics_w2c, img_files), start=1): # Start enumeration from 1 name = Path(img_file).stem + image_suffix rotation_matrix = w2c[:3, :3] qvec = rotmat2qvec(rotation_matrix) tvec = w2c[:3, 3] images[i] = BaseImage( id=i, qvec=qvec, tvec=tvec, camera_id=i, name=name, xys=[], # Empty list as we don't have 2D point information point3D_ids=[] # Empty list as we don't have 3D point IDs ) write_images_binary(images, images_bin_file) write_images_text(images, images_txt_file) def save_intrinsics(sparse_path, focals, org_imgs_shape, imgs_shape, save_focals=False): org_width, org_height = org_imgs_shape scale_factor_x = org_width / imgs_shape[2] scale_factor_y = org_height / imgs_shape[1] cameras_bin_file = sparse_path / 'cameras.bin' cameras_txt_file = sparse_path / 'cameras.txt' cameras = {} for i, focal in enumerate(focals, start=1): # Start enumeration from 1 cameras[i] = Camera( id=i, model="PINHOLE", width=org_width, height=org_height, params=[focal*scale_factor_x, focal*scale_factor_y, org_width/2, org_height/2] ) print(f' - scaling focal: ({focal}, {focal}) --> ({focal*scale_factor_x}, {focal*scale_factor_y})' ) write_cameras_binary(cameras, cameras_bin_file) write_cameras_text(cameras, cameras_txt_file) if save_focals: np.save(sparse_path / 'non_scaled_focals.npy', focals) def save_points3D(sparse_path, imgs, pts3d, confs, masks=None, use_masks=True, save_all_pts=False, save_txt_path=None, depth_threshold=0.1, max_pts_num=150 * 10**10): points3D_bin_file = sparse_path / 'points3D.bin' points3D_txt_file = sparse_path / 'points3D.txt' points3D_ply_file = sparse_path / 'points3D.ply' # Convert inputs to numpy arrays imgs = to_numpy(imgs) pts3d = to_numpy(pts3d) confs = to_numpy(confs) if confs is not None: np.save(sparse_path / 'confidence.npy', confs) # Process points and colors if use_masks: masks = to_numpy(masks) pts = np.concatenate([p[m] for p, m in zip(pts3d, masks)]) # pts = np.concatenate([p[m] for p, m in zip(pts3d, masks.reshape(masks.shape[0], -1))]) col = np.concatenate([p[m] for p, m in zip(imgs, masks)]) confs = np.concatenate([p[m] for p, m in zip(confs, masks.reshape(masks.shape[0], -1))]) else: pts = np.array(pts3d) col = np.array(imgs) confs = np.array(confs) pts = pts.reshape(-1, 3) col = col.reshape(-1, 3) * 255. confs = confs.reshape(-1, 1) co_mask_dsp_pts_num = pts.shape[0] if pts.shape[0] > max_pts_num: print(f'Downsampling points from {pts.shape[0]} to {max_pts_num}') # Normalize confidences to range (0, 1) confs_min = np.min(confs) confs_max = np.max(confs) confs = (confs - confs_min) / (confs_max - confs_min) confs = confs + 1 weights = confs.reshape(-1) / np.sum(confs) indices = np.random.choice(pts.shape[0], max_pts_num, replace=False, p=weights) pts = pts[indices] col = col[indices] confs = confs[indices] conf_dsp_pts_num = pts.shape[0] if confs is not None: np.save(sparse_path / 'confidence_dsp.npy', confs) storePly(points3D_ply_file, pts, col) if save_all_pts: np.save(sparse_path / 'points3D_all.npy', pts3d) np.save(sparse_path / 'pointsColor_all.npy', imgs) # Write pts_num.txt if isinstance(save_txt_path, str): save_txt_path = Path(save_txt_path) pts_num_file = save_txt_path / f'pts_num.txt' # New file for pts_num with open(pts_num_file, 'a') as f: f.write(f"Depth threshold: {depth_threshold}\n") f.write(f"Vanilla points num: {pts3d.reshape(-1, 3).shape[0]}\n") f.write(f"Co_Mask DSP points num: {co_mask_dsp_pts_num}\n") f.write(f"Co_Mask DSP ratio: {co_mask_dsp_pts_num / pts3d.reshape(-1, 3).shape[0]}\n") if co_mask_dsp_pts_num > max_pts_num: f.write(f"Conf_Mask DSP points num: {conf_dsp_pts_num}\n") f.write(f"Conf_Mask DSP ratio: {conf_dsp_pts_num / pts3d.reshape(-1, 3).shape[0]}\n") f.write("\n") return pts.shape[0] # Save images and masks def save_images_and_masks(sparse_0_path, n_views, imgs, overlapping_masks, image_files, image_suffix): images_path = sparse_0_path / f'imgs_{n_views}' overlapping_masks_path = sparse_0_path / f'overlapping_masks_{n_views}' images_path.mkdir(exist_ok=True, parents=True) overlapping_masks_path.mkdir(exist_ok=True, parents=True) for i, (image, name, overlapping_mask) in enumerate(zip(imgs, image_files, overlapping_masks)): imgname = Path(name).stem image_save_path = images_path / f"{imgname}{image_suffix}" overlapping_mask_save_path = overlapping_masks_path / f"{imgname}{image_suffix}" overlapping_mask_save_path = overlapping_masks_path / f"{imgname}{image_suffix}" # Save overlapping masks overlapping_mask = np.repeat(np.expand_dims(overlapping_mask, -1), 3, axis=2) * 255 PIL.Image.fromarray(overlapping_mask.astype(np.uint8)).save(overlapping_mask_save_path) # Save images rgb_image = cv2.cvtColor(image * 255, cv2.COLOR_BGR2RGB) cv2.imwrite(str(image_save_path), rgb_image) def cal_co_vis_mask(points, depths, curr_depth_map, depth_threshold, camera_intrinsics, extrinsics_w2c): h, w = curr_depth_map.shape overlapping_mask = np.zeros((h, w), dtype=bool) # Project 3D points to image j points_2d, _ = project_points(points, camera_intrinsics, extrinsics_w2c) # Check if points are within image bounds valid_points = (points_2d[:, 0] >= 0) & (points_2d[:, 0] < w) & \ (points_2d[:, 1] >= 0) & (points_2d[:, 1] < h) # Check depth consistency using vectorized operations valid_points_2d = points_2d[valid_points].astype(int) valid_depths = depths[valid_points] # Extract x and y coordinates x_coords, y_coords = valid_points_2d[:, 0], valid_points_2d[:, 1] # Compute depth differences depth_differences = np.abs(valid_depths - curr_depth_map[y_coords, x_coords]) # Create a mask for points where the depth difference is below the threshold consistent_depth_mask = depth_differences < depth_threshold # Update the overlapping masks using the consistent depth mask overlapping_mask[y_coords[consistent_depth_mask], x_coords[consistent_depth_mask]] = True return overlapping_mask def normalize_depth(depth_map): """Normalize the depth map to a range between 0 and 1.""" return (depth_map - np.min(depth_map)) / (np.max(depth_map) - np.min(depth_map)) def compute_co_vis_masks(sorted_conf_indices, depthmaps, pointmaps, camera_intrinsics, extrinsics_w2c, image_sizes, depth_threshold=0.1): num_images, h, w, _ = image_sizes pointmaps = pointmaps.reshape(num_images, h, w, 3) overlapping_masks = np.zeros((num_images, h, w), dtype=bool) for i, curr_map_idx in tqdm(enumerate(sorted_conf_indices), total=len(sorted_conf_indices)): # if frame_idx is 0, set its occ_mask to be all False if i == 0: continue # get before and after curr_frame's indices idx_before = sorted_conf_indices[:i] # idx_after = sorted_conf_indices[i+1:] # get partial pointmaps and depthmaps points_before = pointmaps[idx_before].reshape(-1, 3) depths_before = depthmaps[idx_before].reshape(-1) # points_after = pointmaps[idx_after].reshape(-1, 3) # depths_after = depthmaps[idx_after].reshape(-1) # get current frame's depth map curr_depth_map = depthmaps[curr_map_idx].reshape(h, w) # normalize depth for comparison depths_before = normalize_depth(depths_before) # depths_after = normalize_depth(depths_after) curr_depth_map = normalize_depth(curr_depth_map) # before_mask = overlapping_masks[idx_before] # after_mask = overlapping_masks[idx_after] # curr_mask = before_mask & after_mask before_mask = cal_co_vis_mask(points_before, depths_before, curr_depth_map, depth_threshold, camera_intrinsics[curr_map_idx], extrinsics_w2c[curr_map_idx]) # after_mask = cal_co_vis_mask(points_after, depths_after, camera_intrinsics[i], extrinsics_w2c[i], curr_depth_map, depth_threshold) # white/True means co-visible redundant area: we need to remove overlapping_masks[curr_map_idx] = before_mask# & after_mask return overlapping_masks def project_points(points_3d, intrinsics, extrinsics): # Convert to homogeneous coordinates points_3d_homogeneous = np.hstack((points_3d, np.ones((points_3d.shape[0], 1)))) # Apply extrinsic matrix points_camera = np.dot(extrinsics, points_3d_homogeneous.T).T # Apply intrinsic matrix points_2d_homogeneous = np.dot(intrinsics, points_camera[:, :3].T).T # Convert to 2D coordinates points_2d = points_2d_homogeneous[:, :2] / points_2d_homogeneous[:, 2:] depths = points_camera[:, 2] return points_2d, depths def read_colmap_gt_pose(gt_pose_path, llffhold=8): colmap_cam_extrinsics = read_extrinsics_binary(gt_pose_path + '/sparse/0/images.bin') colmap_cam_extrinsics = {k: v for k, v in sorted(colmap_cam_extrinsics.items(), key=lambda item: item[1].name)} all_pose=[] for idx, key in enumerate(colmap_cam_extrinsics): extr = colmap_cam_extrinsics[key] # print(idx, extr.name) R = np.transpose(qvec2rotmat(extr.qvec)) # R = np.array(qvec2rotmat(extr.qvec)) T = np.array(extr.tvec) pose = np.eye(4,4) pose[:3, :3] = R pose[:3, 3] = T all_pose.append(pose) colmap_pose = np.array(all_pose) return colmap_pose def readImages(renders_dir, gt_dir): renders = [] gts = [] image_names = [] for fname in os.listdir(renders_dir): render = PIL.Image.open(renders_dir / fname) gt = PIL.Image.open(gt_dir / fname) renders.append(tf.to_tensor(render).unsqueeze(0)[:, :3, :, :].cuda()) gts.append(tf.to_tensor(gt).unsqueeze(0)[:, :3, :, :].cuda()) image_names.append(fname) return renders, gts, image_names def align_pose(pose1, pose2): mtx1 = np.array(pose1, dtype=np.double, copy=True) mtx2 = np.array(pose2, dtype=np.double, copy=True) if mtx1.ndim != 2 or mtx2.ndim != 2: raise ValueError("Input matrices must be two-dimensional") if mtx1.shape != mtx2.shape: raise ValueError("Input matrices must be of same shape") if mtx1.size == 0: raise ValueError("Input matrices must be >0 rows and >0 cols") # translate all the data to the origin mtx1 -= np.mean(mtx1, 0) mtx2 -= np.mean(mtx2, 0) norm1 = np.linalg.norm(mtx1) norm2 = np.linalg.norm(mtx2) if norm1 == 0 or norm2 == 0: raise ValueError("Input matrices must contain >1 unique points") # change scaling of data (in rows) such that trace(mtx*mtx') = 1 mtx1 /= norm1 mtx2 /= norm2 # transform mtx2 to minimize disparity R, s = scipy.linalg.orthogonal_procrustes(mtx1, mtx2) mtx2 = mtx2 * s return mtx1, mtx2, R def storePly(path, xyz, rgb): # Define the dtype for the structured array dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')] normals = np.zeros_like(xyz) elements = np.empty(xyz.shape[0], dtype=dtype) attributes = np.concatenate((xyz, normals, rgb), axis=1) elements[:] = list(map(tuple, attributes)) # Create the PlyData object and write to file vertex_element = PlyElement.describe(elements, 'vertex') ply_data = PlyData([vertex_element]) ply_data.write(path)