Spaces:
Runtime error
Runtime error
| # Copyright (C) 2024-present Naver Corporation. All rights reserved. | |
| # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). | |
| # | |
| # -------------------------------------------------------- | |
| # Base class for the global alignement procedure | |
| # -------------------------------------------------------- | |
| from copy import deepcopy | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import roma | |
| from copy import deepcopy | |
| import tqdm | |
| from dust3r.utils.geometry import inv, geotrf | |
| from dust3r.utils.device import to_numpy | |
| from dust3r.utils.image import rgb | |
| from dust3r.viz import SceneViz, segment_sky, auto_cam_size | |
| from dust3r.optim_factory import adjust_learning_rate_by_lr | |
| from dust3r.cloud_opt.commons import (edge_str, ALL_DISTS, NoGradParamDict, get_imshapes, signed_expm1, signed_log1p, | |
| cosine_schedule, linear_schedule, get_conf_trf) | |
| import dust3r.cloud_opt.init_im_poses as init_fun | |
| class BasePCOptimizer (nn.Module): | |
| """ Optimize a global scene, given a list of pairwise observations. | |
| Graph node: images | |
| Graph edges: observations = (pred1, pred2) | |
| """ | |
| def __init__(self, *args, **kwargs): | |
| if len(args) == 1 and len(kwargs) == 0: | |
| other = deepcopy(args[0]) | |
| attrs = '''edges is_symmetrized dist n_imgs pred_i pred_j imshapes | |
| min_conf_thr conf_thr conf_i conf_j im_conf | |
| base_scale norm_pw_scale POSE_DIM pw_poses | |
| pw_adaptors pw_adaptors has_im_poses rand_pose imgs verbose'''.split() | |
| self.__dict__.update({k: other[k] for k in attrs}) | |
| else: | |
| self._init_from_views(*args, **kwargs) | |
| def _init_from_views(self, view1, view2, pred1, pred2, | |
| dist='l1', | |
| conf='log', | |
| min_conf_thr=3, | |
| base_scale=0.5, | |
| allow_pw_adaptors=False, | |
| pw_break=20, | |
| rand_pose=torch.randn, | |
| iterationsCount=None, | |
| same_focals=False, | |
| verbose=True): | |
| super().__init__() | |
| if not isinstance(view1['idx'], list): | |
| view1['idx'] = view1['idx'].tolist() | |
| if not isinstance(view2['idx'], list): | |
| view2['idx'] = view2['idx'].tolist() | |
| self.edges = [(int(i), int(j)) for i, j in zip(view1['idx'], view2['idx'])] | |
| self.is_symmetrized = set(self.edges) == {(j, i) for i, j in self.edges} | |
| self.dist = ALL_DISTS[dist] | |
| self.verbose = verbose | |
| self.same_focals = same_focals | |
| self.n_imgs = self._check_edges() | |
| # input data | |
| pred1_pts = pred1['pts3d'] | |
| pred2_pts = pred2['pts3d_in_other_view'] | |
| self.pred_i = NoGradParamDict({ij: pred1_pts[n] for n, ij in enumerate(self.str_edges)}) | |
| self.pred_j = NoGradParamDict({ij: pred2_pts[n] for n, ij in enumerate(self.str_edges)}) | |
| self.imshapes = get_imshapes(self.edges, pred1_pts, pred2_pts) | |
| # work in log-scale with conf | |
| pred1_conf = pred1['conf'] | |
| pred2_conf = pred2['conf'] | |
| self.min_conf_thr = min_conf_thr | |
| self.conf_trf = get_conf_trf(conf) | |
| self.conf_i = NoGradParamDict({ij: pred1_conf[n] for n, ij in enumerate(self.str_edges)}) | |
| self.conf_j = NoGradParamDict({ij: pred2_conf[n] for n, ij in enumerate(self.str_edges)}) | |
| self.im_conf = self._compute_img_conf(pred1_conf, pred2_conf) | |
| # pairwise pose parameters | |
| self.base_scale = base_scale | |
| self.norm_pw_scale = True | |
| self.pw_break = pw_break | |
| self.POSE_DIM = 7 | |
| self.pw_poses = nn.Parameter(rand_pose((self.n_edges, 1+self.POSE_DIM))) # pairwise poses | |
| self.pw_adaptors = nn.Parameter(torch.zeros((self.n_edges, 2))) # slight xy/z adaptation | |
| self.pw_adaptors.requires_grad_(allow_pw_adaptors) | |
| self.has_im_poses = False | |
| self.rand_pose = rand_pose | |
| # possibly store images for show_pointcloud | |
| self.imgs = None | |
| if 'img' in view1 and 'img' in view2: | |
| imgs = [torch.zeros((3,)+hw) for hw in self.imshapes] | |
| for v in range(len(self.edges)): | |
| idx = view1['idx'][v] | |
| imgs[idx] = view1['img'][v] | |
| idx = view2['idx'][v] | |
| imgs[idx] = view2['img'][v] | |
| self.imgs = rgb(imgs) | |
| # TODO for vis pose | |
| self.vis_poses = None | |
| self.vis_pts3d = None | |
| def n_edges(self): | |
| return len(self.edges) | |
| def str_edges(self): | |
| return [edge_str(i, j) for i, j in self.edges] | |
| def imsizes(self): | |
| return [(w, h) for h, w in self.imshapes] | |
| def device(self): | |
| return next(iter(self.parameters())).device | |
| def state_dict(self, trainable=True): | |
| all_params = super().state_dict() | |
| return {k: v for k, v in all_params.items() if k.startswith(('_', 'pred_i.', 'pred_j.', 'conf_i.', 'conf_j.')) != trainable} | |
| def load_state_dict(self, data): | |
| return super().load_state_dict(self.state_dict(trainable=False) | data) | |
| def _check_edges(self): | |
| indices = sorted({i for edge in self.edges for i in edge}) | |
| assert indices == list(range(len(indices))), 'bad pair indices: missing values ' | |
| return len(indices) | |
| def _compute_img_conf(self, pred1_conf, pred2_conf): | |
| im_conf = nn.ParameterList([torch.zeros(hw, device=self.device) for hw in self.imshapes]) | |
| for e, (i, j) in enumerate(self.edges): | |
| im_conf[i] = torch.maximum(im_conf[i], pred1_conf[e]) | |
| im_conf[j] = torch.maximum(im_conf[j], pred2_conf[e]) | |
| return im_conf | |
| def get_adaptors(self): | |
| adapt = self.pw_adaptors | |
| adapt = torch.cat((adapt[:, 0:1], adapt), dim=-1) # (scale_xy, scale_xy, scale_z) | |
| if self.norm_pw_scale: # normalize so that the product == 1 | |
| adapt = adapt - adapt.mean(dim=1, keepdim=True) | |
| return (adapt / self.pw_break).exp() | |
| def _get_poses(self, poses): | |
| # normalize rotation | |
| Q = poses[:, :4] | |
| T = signed_expm1(poses[:, 4:7]) | |
| RT = roma.RigidUnitQuat(Q, T).normalize().to_homogeneous() | |
| return RT | |
| def _set_pose(self, poses, idx, R, T=None, scale=None, force=False): | |
| # all poses == cam-to-world | |
| pose = poses[idx] | |
| if not (pose.requires_grad or force): | |
| return pose | |
| if R.shape == (4, 4): | |
| assert T is None | |
| T = R[:3, 3] | |
| R = R[:3, :3] | |
| if R is not None: | |
| pose.data[0:4] = roma.rotmat_to_unitquat(R) | |
| if T is not None: | |
| pose.data[4:7] = signed_log1p(T / (scale or 1)) # translation is function of scale | |
| if scale is not None: | |
| assert poses.shape[-1] in (8, 13) | |
| pose.data[-1] = np.log(float(scale)) | |
| return pose | |
| def get_pw_norm_scale_factor(self): | |
| if self.norm_pw_scale: | |
| # normalize scales so that things cannot go south | |
| # we want that exp(scale) ~= self.base_scale | |
| return (np.log(self.base_scale) - self.pw_poses[:, -1].mean()).exp() | |
| else: | |
| return 1 # don't norm scale for known poses | |
| def get_pw_scale(self): | |
| scale = self.pw_poses[:, -1].exp() # (n_edges,) | |
| scale = scale * self.get_pw_norm_scale_factor() | |
| return scale | |
| def get_pw_poses(self): # cam to world | |
| RT = self._get_poses(self.pw_poses) | |
| scaled_RT = RT.clone() | |
| scaled_RT[:, :3] *= self.get_pw_scale().view(-1, 1, 1) # scale the rotation AND translation | |
| return scaled_RT | |
| def get_masks(self): | |
| return [(conf > self.min_conf_thr) for conf in self.im_conf] | |
| def depth_to_pts3d(self): | |
| raise NotImplementedError() | |
| def get_pts3d(self, raw=False): | |
| res = self.depth_to_pts3d() | |
| if not raw: | |
| res = [dm[:h*w].view(h, w, 3) for dm, (h, w) in zip(res, self.imshapes)] | |
| return res | |
| def _set_focal(self, idx, focal, force=False): | |
| raise NotImplementedError() | |
| def get_focals(self): | |
| raise NotImplementedError() | |
| def get_known_focal_mask(self): | |
| raise NotImplementedError() | |
| def get_principal_points(self): | |
| raise NotImplementedError() | |
| def get_conf(self, mode=None): | |
| trf = self.conf_trf if mode is None else get_conf_trf(mode) | |
| return [trf(c) for c in self.im_conf] | |
| def get_im_poses(self): | |
| raise NotImplementedError() | |
| def _set_depthmap(self, idx, depth, force=False): | |
| raise NotImplementedError() | |
| def get_depthmaps(self, raw=False): | |
| raise NotImplementedError() | |
| def clean_pointcloud(self, tol=0.001, max_bad_conf=0): | |
| """ Method: | |
| 1) express all 3d points in each camera coordinate frame | |
| 2) if they're in front of a depthmap --> then lower their confidence | |
| """ | |
| assert 0 <= tol < 1 | |
| cams = inv(self.get_im_poses()) | |
| Ks = self.get_intrinsics() | |
| depthmaps = self.get_depthmaps() | |
| res = deepcopy(self) | |
| for i, pts3d in enumerate(self.depth_to_pts3d()): | |
| for j in range(self.n_imgs): | |
| if self.same_focals: | |
| K = Ks[0] | |
| else: | |
| K = Ks[j] | |
| if i == j: | |
| continue | |
| # project 3dpts in other view | |
| Hi, Wi = self.imshapes[i] | |
| Hj, Wj = self.imshapes[j] | |
| proj = geotrf(cams[j], pts3d[:Hi*Wi]).reshape(Hi, Wi, 3) | |
| proj_depth = proj[:, :, 2] | |
| u, v = geotrf(K, proj, norm=1, ncol=2).round().long().unbind(-1) | |
| # check which points are actually in the visible cone | |
| msk_i = (proj_depth > 0) & (0 <= u) & (u < Wj) & (0 <= v) & (v < Hj) | |
| msk_j = v[msk_i], u[msk_i] | |
| # find bad points = those in front but less confident | |
| bad_points = (proj_depth[msk_i] < (1-tol) * depthmaps[j][msk_j] | |
| ) & (res.im_conf[i][msk_i] < res.im_conf[j][msk_j]) | |
| bad_msk_i = msk_i.clone() | |
| bad_msk_i[msk_i] = bad_points | |
| res.im_conf[i][bad_msk_i] = res.im_conf[i][bad_msk_i].clip_(max=max_bad_conf) | |
| return res | |
| def forward(self, ret_details=False): | |
| pw_poses = self.get_pw_poses() # cam-to-world | |
| pw_adapt = self.get_adaptors() | |
| proj_pts3d = self.get_pts3d() | |
| # pre-compute pixel weights | |
| weight_i = {i_j: self.conf_trf(c) for i_j, c in self.conf_i.items()} | |
| weight_j = {i_j: self.conf_trf(c) for i_j, c in self.conf_j.items()} | |
| loss = 0 | |
| if ret_details: | |
| details = -torch.ones((self.n_imgs, self.n_imgs)) | |
| for e, (i, j) in enumerate(self.edges): | |
| i_j = edge_str(i, j) | |
| # distance in image i and j | |
| aligned_pred_i = geotrf(pw_poses[e], pw_adapt[e] * self.pred_i[i_j]) | |
| aligned_pred_j = geotrf(pw_poses[e], pw_adapt[e] * self.pred_j[i_j]) | |
| li = self.dist(proj_pts3d[i], aligned_pred_i, weight=weight_i[i_j]).mean() | |
| lj = self.dist(proj_pts3d[j], aligned_pred_j, weight=weight_j[i_j]).mean() | |
| loss = loss + li + lj | |
| if ret_details: | |
| details[i, j] = li + lj | |
| loss /= self.n_edges # average over all pairs | |
| if ret_details: | |
| return loss, details | |
| return loss | |
| def compute_global_alignment(self, init=None, niter_PnP=10, **kw): | |
| if init is None: | |
| pass | |
| elif init == 'msp' or init == 'mst': | |
| init_fun.init_minimum_spanning_tree(self, niter_PnP=niter_PnP) | |
| elif init == 'known_poses': | |
| init_fun.init_from_known_poses(self, min_conf_thr=self.min_conf_thr, | |
| niter_PnP=niter_PnP) | |
| else: | |
| raise ValueError(f'bad value for {init=}') | |
| return global_alignment_loop(self, **kw) | |
| def mask_sky(self): | |
| res = deepcopy(self) | |
| for i in range(self.n_imgs): | |
| sky = segment_sky(self.imgs[i]) | |
| res.im_conf[i][sky] = 0 | |
| return res | |
| def show(self, show_pw_cams=False, show_pw_pts3d=False, cam_size=None, **kw): | |
| viz = SceneViz() | |
| if self.imgs is None: | |
| colors = np.random.randint(0, 256, size=(self.n_imgs, 3)) | |
| colors = list(map(tuple, colors.tolist())) | |
| for n in range(self.n_imgs): | |
| viz.add_pointcloud(self.get_pts3d()[n], colors[n], self.get_masks()[n]) | |
| else: | |
| viz.add_pointcloud(self.get_pts3d(), self.imgs, self.get_masks()) | |
| colors = np.random.randint(256, size=(self.n_imgs, 3)) | |
| # camera poses | |
| im_poses = to_numpy(self.get_im_poses()) | |
| if cam_size is None: | |
| cam_size = auto_cam_size(im_poses) | |
| viz.add_cameras(im_poses, self.get_focals(), colors=colors, | |
| images=self.imgs, imsizes=self.imsizes, cam_size=cam_size) | |
| if show_pw_cams: | |
| pw_poses = self.get_pw_poses() | |
| viz.add_cameras(pw_poses, color=(192, 0, 192), cam_size=cam_size) | |
| if show_pw_pts3d: | |
| pts = [geotrf(pw_poses[e], self.pred_i[edge_str(i, j)]) for e, (i, j) in enumerate(self.edges)] | |
| viz.add_pointcloud(pts, (128, 0, 128)) | |
| viz.show(**kw) | |
| return viz | |
| def global_alignment_loop(net, lr=0.01, niter=300, schedule='cosine', lr_min=1e-6): | |
| params = [p for p in net.parameters() if p.requires_grad] | |
| if not params: | |
| return net | |
| verbose = net.verbose | |
| if verbose: | |
| print('Global alignement - optimizing for:') | |
| print([name for name, value in net.named_parameters() if value.requires_grad]) | |
| lr_base = lr | |
| optimizer = torch.optim.Adam(params, lr=lr, betas=(0.9, 0.9)) | |
| loss = float('inf') | |
| if verbose: | |
| with tqdm.tqdm(total=niter) as bar: | |
| while bar.n < bar.total: | |
| loss = global_alignment_iter(net, bar.n, niter, lr_base, lr_min, optimizer, schedule) | |
| bar.set_postfix_str(f'{lr=:g} loss={loss:g}') | |
| bar.update() | |
| else: | |
| for n in range(niter): | |
| loss = global_alignment_iter(net, n, niter, lr_base, lr_min, optimizer, schedule) | |
| return loss | |
| def global_alignment_iter(net, cur_iter, niter, lr_base, lr_min, optimizer, schedule): | |
| t = cur_iter / niter | |
| if schedule == 'cosine': | |
| lr = cosine_schedule(t, lr_base, lr_min) | |
| elif schedule == 'linear': | |
| lr = linear_schedule(t, lr_base, lr_min) | |
| else: | |
| raise ValueError(f'bad lr {schedule=}') | |
| adjust_learning_rate_by_lr(optimizer, lr) | |
| optimizer.zero_grad() | |
| loss = net() | |
| loss.backward() | |
| optimizer.step() | |
| return float(loss) | |