LangScene-X / utils /sfm_utils.py
seawolf2357's picture
Upload folder using huggingface_hub
684943d verified
import collections
import math
import os
import re
import time
from pathlib import Path
from typing import List, NamedTuple, Tuple
import cv2
import numpy as np
import open3d as o3d
import PIL.Image
import roma
import scipy
import torch
import torchvision.transforms as tvf
import torchvision.transforms.functional as tf
from PIL.ImageOps import exif_transpose
from plyfile import PlyData, PlyElement
from tqdm import tqdm
from dust3r.utils.device import to_numpy
from dust3r.utils.image import _resize_pil_image
from field_construction.scene.colmap_loader import (qvec2rotmat,
read_extrinsics_binary,
rotmat2qvec,
write_cameras_binary,
write_cameras_text,
write_images_binary,
write_images_text)
try:
from pillow_heif import register_heif_opener
register_heif_opener()
heif_support_enabled = True
except ImportError:
heif_support_enabled = False
ImgNorm = tvf.Compose([
tvf.ToTensor(),
tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
def save_time(time_dir, process_name, sub_time):
if isinstance(time_dir, str):
time_dir = Path(time_dir)
time_dir.mkdir(parents=True, exist_ok=True)
minutes, seconds = divmod(sub_time, 60)
formatted_time = f"{int(minutes)} min {int(seconds)} sec"
with open(time_dir / f'train_time.txt', 'a') as f:
f.write(f'{process_name}: {formatted_time}\n')
def split_train_test(image_files, llffhold=8, n_views=None, verbose=True):
test_idx = np.linspace(1, len(image_files) - 2, num=12, dtype=int)
train_idx = [i for i in range(len(image_files)) if i not in test_idx]
sparse_idx = np.linspace(0, len(train_idx) - 1, num=n_views, dtype=int)
train_idx = [train_idx[i] for i in sparse_idx]
if verbose:
print(">> Spliting Train-Test Set: ")
# print(" - sparse_idx: ", sparse_idx)
print(" - train_set_indices: ", train_idx)
print(" - test_set_indices: ", test_idx)
train_img_files = [image_files[i] for i in train_idx]
test_img_files = [image_files[i] for i in test_idx]
return train_img_files, test_img_files
def get_sorted_image_files(image_dir: str) -> Tuple[List[str], List[str]]:
"""
Get sorted image files from the given directory.
Args:
image_dir (str): Path to the directory containing images.
Returns:
Tuple[List[str], List[str]]: A tuple containing two lists:
- List of sorted image file paths
- List of corresponding file suffixes
"""
allowed_extensions = {'.png', '.jpg', '.jpeg', '.bmp', '.tiff', '.JPG', '.PNG'}
image_path = Path(image_dir)
def extract_number(filename):
match = re.search(r'\d+', filename.stem)
return int(match.group()) if match else float('inf')
image_files = [
str(f) for f in image_path.iterdir()
if f.is_file() and f.suffix.lower() in allowed_extensions
]
sorted_files = sorted(image_files, key=lambda x: extract_number(Path(x)))
suffixes = [Path(file).suffix for file in sorted_files]
return sorted_files, suffixes[0]
def rigid_points_registration(pts1, pts2, conf=None):
R, T, s = roma.rigid_points_registration(
pts1.reshape(-1, 3), pts2.reshape(-1, 3), weights=conf, compute_scaling=True)
return s, R, T # return un-scaled (R, T)
def init_filestructure(save_path, n_views=None):
if n_views is not None and n_views != 0:
sparse_0_path = save_path / f'sparse_{n_views}/0'
sparse_1_path = save_path / f'sparse_{n_views}/1'
print(f'>> Doing {n_views} views reconstrution!')
elif n_views is None or n_views == 0:
sparse_0_path = save_path / 'sparse_0/0'
sparse_1_path = save_path / 'sparse_0/1'
print(f'>> Doing full views reconstrution!')
save_path.mkdir(exist_ok=True, parents=True)
sparse_0_path.mkdir(exist_ok=True, parents=True)
sparse_1_path.mkdir(exist_ok=True, parents=True)
return save_path, sparse_0_path, sparse_1_path
def load_images(folder_or_list, size=512, square_ok=False, verbose=True):
""" open and convert all images in a list or folder to proper input format for DUSt3R
"""
if isinstance(folder_or_list, str):
if verbose:
print(f'>> Loading images from {folder_or_list}')
root, folder_content = folder_or_list, sorted(os.listdir(folder_or_list))
elif isinstance(folder_or_list, list):
if verbose:
print(f'>> Loading a list of {len(folder_or_list)} images')
root, folder_content = '', folder_or_list
else:
raise ValueError(f'bad {folder_or_list=} ({type(folder_or_list)})')
supported_images_extensions = ['.jpg', '.jpeg', '.png', '.JPG', 'PNG']
if heif_support_enabled:
supported_images_extensions += ['.heic', '.heif']
supported_images_extensions = tuple(supported_images_extensions)
imgs = []
for path in folder_content:
if not path.lower().endswith(supported_images_extensions):
continue
img = exif_transpose(PIL.Image.open(os.path.join(root, path))).convert('RGB')
W1, H1 = img.size
if size == 224:
# resize short side to 224 (then crop)
img = _resize_pil_image(img, round(size * max(W1/H1, H1/W1)))
else:
# resize long side to 512
img = _resize_pil_image(img, size)
W, H = img.size
cx, cy = W//2, H//2
if size == 224:
half = min(cx, cy)
img = img.crop((cx-half, cy-half, cx+half, cy+half))
else:
halfw, halfh = ((2*cx)//16)*8, ((2*cy)//16)*8
if not (square_ok) and W == H:
halfh = 3*halfw/4
img = img.crop((cx-halfw, cy-halfh, cx+halfw, cy+halfh))
W2, H2 = img.size
if verbose:
print(f' - adding {path} with resolution {W1}x{H1} --> {W2}x{H2}')
imgs.append(dict(img=ImgNorm(img)[None], true_shape=np.int32(
[img.size[::-1]]), idx=len(imgs), instance=str(len(imgs))))
assert imgs, 'no images foud at '+root
if verbose:
print(f' (Found {len(imgs)} images)')
return imgs, (W1,H1)
import collections
CameraModel = collections.namedtuple("CameraModel", ["model_id", "model_name", "num_params"])
Camera = collections.namedtuple("Camera", ["id", "model", "width", "height", "params"])
BaseImage = collections.namedtuple("Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"])
Point3D = collections.namedtuple("Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"])
CAMERA_MODELS = {
CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
CameraModel(model_id=3, model_name="RADIAL", num_params=5),
CameraModel(model_id=4, model_name="OPENCV", num_params=8),
CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
CameraModel(model_id=7, model_name="FOV", num_params=5),
CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12)
}
CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model)
for camera_model in CAMERA_MODELS])
CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model)
for camera_model in CAMERA_MODELS])
def save_extrinsic(sparse_path, extrinsics_w2c, img_files, image_suffix):
images_bin_file = sparse_path / 'images.bin'
images_txt_file = sparse_path / 'images.txt'
images = {}
for i, (w2c, img_file) in enumerate(zip(extrinsics_w2c, img_files), start=1): # Start enumeration from 1
name = Path(img_file).stem + image_suffix
rotation_matrix = w2c[:3, :3]
qvec = rotmat2qvec(rotation_matrix)
tvec = w2c[:3, 3]
images[i] = BaseImage(
id=i,
qvec=qvec,
tvec=tvec,
camera_id=i,
name=name,
xys=[], # Empty list as we don't have 2D point information
point3D_ids=[] # Empty list as we don't have 3D point IDs
)
write_images_binary(images, images_bin_file)
write_images_text(images, images_txt_file)
def save_intrinsics(sparse_path, focals, org_imgs_shape, imgs_shape, save_focals=False):
org_width, org_height = org_imgs_shape
scale_factor_x = org_width / imgs_shape[2]
scale_factor_y = org_height / imgs_shape[1]
cameras_bin_file = sparse_path / 'cameras.bin'
cameras_txt_file = sparse_path / 'cameras.txt'
cameras = {}
for i, focal in enumerate(focals, start=1): # Start enumeration from 1
cameras[i] = Camera(
id=i,
model="PINHOLE",
width=org_width,
height=org_height,
params=[focal*scale_factor_x, focal*scale_factor_y, org_width/2, org_height/2]
)
print(f' - scaling focal: ({focal}, {focal}) --> ({focal*scale_factor_x}, {focal*scale_factor_y})' )
write_cameras_binary(cameras, cameras_bin_file)
write_cameras_text(cameras, cameras_txt_file)
if save_focals:
np.save(sparse_path / 'non_scaled_focals.npy', focals)
def save_points3D(sparse_path, imgs, pts3d, confs, masks=None, use_masks=True, save_all_pts=False, save_txt_path=None, depth_threshold=0.1, max_pts_num=150 * 10**10):
points3D_bin_file = sparse_path / 'points3D.bin'
points3D_txt_file = sparse_path / 'points3D.txt'
points3D_ply_file = sparse_path / 'points3D.ply'
# Convert inputs to numpy arrays
imgs = to_numpy(imgs)
pts3d = to_numpy(pts3d)
confs = to_numpy(confs)
if confs is not None:
np.save(sparse_path / 'confidence.npy', confs)
# Process points and colors
if use_masks:
masks = to_numpy(masks)
pts = np.concatenate([p[m] for p, m in zip(pts3d, masks)])
# pts = np.concatenate([p[m] for p, m in zip(pts3d, masks.reshape(masks.shape[0], -1))])
col = np.concatenate([p[m] for p, m in zip(imgs, masks)])
confs = np.concatenate([p[m] for p, m in zip(confs, masks.reshape(masks.shape[0], -1))])
else:
pts = np.array(pts3d)
col = np.array(imgs)
confs = np.array(confs)
pts = pts.reshape(-1, 3)
col = col.reshape(-1, 3) * 255.
confs = confs.reshape(-1, 1)
co_mask_dsp_pts_num = pts.shape[0]
if pts.shape[0] > max_pts_num:
print(f'Downsampling points from {pts.shape[0]} to {max_pts_num}')
# Normalize confidences to range (0, 1)
confs_min = np.min(confs)
confs_max = np.max(confs)
confs = (confs - confs_min) / (confs_max - confs_min)
confs = confs + 1
weights = confs.reshape(-1) / np.sum(confs)
indices = np.random.choice(pts.shape[0], max_pts_num, replace=False, p=weights)
pts = pts[indices]
col = col[indices]
confs = confs[indices]
conf_dsp_pts_num = pts.shape[0]
if confs is not None:
np.save(sparse_path / 'confidence_dsp.npy', confs)
storePly(points3D_ply_file, pts, col)
if save_all_pts:
np.save(sparse_path / 'points3D_all.npy', pts3d)
np.save(sparse_path / 'pointsColor_all.npy', imgs)
# Write pts_num.txt
if isinstance(save_txt_path, str):
save_txt_path = Path(save_txt_path)
pts_num_file = save_txt_path / f'pts_num.txt' # New file for pts_num
with open(pts_num_file, 'a') as f:
f.write(f"Depth threshold: {depth_threshold}\n")
f.write(f"Vanilla points num: {pts3d.reshape(-1, 3).shape[0]}\n")
f.write(f"Co_Mask DSP points num: {co_mask_dsp_pts_num}\n")
f.write(f"Co_Mask DSP ratio: {co_mask_dsp_pts_num / pts3d.reshape(-1, 3).shape[0]}\n")
if co_mask_dsp_pts_num > max_pts_num:
f.write(f"Conf_Mask DSP points num: {conf_dsp_pts_num}\n")
f.write(f"Conf_Mask DSP ratio: {conf_dsp_pts_num / pts3d.reshape(-1, 3).shape[0]}\n")
f.write("\n")
return pts.shape[0]
# Save images and masks
def save_images_and_masks(sparse_0_path, n_views, imgs, overlapping_masks, image_files, image_suffix):
images_path = sparse_0_path / f'imgs_{n_views}'
overlapping_masks_path = sparse_0_path / f'overlapping_masks_{n_views}'
images_path.mkdir(exist_ok=True, parents=True)
overlapping_masks_path.mkdir(exist_ok=True, parents=True)
for i, (image, name, overlapping_mask) in enumerate(zip(imgs, image_files, overlapping_masks)):
imgname = Path(name).stem
image_save_path = images_path / f"{imgname}{image_suffix}"
overlapping_mask_save_path = overlapping_masks_path / f"{imgname}{image_suffix}"
overlapping_mask_save_path = overlapping_masks_path / f"{imgname}{image_suffix}"
# Save overlapping masks
overlapping_mask = np.repeat(np.expand_dims(overlapping_mask, -1), 3, axis=2) * 255
PIL.Image.fromarray(overlapping_mask.astype(np.uint8)).save(overlapping_mask_save_path)
# Save images
rgb_image = cv2.cvtColor(image * 255, cv2.COLOR_BGR2RGB)
cv2.imwrite(str(image_save_path), rgb_image)
def cal_co_vis_mask(points, depths, curr_depth_map, depth_threshold, camera_intrinsics, extrinsics_w2c):
h, w = curr_depth_map.shape
overlapping_mask = np.zeros((h, w), dtype=bool)
# Project 3D points to image j
points_2d, _ = project_points(points, camera_intrinsics, extrinsics_w2c)
# Check if points are within image bounds
valid_points = (points_2d[:, 0] >= 0) & (points_2d[:, 0] < w) & \
(points_2d[:, 1] >= 0) & (points_2d[:, 1] < h)
# Check depth consistency using vectorized operations
valid_points_2d = points_2d[valid_points].astype(int)
valid_depths = depths[valid_points]
# Extract x and y coordinates
x_coords, y_coords = valid_points_2d[:, 0], valid_points_2d[:, 1]
# Compute depth differences
depth_differences = np.abs(valid_depths - curr_depth_map[y_coords, x_coords])
# Create a mask for points where the depth difference is below the threshold
consistent_depth_mask = depth_differences < depth_threshold
# Update the overlapping masks using the consistent depth mask
overlapping_mask[y_coords[consistent_depth_mask], x_coords[consistent_depth_mask]] = True
return overlapping_mask
def normalize_depth(depth_map):
"""Normalize the depth map to a range between 0 and 1."""
return (depth_map - np.min(depth_map)) / (np.max(depth_map) - np.min(depth_map))
def compute_co_vis_masks(sorted_conf_indices, depthmaps, pointmaps, camera_intrinsics, extrinsics_w2c, image_sizes, depth_threshold=0.1):
num_images, h, w, _ = image_sizes
pointmaps = pointmaps.reshape(num_images, h, w, 3)
overlapping_masks = np.zeros((num_images, h, w), dtype=bool)
for i, curr_map_idx in tqdm(enumerate(sorted_conf_indices), total=len(sorted_conf_indices)):
# if frame_idx is 0, set its occ_mask to be all False
if i == 0:
continue
# get before and after curr_frame's indices
idx_before = sorted_conf_indices[:i]
# idx_after = sorted_conf_indices[i+1:]
# get partial pointmaps and depthmaps
points_before = pointmaps[idx_before].reshape(-1, 3)
depths_before = depthmaps[idx_before].reshape(-1)
# points_after = pointmaps[idx_after].reshape(-1, 3)
# depths_after = depthmaps[idx_after].reshape(-1)
# get current frame's depth map
curr_depth_map = depthmaps[curr_map_idx].reshape(h, w)
# normalize depth for comparison
depths_before = normalize_depth(depths_before)
# depths_after = normalize_depth(depths_after)
curr_depth_map = normalize_depth(curr_depth_map)
# before_mask = overlapping_masks[idx_before]
# after_mask = overlapping_masks[idx_after]
# curr_mask = before_mask & after_mask
before_mask = cal_co_vis_mask(points_before, depths_before, curr_depth_map, depth_threshold, camera_intrinsics[curr_map_idx], extrinsics_w2c[curr_map_idx])
# after_mask = cal_co_vis_mask(points_after, depths_after, camera_intrinsics[i], extrinsics_w2c[i], curr_depth_map, depth_threshold)
# white/True means co-visible redundant area: we need to remove
overlapping_masks[curr_map_idx] = before_mask# & after_mask
return overlapping_masks
def project_points(points_3d, intrinsics, extrinsics):
# Convert to homogeneous coordinates
points_3d_homogeneous = np.hstack((points_3d, np.ones((points_3d.shape[0], 1))))
# Apply extrinsic matrix
points_camera = np.dot(extrinsics, points_3d_homogeneous.T).T
# Apply intrinsic matrix
points_2d_homogeneous = np.dot(intrinsics, points_camera[:, :3].T).T
# Convert to 2D coordinates
points_2d = points_2d_homogeneous[:, :2] / points_2d_homogeneous[:, 2:]
depths = points_camera[:, 2]
return points_2d, depths
def read_colmap_gt_pose(gt_pose_path, llffhold=8):
colmap_cam_extrinsics = read_extrinsics_binary(gt_pose_path + '/sparse/0/images.bin')
colmap_cam_extrinsics = {k: v for k, v in sorted(colmap_cam_extrinsics.items(), key=lambda item: item[1].name)}
all_pose=[]
for idx, key in enumerate(colmap_cam_extrinsics):
extr = colmap_cam_extrinsics[key]
# print(idx, extr.name)
R = np.transpose(qvec2rotmat(extr.qvec))
# R = np.array(qvec2rotmat(extr.qvec))
T = np.array(extr.tvec)
pose = np.eye(4,4)
pose[:3, :3] = R
pose[:3, 3] = T
all_pose.append(pose)
colmap_pose = np.array(all_pose)
return colmap_pose
def readImages(renders_dir, gt_dir):
renders = []
gts = []
image_names = []
for fname in os.listdir(renders_dir):
render = PIL.Image.open(renders_dir / fname)
gt = PIL.Image.open(gt_dir / fname)
renders.append(tf.to_tensor(render).unsqueeze(0)[:, :3, :, :].cuda())
gts.append(tf.to_tensor(gt).unsqueeze(0)[:, :3, :, :].cuda())
image_names.append(fname)
return renders, gts, image_names
def align_pose(pose1, pose2):
mtx1 = np.array(pose1, dtype=np.double, copy=True)
mtx2 = np.array(pose2, dtype=np.double, copy=True)
if mtx1.ndim != 2 or mtx2.ndim != 2:
raise ValueError("Input matrices must be two-dimensional")
if mtx1.shape != mtx2.shape:
raise ValueError("Input matrices must be of same shape")
if mtx1.size == 0:
raise ValueError("Input matrices must be >0 rows and >0 cols")
# translate all the data to the origin
mtx1 -= np.mean(mtx1, 0)
mtx2 -= np.mean(mtx2, 0)
norm1 = np.linalg.norm(mtx1)
norm2 = np.linalg.norm(mtx2)
if norm1 == 0 or norm2 == 0:
raise ValueError("Input matrices must contain >1 unique points")
# change scaling of data (in rows) such that trace(mtx*mtx') = 1
mtx1 /= norm1
mtx2 /= norm2
# transform mtx2 to minimize disparity
R, s = scipy.linalg.orthogonal_procrustes(mtx1, mtx2)
mtx2 = mtx2 * s
return mtx1, mtx2, R
def storePly(path, xyz, rgb):
# Define the dtype for the structured array
dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'),
('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]
normals = np.zeros_like(xyz)
elements = np.empty(xyz.shape[0], dtype=dtype)
attributes = np.concatenate((xyz, normals, rgb), axis=1)
elements[:] = list(map(tuple, attributes))
# Create the PlyData object and write to file
vertex_element = PlyElement.describe(elements, 'vertex')
ply_data = PlyData([vertex_element])
ply_data.write(path)