Spaces:
Running
Running
File size: 7,561 Bytes
684943d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 |
import os
from copy import deepcopy
import numpy as np
import open3d as o3d
import torch
def storePly(path, xyz, rgb):
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(xyz)
pcd.colors = o3d.utility.Vector3dVector(rgb)
o3d.io.write_point_cloud(path, pcd)
def prepare_input(
img_paths, img_mask, size, raymaps=None, raymap_mask=None, revisit=1, update=True
):
"""
Prepare input views for inference from a list of image paths.
Args:
img_paths (list): List of image file paths.
img_mask (list of bool): Flags indicating valid images.
size (int): Target image size.
raymaps (list, optional): List of ray maps.
raymap_mask (list, optional): Flags indicating valid ray maps.
revisit (int): How many times to revisit each view.
update (bool): Whether to update the state on revisits.
Returns:
list: A list of view dictionaries.
"""
# Import image loader (delayed import needed after adding ckpt path).
from cut3r.dust3r.utils.image import load_images
images, orig_shape = load_images(img_paths, size=size)
views = []
if raymaps is None and raymap_mask is None:
# Only images are provided.
for i in range(len(images)):
view = {
"img": images[i]["img"],
"ray_map": torch.full(
(
images[i]["img"].shape[0],
6,
images[i]["img"].shape[-2],
images[i]["img"].shape[-1],
),
torch.nan,
),
"true_shape": torch.from_numpy(images[i]["true_shape"]),
"idx": i,
"instance": str(i),
"camera_pose": torch.from_numpy(np.eye(4, dtype=np.float32)).unsqueeze(
0
),
"img_mask": torch.tensor(True).unsqueeze(0),
"ray_mask": torch.tensor(False).unsqueeze(0),
"update": torch.tensor(True).unsqueeze(0),
"reset": torch.tensor(False).unsqueeze(0),
}
views.append(view)
else:
# Combine images and raymaps.
num_views = len(images) + len(raymaps)
assert len(img_mask) == len(raymap_mask) == num_views
assert sum(img_mask) == len(images) and sum(raymap_mask) == len(raymaps)
j = 0
k = 0
for i in range(num_views):
view = {
"img": (
images[j]["img"]
if img_mask[i]
else torch.full_like(images[0]["img"], torch.nan)
),
"ray_map": (
raymaps[k]
if raymap_mask[i]
else torch.full_like(raymaps[0], torch.nan)
),
"true_shape": (
torch.from_numpy(images[j]["true_shape"])
if img_mask[i]
else torch.from_numpy(np.int32([raymaps[k].shape[1:-1][::-1]]))
),
"idx": i,
"instance": str(i),
"camera_pose": torch.from_numpy(np.eye(4, dtype=np.float32)).unsqueeze(
0
),
"img_mask": torch.tensor(img_mask[i]).unsqueeze(0),
"ray_mask": torch.tensor(raymap_mask[i]).unsqueeze(0),
"update": torch.tensor(img_mask[i]).unsqueeze(0),
"reset": torch.tensor(False).unsqueeze(0),
}
if img_mask[i]:
j += 1
if raymap_mask[i]:
k += 1
views.append(view)
assert j == len(images) and k == len(raymaps)
if revisit > 1:
new_views = []
for r in range(revisit):
for i, view in enumerate(views):
new_view = deepcopy(view)
new_view["idx"] = r * len(views) + i
new_view["instance"] = str(r * len(views) + i)
if r > 0 and not update:
new_view["update"] = torch.tensor(False).unsqueeze(0)
new_views.append(new_view)
return new_views
return views, orig_shape
def prepare_output(outputs, orig_shape, outdir, revisit=1, use_pose=True):
"""
Process inference outputs to generate point clouds and camera parameters for visualization.
Args:
outputs (dict): Inference outputs.
revisit (int): Number of revisits per view.
use_pose (bool): Whether to transform points using camera pose.
Returns:
tuple: (points, colors, confidence, camera parameters dictionary)
"""
from cut3r.dust3r.post_process import estimate_focal_knowing_depth
from cut3r.dust3r.utils.camera import pose_encoding_to_camera
from cut3r.dust3r.utils.geometry import geotrf
# Only keep the outputs corresponding to one full pass.
valid_length = len(outputs["pred"]) // revisit
outputs["pred"] = outputs["pred"][-valid_length:]
outputs["views"] = outputs["views"][-valid_length:]
pts3ds_self_ls = [output["pts3d_in_self_view"].cpu() for output in outputs["pred"]]
pts3ds_other = [output["pts3d_in_other_view"].cpu() for output in outputs["pred"]]
conf_self = [output["conf_self"].cpu() for output in outputs["pred"]]
conf_other = [output["conf"].cpu() for output in outputs["pred"]]
pts3ds_self = torch.cat(pts3ds_self_ls, 0)
# Recover camera poses.
pr_poses = [
pose_encoding_to_camera(pred["camera_pose"].clone()).cpu()
for pred in outputs["pred"]
]
R_c2w = torch.cat([pr_pose[:, :3, :3] for pr_pose in pr_poses], 0)
t_c2w = torch.cat([pr_pose[:, :3, 3] for pr_pose in pr_poses], 0)
if use_pose:
transformed_pts3ds_other = []
for pose, pself in zip(pr_poses, pts3ds_self):
transformed_pts3ds_other.append(geotrf(pose, pself.unsqueeze(0)))
pts3ds_other = transformed_pts3ds_other
conf_other = conf_self
# Estimate focal length based on depth.
B, H, W, _ = pts3ds_self.shape
orig_H, orig_W = orig_shape
pp = torch.tensor([orig_W // 2, orig_H // 2], device=pts3ds_self.device).float().repeat(B, 1)
focal = estimate_focal_knowing_depth(pts3ds_self, pp, focal_mode="weiszfeld")
# focal = focal.mean().repeat(len(focal))
focal_x = focal * orig_W / W
focal_y = focal * orig_H / H
colors = [
0.5 * (output["img"].permute(0, 2, 3, 1) + 1.0) for output in outputs["views"]
]
cam_dict = {
"focal": focal.cpu().numpy(),
"pp": pp.cpu().numpy(),
"R": R_c2w.cpu().numpy(),
"t": t_c2w.cpu().numpy(),
}
cam2world_tosave = torch.cat(pr_poses) # B, 4, 4
intrinsics_tosave = (
torch.eye(3).unsqueeze(0).repeat(cam2world_tosave.shape[0], 1, 1)
) # B, 3, 3
intrinsics_tosave[:, 0, 0] = focal_x.detach().cpu()
intrinsics_tosave[:, 1, 1] = focal_y.detach().cpu()
intrinsics_tosave[:, 0, 2] = pp[:, 0]
intrinsics_tosave[:, 1, 2] = pp[:, 1]
os.makedirs(os.path.join(outdir, "camera"), exist_ok=True)
for f_id in range(len(cam2world_tosave)):
c2w = cam2world_tosave[f_id].cpu().numpy()
intrins = intrinsics_tosave[f_id].cpu().numpy()
np.savez(
os.path.join(outdir, "camera", f"{f_id+1:04d}.npz"),
pose=c2w,
intrinsics=intrins,
)
return pts3ds_other, colors, conf_other, cam_dict
|