import os from copy import deepcopy import numpy as np import open3d as o3d import torch def storePly(path, xyz, rgb): pcd = o3d.geometry.PointCloud() pcd.points = o3d.utility.Vector3dVector(xyz) pcd.colors = o3d.utility.Vector3dVector(rgb) o3d.io.write_point_cloud(path, pcd) def prepare_input( img_paths, img_mask, size, raymaps=None, raymap_mask=None, revisit=1, update=True ): """ Prepare input views for inference from a list of image paths. Args: img_paths (list): List of image file paths. img_mask (list of bool): Flags indicating valid images. size (int): Target image size. raymaps (list, optional): List of ray maps. raymap_mask (list, optional): Flags indicating valid ray maps. revisit (int): How many times to revisit each view. update (bool): Whether to update the state on revisits. Returns: list: A list of view dictionaries. """ # Import image loader (delayed import needed after adding ckpt path). from cut3r.dust3r.utils.image import load_images images, orig_shape = load_images(img_paths, size=size) views = [] if raymaps is None and raymap_mask is None: # Only images are provided. for i in range(len(images)): view = { "img": images[i]["img"], "ray_map": torch.full( ( images[i]["img"].shape[0], 6, images[i]["img"].shape[-2], images[i]["img"].shape[-1], ), torch.nan, ), "true_shape": torch.from_numpy(images[i]["true_shape"]), "idx": i, "instance": str(i), "camera_pose": torch.from_numpy(np.eye(4, dtype=np.float32)).unsqueeze( 0 ), "img_mask": torch.tensor(True).unsqueeze(0), "ray_mask": torch.tensor(False).unsqueeze(0), "update": torch.tensor(True).unsqueeze(0), "reset": torch.tensor(False).unsqueeze(0), } views.append(view) else: # Combine images and raymaps. num_views = len(images) + len(raymaps) assert len(img_mask) == len(raymap_mask) == num_views assert sum(img_mask) == len(images) and sum(raymap_mask) == len(raymaps) j = 0 k = 0 for i in range(num_views): view = { "img": ( images[j]["img"] if img_mask[i] else torch.full_like(images[0]["img"], torch.nan) ), "ray_map": ( raymaps[k] if raymap_mask[i] else torch.full_like(raymaps[0], torch.nan) ), "true_shape": ( torch.from_numpy(images[j]["true_shape"]) if img_mask[i] else torch.from_numpy(np.int32([raymaps[k].shape[1:-1][::-1]])) ), "idx": i, "instance": str(i), "camera_pose": torch.from_numpy(np.eye(4, dtype=np.float32)).unsqueeze( 0 ), "img_mask": torch.tensor(img_mask[i]).unsqueeze(0), "ray_mask": torch.tensor(raymap_mask[i]).unsqueeze(0), "update": torch.tensor(img_mask[i]).unsqueeze(0), "reset": torch.tensor(False).unsqueeze(0), } if img_mask[i]: j += 1 if raymap_mask[i]: k += 1 views.append(view) assert j == len(images) and k == len(raymaps) if revisit > 1: new_views = [] for r in range(revisit): for i, view in enumerate(views): new_view = deepcopy(view) new_view["idx"] = r * len(views) + i new_view["instance"] = str(r * len(views) + i) if r > 0 and not update: new_view["update"] = torch.tensor(False).unsqueeze(0) new_views.append(new_view) return new_views return views, orig_shape def prepare_output(outputs, orig_shape, outdir, revisit=1, use_pose=True): """ Process inference outputs to generate point clouds and camera parameters for visualization. Args: outputs (dict): Inference outputs. revisit (int): Number of revisits per view. use_pose (bool): Whether to transform points using camera pose. Returns: tuple: (points, colors, confidence, camera parameters dictionary) """ from cut3r.dust3r.post_process import estimate_focal_knowing_depth from cut3r.dust3r.utils.camera import pose_encoding_to_camera from cut3r.dust3r.utils.geometry import geotrf # Only keep the outputs corresponding to one full pass. valid_length = len(outputs["pred"]) // revisit outputs["pred"] = outputs["pred"][-valid_length:] outputs["views"] = outputs["views"][-valid_length:] pts3ds_self_ls = [output["pts3d_in_self_view"].cpu() for output in outputs["pred"]] pts3ds_other = [output["pts3d_in_other_view"].cpu() for output in outputs["pred"]] conf_self = [output["conf_self"].cpu() for output in outputs["pred"]] conf_other = [output["conf"].cpu() for output in outputs["pred"]] pts3ds_self = torch.cat(pts3ds_self_ls, 0) # Recover camera poses. pr_poses = [ pose_encoding_to_camera(pred["camera_pose"].clone()).cpu() for pred in outputs["pred"] ] R_c2w = torch.cat([pr_pose[:, :3, :3] for pr_pose in pr_poses], 0) t_c2w = torch.cat([pr_pose[:, :3, 3] for pr_pose in pr_poses], 0) if use_pose: transformed_pts3ds_other = [] for pose, pself in zip(pr_poses, pts3ds_self): transformed_pts3ds_other.append(geotrf(pose, pself.unsqueeze(0))) pts3ds_other = transformed_pts3ds_other conf_other = conf_self # Estimate focal length based on depth. B, H, W, _ = pts3ds_self.shape orig_H, orig_W = orig_shape pp = torch.tensor([orig_W // 2, orig_H // 2], device=pts3ds_self.device).float().repeat(B, 1) focal = estimate_focal_knowing_depth(pts3ds_self, pp, focal_mode="weiszfeld") # focal = focal.mean().repeat(len(focal)) focal_x = focal * orig_W / W focal_y = focal * orig_H / H colors = [ 0.5 * (output["img"].permute(0, 2, 3, 1) + 1.0) for output in outputs["views"] ] cam_dict = { "focal": focal.cpu().numpy(), "pp": pp.cpu().numpy(), "R": R_c2w.cpu().numpy(), "t": t_c2w.cpu().numpy(), } cam2world_tosave = torch.cat(pr_poses) # B, 4, 4 intrinsics_tosave = ( torch.eye(3).unsqueeze(0).repeat(cam2world_tosave.shape[0], 1, 1) ) # B, 3, 3 intrinsics_tosave[:, 0, 0] = focal_x.detach().cpu() intrinsics_tosave[:, 1, 1] = focal_y.detach().cpu() intrinsics_tosave[:, 0, 2] = pp[:, 0] intrinsics_tosave[:, 1, 2] = pp[:, 1] os.makedirs(os.path.join(outdir, "camera"), exist_ok=True) for f_id in range(len(cam2world_tosave)): c2w = cam2world_tosave[f_id].cpu().numpy() intrins = intrinsics_tosave[f_id].cpu().numpy() np.savez( os.path.join(outdir, "camera", f"{f_id+1:04d}.npz"), pose=c2w, intrinsics=intrins, ) return pts3ds_other, colors, conf_other, cam_dict