Spaces:
Running
Running
import os | |
from copy import deepcopy | |
import numpy as np | |
import open3d as o3d | |
import torch | |
def storePly(path, xyz, rgb): | |
pcd = o3d.geometry.PointCloud() | |
pcd.points = o3d.utility.Vector3dVector(xyz) | |
pcd.colors = o3d.utility.Vector3dVector(rgb) | |
o3d.io.write_point_cloud(path, pcd) | |
def prepare_input( | |
img_paths, img_mask, size, raymaps=None, raymap_mask=None, revisit=1, update=True | |
): | |
""" | |
Prepare input views for inference from a list of image paths. | |
Args: | |
img_paths (list): List of image file paths. | |
img_mask (list of bool): Flags indicating valid images. | |
size (int): Target image size. | |
raymaps (list, optional): List of ray maps. | |
raymap_mask (list, optional): Flags indicating valid ray maps. | |
revisit (int): How many times to revisit each view. | |
update (bool): Whether to update the state on revisits. | |
Returns: | |
list: A list of view dictionaries. | |
""" | |
# Import image loader (delayed import needed after adding ckpt path). | |
from cut3r.dust3r.utils.image import load_images | |
images, orig_shape = load_images(img_paths, size=size) | |
views = [] | |
if raymaps is None and raymap_mask is None: | |
# Only images are provided. | |
for i in range(len(images)): | |
view = { | |
"img": images[i]["img"], | |
"ray_map": torch.full( | |
( | |
images[i]["img"].shape[0], | |
6, | |
images[i]["img"].shape[-2], | |
images[i]["img"].shape[-1], | |
), | |
torch.nan, | |
), | |
"true_shape": torch.from_numpy(images[i]["true_shape"]), | |
"idx": i, | |
"instance": str(i), | |
"camera_pose": torch.from_numpy(np.eye(4, dtype=np.float32)).unsqueeze( | |
0 | |
), | |
"img_mask": torch.tensor(True).unsqueeze(0), | |
"ray_mask": torch.tensor(False).unsqueeze(0), | |
"update": torch.tensor(True).unsqueeze(0), | |
"reset": torch.tensor(False).unsqueeze(0), | |
} | |
views.append(view) | |
else: | |
# Combine images and raymaps. | |
num_views = len(images) + len(raymaps) | |
assert len(img_mask) == len(raymap_mask) == num_views | |
assert sum(img_mask) == len(images) and sum(raymap_mask) == len(raymaps) | |
j = 0 | |
k = 0 | |
for i in range(num_views): | |
view = { | |
"img": ( | |
images[j]["img"] | |
if img_mask[i] | |
else torch.full_like(images[0]["img"], torch.nan) | |
), | |
"ray_map": ( | |
raymaps[k] | |
if raymap_mask[i] | |
else torch.full_like(raymaps[0], torch.nan) | |
), | |
"true_shape": ( | |
torch.from_numpy(images[j]["true_shape"]) | |
if img_mask[i] | |
else torch.from_numpy(np.int32([raymaps[k].shape[1:-1][::-1]])) | |
), | |
"idx": i, | |
"instance": str(i), | |
"camera_pose": torch.from_numpy(np.eye(4, dtype=np.float32)).unsqueeze( | |
0 | |
), | |
"img_mask": torch.tensor(img_mask[i]).unsqueeze(0), | |
"ray_mask": torch.tensor(raymap_mask[i]).unsqueeze(0), | |
"update": torch.tensor(img_mask[i]).unsqueeze(0), | |
"reset": torch.tensor(False).unsqueeze(0), | |
} | |
if img_mask[i]: | |
j += 1 | |
if raymap_mask[i]: | |
k += 1 | |
views.append(view) | |
assert j == len(images) and k == len(raymaps) | |
if revisit > 1: | |
new_views = [] | |
for r in range(revisit): | |
for i, view in enumerate(views): | |
new_view = deepcopy(view) | |
new_view["idx"] = r * len(views) + i | |
new_view["instance"] = str(r * len(views) + i) | |
if r > 0 and not update: | |
new_view["update"] = torch.tensor(False).unsqueeze(0) | |
new_views.append(new_view) | |
return new_views | |
return views, orig_shape | |
def prepare_output(outputs, orig_shape, outdir, revisit=1, use_pose=True): | |
""" | |
Process inference outputs to generate point clouds and camera parameters for visualization. | |
Args: | |
outputs (dict): Inference outputs. | |
revisit (int): Number of revisits per view. | |
use_pose (bool): Whether to transform points using camera pose. | |
Returns: | |
tuple: (points, colors, confidence, camera parameters dictionary) | |
""" | |
from cut3r.dust3r.post_process import estimate_focal_knowing_depth | |
from cut3r.dust3r.utils.camera import pose_encoding_to_camera | |
from cut3r.dust3r.utils.geometry import geotrf | |
# Only keep the outputs corresponding to one full pass. | |
valid_length = len(outputs["pred"]) // revisit | |
outputs["pred"] = outputs["pred"][-valid_length:] | |
outputs["views"] = outputs["views"][-valid_length:] | |
pts3ds_self_ls = [output["pts3d_in_self_view"].cpu() for output in outputs["pred"]] | |
pts3ds_other = [output["pts3d_in_other_view"].cpu() for output in outputs["pred"]] | |
conf_self = [output["conf_self"].cpu() for output in outputs["pred"]] | |
conf_other = [output["conf"].cpu() for output in outputs["pred"]] | |
pts3ds_self = torch.cat(pts3ds_self_ls, 0) | |
# Recover camera poses. | |
pr_poses = [ | |
pose_encoding_to_camera(pred["camera_pose"].clone()).cpu() | |
for pred in outputs["pred"] | |
] | |
R_c2w = torch.cat([pr_pose[:, :3, :3] for pr_pose in pr_poses], 0) | |
t_c2w = torch.cat([pr_pose[:, :3, 3] for pr_pose in pr_poses], 0) | |
if use_pose: | |
transformed_pts3ds_other = [] | |
for pose, pself in zip(pr_poses, pts3ds_self): | |
transformed_pts3ds_other.append(geotrf(pose, pself.unsqueeze(0))) | |
pts3ds_other = transformed_pts3ds_other | |
conf_other = conf_self | |
# Estimate focal length based on depth. | |
B, H, W, _ = pts3ds_self.shape | |
orig_H, orig_W = orig_shape | |
pp = torch.tensor([orig_W // 2, orig_H // 2], device=pts3ds_self.device).float().repeat(B, 1) | |
focal = estimate_focal_knowing_depth(pts3ds_self, pp, focal_mode="weiszfeld") | |
# focal = focal.mean().repeat(len(focal)) | |
focal_x = focal * orig_W / W | |
focal_y = focal * orig_H / H | |
colors = [ | |
0.5 * (output["img"].permute(0, 2, 3, 1) + 1.0) for output in outputs["views"] | |
] | |
cam_dict = { | |
"focal": focal.cpu().numpy(), | |
"pp": pp.cpu().numpy(), | |
"R": R_c2w.cpu().numpy(), | |
"t": t_c2w.cpu().numpy(), | |
} | |
cam2world_tosave = torch.cat(pr_poses) # B, 4, 4 | |
intrinsics_tosave = ( | |
torch.eye(3).unsqueeze(0).repeat(cam2world_tosave.shape[0], 1, 1) | |
) # B, 3, 3 | |
intrinsics_tosave[:, 0, 0] = focal_x.detach().cpu() | |
intrinsics_tosave[:, 1, 1] = focal_y.detach().cpu() | |
intrinsics_tosave[:, 0, 2] = pp[:, 0] | |
intrinsics_tosave[:, 1, 2] = pp[:, 1] | |
os.makedirs(os.path.join(outdir, "camera"), exist_ok=True) | |
for f_id in range(len(cam2world_tosave)): | |
c2w = cam2world_tosave[f_id].cpu().numpy() | |
intrins = intrinsics_tosave[f_id].cpu().numpy() | |
np.savez( | |
os.path.join(outdir, "camera", f"{f_id+1:04d}.npz"), | |
pose=c2w, | |
intrinsics=intrins, | |
) | |
return pts3ds_other, colors, conf_other, cam_dict | |