seawolf2357's picture
Upload folder using huggingface_hub
684943d verified
import os
from copy import deepcopy
import numpy as np
import open3d as o3d
import torch
def storePly(path, xyz, rgb):
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(xyz)
pcd.colors = o3d.utility.Vector3dVector(rgb)
o3d.io.write_point_cloud(path, pcd)
def prepare_input(
img_paths, img_mask, size, raymaps=None, raymap_mask=None, revisit=1, update=True
):
"""
Prepare input views for inference from a list of image paths.
Args:
img_paths (list): List of image file paths.
img_mask (list of bool): Flags indicating valid images.
size (int): Target image size.
raymaps (list, optional): List of ray maps.
raymap_mask (list, optional): Flags indicating valid ray maps.
revisit (int): How many times to revisit each view.
update (bool): Whether to update the state on revisits.
Returns:
list: A list of view dictionaries.
"""
# Import image loader (delayed import needed after adding ckpt path).
from cut3r.dust3r.utils.image import load_images
images, orig_shape = load_images(img_paths, size=size)
views = []
if raymaps is None and raymap_mask is None:
# Only images are provided.
for i in range(len(images)):
view = {
"img": images[i]["img"],
"ray_map": torch.full(
(
images[i]["img"].shape[0],
6,
images[i]["img"].shape[-2],
images[i]["img"].shape[-1],
),
torch.nan,
),
"true_shape": torch.from_numpy(images[i]["true_shape"]),
"idx": i,
"instance": str(i),
"camera_pose": torch.from_numpy(np.eye(4, dtype=np.float32)).unsqueeze(
0
),
"img_mask": torch.tensor(True).unsqueeze(0),
"ray_mask": torch.tensor(False).unsqueeze(0),
"update": torch.tensor(True).unsqueeze(0),
"reset": torch.tensor(False).unsqueeze(0),
}
views.append(view)
else:
# Combine images and raymaps.
num_views = len(images) + len(raymaps)
assert len(img_mask) == len(raymap_mask) == num_views
assert sum(img_mask) == len(images) and sum(raymap_mask) == len(raymaps)
j = 0
k = 0
for i in range(num_views):
view = {
"img": (
images[j]["img"]
if img_mask[i]
else torch.full_like(images[0]["img"], torch.nan)
),
"ray_map": (
raymaps[k]
if raymap_mask[i]
else torch.full_like(raymaps[0], torch.nan)
),
"true_shape": (
torch.from_numpy(images[j]["true_shape"])
if img_mask[i]
else torch.from_numpy(np.int32([raymaps[k].shape[1:-1][::-1]]))
),
"idx": i,
"instance": str(i),
"camera_pose": torch.from_numpy(np.eye(4, dtype=np.float32)).unsqueeze(
0
),
"img_mask": torch.tensor(img_mask[i]).unsqueeze(0),
"ray_mask": torch.tensor(raymap_mask[i]).unsqueeze(0),
"update": torch.tensor(img_mask[i]).unsqueeze(0),
"reset": torch.tensor(False).unsqueeze(0),
}
if img_mask[i]:
j += 1
if raymap_mask[i]:
k += 1
views.append(view)
assert j == len(images) and k == len(raymaps)
if revisit > 1:
new_views = []
for r in range(revisit):
for i, view in enumerate(views):
new_view = deepcopy(view)
new_view["idx"] = r * len(views) + i
new_view["instance"] = str(r * len(views) + i)
if r > 0 and not update:
new_view["update"] = torch.tensor(False).unsqueeze(0)
new_views.append(new_view)
return new_views
return views, orig_shape
def prepare_output(outputs, orig_shape, outdir, revisit=1, use_pose=True):
"""
Process inference outputs to generate point clouds and camera parameters for visualization.
Args:
outputs (dict): Inference outputs.
revisit (int): Number of revisits per view.
use_pose (bool): Whether to transform points using camera pose.
Returns:
tuple: (points, colors, confidence, camera parameters dictionary)
"""
from cut3r.dust3r.post_process import estimate_focal_knowing_depth
from cut3r.dust3r.utils.camera import pose_encoding_to_camera
from cut3r.dust3r.utils.geometry import geotrf
# Only keep the outputs corresponding to one full pass.
valid_length = len(outputs["pred"]) // revisit
outputs["pred"] = outputs["pred"][-valid_length:]
outputs["views"] = outputs["views"][-valid_length:]
pts3ds_self_ls = [output["pts3d_in_self_view"].cpu() for output in outputs["pred"]]
pts3ds_other = [output["pts3d_in_other_view"].cpu() for output in outputs["pred"]]
conf_self = [output["conf_self"].cpu() for output in outputs["pred"]]
conf_other = [output["conf"].cpu() for output in outputs["pred"]]
pts3ds_self = torch.cat(pts3ds_self_ls, 0)
# Recover camera poses.
pr_poses = [
pose_encoding_to_camera(pred["camera_pose"].clone()).cpu()
for pred in outputs["pred"]
]
R_c2w = torch.cat([pr_pose[:, :3, :3] for pr_pose in pr_poses], 0)
t_c2w = torch.cat([pr_pose[:, :3, 3] for pr_pose in pr_poses], 0)
if use_pose:
transformed_pts3ds_other = []
for pose, pself in zip(pr_poses, pts3ds_self):
transformed_pts3ds_other.append(geotrf(pose, pself.unsqueeze(0)))
pts3ds_other = transformed_pts3ds_other
conf_other = conf_self
# Estimate focal length based on depth.
B, H, W, _ = pts3ds_self.shape
orig_H, orig_W = orig_shape
pp = torch.tensor([orig_W // 2, orig_H // 2], device=pts3ds_self.device).float().repeat(B, 1)
focal = estimate_focal_knowing_depth(pts3ds_self, pp, focal_mode="weiszfeld")
# focal = focal.mean().repeat(len(focal))
focal_x = focal * orig_W / W
focal_y = focal * orig_H / H
colors = [
0.5 * (output["img"].permute(0, 2, 3, 1) + 1.0) for output in outputs["views"]
]
cam_dict = {
"focal": focal.cpu().numpy(),
"pp": pp.cpu().numpy(),
"R": R_c2w.cpu().numpy(),
"t": t_c2w.cpu().numpy(),
}
cam2world_tosave = torch.cat(pr_poses) # B, 4, 4
intrinsics_tosave = (
torch.eye(3).unsqueeze(0).repeat(cam2world_tosave.shape[0], 1, 1)
) # B, 3, 3
intrinsics_tosave[:, 0, 0] = focal_x.detach().cpu()
intrinsics_tosave[:, 1, 1] = focal_y.detach().cpu()
intrinsics_tosave[:, 0, 2] = pp[:, 0]
intrinsics_tosave[:, 1, 2] = pp[:, 1]
os.makedirs(os.path.join(outdir, "camera"), exist_ok=True)
for f_id in range(len(cam2world_tosave)):
c2w = cam2world_tosave[f_id].cpu().numpy()
intrins = intrinsics_tosave[f_id].cpu().numpy()
np.savez(
os.path.join(outdir, "camera", f"{f_id+1:04d}.npz"),
pose=c2w,
intrinsics=intrins,
)
return pts3ds_other, colors, conf_other, cam_dict