Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| # Copyright (C) 2024-present Naver Corporation. All rights reserved. | |
| # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). | |
| # | |
| # -------------------------------------------------------- | |
| # visloc script with support for coarse to fine | |
| # -------------------------------------------------------- | |
| import os | |
| import numpy as np | |
| import random | |
| import torch | |
| import torchvision.transforms as tvf | |
| import argparse | |
| from tqdm import tqdm | |
| from PIL import Image | |
| import math | |
| from mast3r.model import AsymmetricMASt3R | |
| from mast3r.fast_nn import fast_reciprocal_NNs | |
| from mast3r.utils.coarse_to_fine import select_pairs_of_crops, crop_slice | |
| from mast3r.utils.collate import cat_collate, cat_collate_fn_map | |
| from mast3r.utils.misc import mkdir_for | |
| from mast3r.datasets.utils.cropping import crop_to_homography | |
| import mast3r.utils.path_to_dust3r # noqa | |
| from dust3r.inference import inference, loss_of_one_batch | |
| from dust3r.utils.geometry import geotrf, colmap_to_opencv_intrinsics, opencv_to_colmap_intrinsics | |
| from dust3r.datasets.utils.transforms import ImgNorm | |
| from dust3r_visloc.datasets import * | |
| from dust3r_visloc.localization import run_pnp | |
| from dust3r_visloc.evaluation import get_pose_error, aggregate_stats, export_results | |
| from dust3r_visloc.datasets.utils import get_HW_resolution, rescale_points3d | |
| def get_args_parser(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--dataset", type=str, required=True, help="visloc dataset to eval") | |
| parser_weights = parser.add_mutually_exclusive_group(required=True) | |
| parser_weights.add_argument("--weights", type=str, help="path to the model weights", default=None) | |
| parser_weights.add_argument("--model_name", type=str, help="name of the model weights", | |
| choices=["MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"]) | |
| parser.add_argument("--confidence_threshold", type=float, default=1.001, | |
| help="confidence values higher than threshold are invalid") | |
| parser.add_argument('--pixel_tol', default=5, type=int) | |
| parser.add_argument("--coarse_to_fine", action='store_true', default=False, | |
| help="do the matching from coarse to fine") | |
| parser.add_argument("--max_image_size", type=int, default=None, | |
| help="max image size for the fine resolution") | |
| parser.add_argument("--c2f_crop_with_homography", action='store_true', default=False, | |
| help="when using coarse to fine, crop with homographies to keep cx, cy centered") | |
| parser.add_argument("--device", type=str, default='cuda', help="pytorch device") | |
| parser.add_argument("--pnp_mode", type=str, default="cv2", choices=['cv2', 'poselib', 'pycolmap'], | |
| help="pnp lib to use") | |
| parser_reproj = parser.add_mutually_exclusive_group() | |
| parser_reproj.add_argument("--reprojection_error", type=float, default=5.0, help="pnp reprojection error") | |
| parser_reproj.add_argument("--reprojection_error_diag_ratio", type=float, default=None, | |
| help="pnp reprojection error as a ratio of the diagonal of the image") | |
| parser.add_argument("--max_batch_size", type=int, default=48, | |
| help="max batch size for inference on crops when using coarse to fine") | |
| parser.add_argument("--pnp_max_points", type=int, default=100_000, help="pnp maximum number of points kept") | |
| parser.add_argument("--viz_matches", type=int, default=0, help="debug matches") | |
| parser.add_argument("--output_dir", type=str, default=None, help="output path") | |
| parser.add_argument("--output_label", type=str, default='', help="prefix for results files") | |
| return parser | |
| def coarse_matching(query_view, map_view, model, device, pixel_tol, fast_nn_params): | |
| # prepare batch | |
| imgs = [] | |
| for idx, img in enumerate([query_view['rgb_rescaled'], map_view['rgb_rescaled']]): | |
| imgs.append(dict(img=img.unsqueeze(0), true_shape=np.int32([img.shape[1:]]), | |
| idx=idx, instance=str(idx))) | |
| output = inference([tuple(imgs)], model, device, batch_size=1, verbose=False) | |
| pred1, pred2 = output['pred1'], output['pred2'] | |
| conf_list = [pred1['desc_conf'].squeeze(0).cpu().numpy(), pred2['desc_conf'].squeeze(0).cpu().numpy()] | |
| desc_list = [pred1['desc'].squeeze(0).detach(), pred2['desc'].squeeze(0).detach()] | |
| # find 2D-2D matches between the two images | |
| PQ, PM = desc_list[0], desc_list[1] | |
| if len(PQ) == 0 or len(PM) == 0: | |
| return [], [], [], [] | |
| if pixel_tol == 0: | |
| matches_im_map, matches_im_query = fast_reciprocal_NNs(PM, PQ, subsample_or_initxy1=8, **fast_nn_params) | |
| HM, WM = map_view['rgb_rescaled'].shape[1:] | |
| HQ, WQ = query_view['rgb_rescaled'].shape[1:] | |
| # ignore small border around the edge | |
| valid_matches_map = (matches_im_map[:, 0] >= 3) & (matches_im_map[:, 0] < WM - 3) & ( | |
| matches_im_map[:, 1] >= 3) & (matches_im_map[:, 1] < HM - 3) | |
| valid_matches_query = (matches_im_query[:, 0] >= 3) & (matches_im_query[:, 0] < WQ - 3) & ( | |
| matches_im_query[:, 1] >= 3) & (matches_im_query[:, 1] < HQ - 3) | |
| valid_matches = valid_matches_map & valid_matches_query | |
| matches_im_map = matches_im_map[valid_matches] | |
| matches_im_query = matches_im_query[valid_matches] | |
| valid_pts3d = [] | |
| matches_confs = [] | |
| else: | |
| yM, xM = torch.where(map_view['valid_rescaled']) | |
| matches_im_map, matches_im_query = fast_reciprocal_NNs(PM, PQ, (xM, yM), pixel_tol=pixel_tol, **fast_nn_params) | |
| valid_pts3d = map_view['pts3d_rescaled'].cpu().numpy()[matches_im_map[:, 1], matches_im_map[:, 0]] | |
| matches_confs = np.minimum( | |
| conf_list[1][matches_im_map[:, 1], matches_im_map[:, 0]], | |
| conf_list[0][matches_im_query[:, 1], matches_im_query[:, 0]] | |
| ) | |
| # from cv2 to colmap | |
| matches_im_query = matches_im_query.astype(np.float64) | |
| matches_im_map = matches_im_map.astype(np.float64) | |
| matches_im_query[:, 0] += 0.5 | |
| matches_im_query[:, 1] += 0.5 | |
| matches_im_map[:, 0] += 0.5 | |
| matches_im_map[:, 1] += 0.5 | |
| # rescale coordinates | |
| matches_im_query = geotrf(query_view['to_orig'], matches_im_query, norm=True) | |
| matches_im_map = geotrf(map_view['to_orig'], matches_im_map, norm=True) | |
| # from colmap back to cv2 | |
| matches_im_query[:, 0] -= 0.5 | |
| matches_im_query[:, 1] -= 0.5 | |
| matches_im_map[:, 0] -= 0.5 | |
| matches_im_map[:, 1] -= 0.5 | |
| return valid_pts3d, matches_im_query, matches_im_map, matches_confs | |
| def crops_inference(pairs, model, device, batch_size=48, verbose=True): | |
| assert len(pairs) == 2, "Error, data should be a tuple of dicts containing the batch of image pairs" | |
| # Forward a possibly big bunch of data, by blocks of batch_size | |
| B = pairs[0]['img'].shape[0] | |
| if B < batch_size: | |
| return loss_of_one_batch(pairs, model, None, device=device, symmetrize_batch=False) | |
| preds = [] | |
| for ii in range(0, B, batch_size): | |
| sel = slice(ii, ii + min(B - ii, batch_size)) | |
| temp_data = [{}, {}] | |
| for di in [0, 1]: | |
| temp_data[di] = {kk: pairs[di][kk][sel] | |
| for kk in pairs[di].keys() if pairs[di][kk] is not None} # copy chunk for forward | |
| preds.append(loss_of_one_batch(temp_data, model, | |
| None, device=device, symmetrize_batch=False)) # sequential forward | |
| # Merge all preds | |
| return cat_collate(preds, collate_fn_map=cat_collate_fn_map) | |
| def fine_matching(query_views, map_views, model, device, max_batch_size, pixel_tol, fast_nn_params): | |
| assert pixel_tol > 0 | |
| output = crops_inference([query_views, map_views], | |
| model, device, batch_size=max_batch_size, verbose=False) | |
| pred1, pred2 = output['pred1'], output['pred2'] | |
| descs1 = pred1['desc'].clone() | |
| descs2 = pred2['desc'].clone() | |
| confs1 = pred1['desc_conf'].clone() | |
| confs2 = pred2['desc_conf'].clone() | |
| # Compute matches | |
| valid_pts3d, matches_im_map, matches_im_query, matches_confs = [], [], [], [] | |
| for ppi, (pp1, pp2, cc11, cc21) in enumerate(zip(descs1, descs2, confs1, confs2)): | |
| valid_ppi = map_views['valid'][ppi] | |
| pts3d_ppi = map_views['pts3d'][ppi].cpu().numpy() | |
| conf_list_ppi = [cc11.cpu().numpy(), cc21.cpu().numpy()] | |
| y_ppi, x_ppi = torch.where(valid_ppi) | |
| matches_im_map_ppi, matches_im_query_ppi = fast_reciprocal_NNs(pp2, pp1, (x_ppi, y_ppi), | |
| pixel_tol=pixel_tol, **fast_nn_params) | |
| valid_pts3d_ppi = pts3d_ppi[matches_im_map_ppi[:, 1], matches_im_map_ppi[:, 0]] | |
| matches_confs_ppi = np.minimum( | |
| conf_list_ppi[1][matches_im_map_ppi[:, 1], matches_im_map_ppi[:, 0]], | |
| conf_list_ppi[0][matches_im_query_ppi[:, 1], matches_im_query_ppi[:, 0]] | |
| ) | |
| # inverse operation where we uncrop pixel coordinates | |
| matches_im_map_ppi = geotrf(map_views['to_orig'][ppi].cpu().numpy(), matches_im_map_ppi.copy(), norm=True) | |
| matches_im_query_ppi = geotrf(query_views['to_orig'][ppi].cpu().numpy(), matches_im_query_ppi.copy(), norm=True) | |
| matches_im_map.append(matches_im_map_ppi) | |
| matches_im_query.append(matches_im_query_ppi) | |
| valid_pts3d.append(valid_pts3d_ppi) | |
| matches_confs.append(matches_confs_ppi) | |
| if len(valid_pts3d) == 0: | |
| return [], [], [], [] | |
| matches_im_map = np.concatenate(matches_im_map, axis=0) | |
| matches_im_query = np.concatenate(matches_im_query, axis=0) | |
| valid_pts3d = np.concatenate(valid_pts3d, axis=0) | |
| matches_confs = np.concatenate(matches_confs, axis=0) | |
| return valid_pts3d, matches_im_query, matches_im_map, matches_confs | |
| def crop(img, mask, pts3d, crop, intrinsics=None): | |
| out_cropped_img = img.clone() | |
| if mask is not None: | |
| out_cropped_mask = mask.clone() | |
| else: | |
| out_cropped_mask = None | |
| if pts3d is not None: | |
| out_cropped_pts3d = pts3d.clone() | |
| else: | |
| out_cropped_pts3d = None | |
| to_orig = torch.eye(3, device=img.device) | |
| # If intrinsics available, crop and apply rectifying homography. Otherwise, just crop | |
| if intrinsics is not None: | |
| K_old = intrinsics | |
| imsize, K_new, R, H = crop_to_homography(K_old, crop) | |
| # apply homography to image | |
| H /= H[2, 2] | |
| homo8 = H.ravel().tolist()[:8] | |
| # From float tensor to uint8 PIL Image | |
| pilim = Image.fromarray((255 * (img + 1.) / 2).to(torch.uint8).numpy()) | |
| pilout_cropped_img = pilim.transform(imsize, Image.Transform.PERSPECTIVE, | |
| homo8, resample=Image.Resampling.BICUBIC) | |
| # From uint8 PIL Image to float tensor | |
| out_cropped_img = 2. * torch.tensor(np.array(pilout_cropped_img)).to(img) / 255. - 1. | |
| if out_cropped_mask is not None: | |
| pilmask = Image.fromarray((255 * out_cropped_mask).to(torch.uint8).numpy()) | |
| pilout_cropped_mask = pilmask.transform( | |
| imsize, Image.Transform.PERSPECTIVE, homo8, resample=Image.Resampling.NEAREST) | |
| out_cropped_mask = torch.from_numpy(np.array(pilout_cropped_mask) > 0).to(out_cropped_mask.dtype) | |
| if out_cropped_pts3d is not None: | |
| out_cropped_pts3d = out_cropped_pts3d.numpy() | |
| out_cropped_X = np.array(Image.fromarray(out_cropped_pts3d[:, :, 0]).transform(imsize, | |
| Image.Transform.PERSPECTIVE, | |
| homo8, | |
| resample=Image.Resampling.NEAREST)) | |
| out_cropped_Y = np.array(Image.fromarray(out_cropped_pts3d[:, :, 1]).transform(imsize, | |
| Image.Transform.PERSPECTIVE, | |
| homo8, | |
| resample=Image.Resampling.NEAREST)) | |
| out_cropped_Z = np.array(Image.fromarray(out_cropped_pts3d[:, :, 2]).transform(imsize, | |
| Image.Transform.PERSPECTIVE, | |
| homo8, | |
| resample=Image.Resampling.NEAREST)) | |
| out_cropped_pts3d = torch.from_numpy(np.stack([out_cropped_X, out_cropped_Y, out_cropped_Z], axis=-1)) | |
| to_orig = torch.tensor(H, device=img.device) | |
| else: | |
| out_cropped_img = img[crop_slice(crop)] | |
| if out_cropped_mask is not None: | |
| out_cropped_mask = out_cropped_mask[crop_slice(crop)] | |
| if out_cropped_pts3d is not None: | |
| out_cropped_pts3d = out_cropped_pts3d[crop_slice(crop)] | |
| to_orig[:2, -1] = torch.tensor(crop[:2]) | |
| return out_cropped_img, out_cropped_mask, out_cropped_pts3d, to_orig | |
| def resize_image_to_max(max_image_size, rgb, K): | |
| W, H = rgb.size | |
| if max_image_size and max(W, H) > max_image_size: | |
| islandscape = (W >= H) | |
| if islandscape: | |
| WMax = max_image_size | |
| HMax = int(H * (WMax / W)) | |
| else: | |
| HMax = max_image_size | |
| WMax = int(W * (HMax / H)) | |
| resize_op = tvf.Compose([ImgNorm, tvf.Resize(size=[HMax, WMax])]) | |
| rgb_tensor = resize_op(rgb).permute(1, 2, 0) | |
| to_orig_max = np.array([[W / WMax, 0, 0], | |
| [0, H / HMax, 0], | |
| [0, 0, 1]]) | |
| to_resize_max = np.array([[WMax / W, 0, 0], | |
| [0, HMax / H, 0], | |
| [0, 0, 1]]) | |
| # Generate new camera parameters | |
| new_K = opencv_to_colmap_intrinsics(K) | |
| new_K[0, :] *= WMax / W | |
| new_K[1, :] *= HMax / H | |
| new_K = colmap_to_opencv_intrinsics(new_K) | |
| else: | |
| rgb_tensor = ImgNorm(rgb).permute(1, 2, 0) | |
| to_orig_max = np.eye(3) | |
| to_resize_max = np.eye(3) | |
| HMax, WMax = H, W | |
| new_K = K | |
| return rgb_tensor, new_K, to_orig_max, to_resize_max, (HMax, WMax) | |
| if __name__ == '__main__': | |
| parser = get_args_parser() | |
| args = parser.parse_args() | |
| conf_thr = args.confidence_threshold | |
| device = args.device | |
| pnp_mode = args.pnp_mode | |
| assert args.pixel_tol > 0 | |
| reprojection_error = args.reprojection_error | |
| reprojection_error_diag_ratio = args.reprojection_error_diag_ratio | |
| pnp_max_points = args.pnp_max_points | |
| viz_matches = args.viz_matches | |
| if args.weights is not None: | |
| weights_path = args.weights | |
| else: | |
| weights_path = "naver/" + args.model_name | |
| model = AsymmetricMASt3R.from_pretrained(weights_path).to(args.device) | |
| fast_nn_params = dict(device=device, dist='dot', block_size=2**13) | |
| dataset = eval(args.dataset) | |
| dataset.set_resolution(model) | |
| query_names = [] | |
| poses_pred = [] | |
| pose_errors = [] | |
| angular_errors = [] | |
| params_str = f'tol_{args.pixel_tol}' + ("_c2f" if args.coarse_to_fine else '') | |
| if args.max_image_size is not None: | |
| params_str = params_str + f'_{args.max_image_size}' | |
| if args.coarse_to_fine and args.c2f_crop_with_homography: | |
| params_str = params_str + '_with_homography' | |
| for idx in tqdm(range(len(dataset))): | |
| views = dataset[(idx)] # 0 is the query | |
| query_view = views[0] | |
| map_views = views[1:] | |
| query_names.append(query_view['image_name']) | |
| query_pts2d = [] | |
| query_pts3d = [] | |
| maxdim = max(model.patch_embed.img_size) | |
| query_rgb_tensor, query_K, query_to_orig_max, query_to_resize_max, (HQ, WQ) = resize_image_to_max( | |
| args.max_image_size, query_view['rgb'], query_view['intrinsics']) | |
| # pairs of crops have the same resolution | |
| query_resolution = get_HW_resolution(HQ, WQ, maxdim=maxdim, patchsize=model.patch_embed.patch_size) | |
| for map_view in map_views: | |
| if args.output_dir is not None: | |
| cache_file = os.path.join(args.output_dir, 'matches', params_str, | |
| query_view['image_name'], map_view['image_name'] + '.npz') | |
| else: | |
| cache_file = None | |
| if cache_file is not None and os.path.isfile(cache_file): | |
| matches = np.load(cache_file) | |
| valid_pts3d = matches['valid_pts3d'] | |
| matches_im_query = matches['matches_im_query'] | |
| matches_im_map = matches['matches_im_map'] | |
| matches_conf = matches['matches_conf'] | |
| else: | |
| # coarse matching | |
| if args.coarse_to_fine and (maxdim < max(WQ, HQ)): | |
| # use all points | |
| _, coarse_matches_im0, coarse_matches_im1, _ = coarse_matching(query_view, map_view, model, device, | |
| 0, fast_nn_params) | |
| # visualize a few matches | |
| if viz_matches > 0: | |
| num_matches = coarse_matches_im1.shape[0] | |
| print(f'found {num_matches} matches') | |
| viz_imgs = [np.array(query_view['rgb']), np.array(map_view['rgb'])] | |
| from matplotlib import pyplot as pl | |
| n_viz = viz_matches | |
| match_idx_to_viz = np.round(np.linspace(0, num_matches - 1, n_viz)).astype(int) | |
| viz_matches_im_query = coarse_matches_im0[match_idx_to_viz] | |
| viz_matches_im_map = coarse_matches_im1[match_idx_to_viz] | |
| H0, W0, H1, W1 = *viz_imgs[0].shape[:2], *viz_imgs[1].shape[:2] | |
| img0 = np.pad(viz_imgs[0], ((0, max(H1 - H0, 0)), (0, 0), (0, 0)), | |
| 'constant', constant_values=0) | |
| img1 = np.pad(viz_imgs[1], ((0, max(H0 - H1, 0)), (0, 0), (0, 0)), | |
| 'constant', constant_values=0) | |
| img = np.concatenate((img0, img1), axis=1) | |
| pl.figure() | |
| pl.imshow(img) | |
| cmap = pl.get_cmap('jet') | |
| for i in range(n_viz): | |
| (x0, y0), (x1, y1) = viz_matches_im_query[i].T, viz_matches_im_map[i].T | |
| pl.plot([x0, x1 + W0], [y0, y1], '-+', | |
| color=cmap(i / (n_viz - 1)), scalex=False, scaley=False) | |
| pl.show(block=True) | |
| valid_all = map_view['valid'] | |
| pts3d = map_view['pts3d'] | |
| WM_full, HM_full = map_view['rgb'].size | |
| map_rgb_tensor, map_K, map_to_orig_max, map_to_resize_max, (HM, WM) = resize_image_to_max( | |
| args.max_image_size, map_view['rgb'], map_view['intrinsics']) | |
| if WM_full != WM or HM_full != HM: | |
| y_full, x_full = torch.where(valid_all) | |
| pos2d_cv2 = torch.stack([x_full, y_full], dim=-1).cpu().numpy().astype(np.float64) | |
| sparse_pts3d = pts3d[y_full, x_full].cpu().numpy() | |
| _, _, pts3d_max, valid_max = rescale_points3d( | |
| pos2d_cv2, sparse_pts3d, map_to_resize_max, HM, WM) | |
| pts3d = torch.from_numpy(pts3d_max) | |
| valid_all = torch.from_numpy(valid_max) | |
| coarse_matches_im0 = geotrf(query_to_resize_max, coarse_matches_im0, norm=True) | |
| coarse_matches_im1 = geotrf(map_to_resize_max, coarse_matches_im1, norm=True) | |
| crops1, crops2 = [], [] | |
| crops_v1, crops_p1 = [], [] | |
| to_orig1, to_orig2 = [], [] | |
| map_resolution = get_HW_resolution(HM, WM, maxdim=maxdim, patchsize=model.patch_embed.patch_size) | |
| for crop_q, crop_b, pair_tag in select_pairs_of_crops(map_rgb_tensor, | |
| query_rgb_tensor, | |
| coarse_matches_im1, | |
| coarse_matches_im0, | |
| maxdim=maxdim, | |
| overlap=.5, | |
| forced_resolution=[map_resolution, | |
| query_resolution]): | |
| # Per crop processing | |
| if not args.c2f_crop_with_homography: | |
| map_K = None | |
| query_K = None | |
| c1, v1, p1, trf1 = crop(map_rgb_tensor, valid_all, pts3d, crop_q, map_K) | |
| c2, _, _, trf2 = crop(query_rgb_tensor, None, None, crop_b, query_K) | |
| crops1.append(c1) | |
| crops2.append(c2) | |
| crops_v1.append(v1) | |
| crops_p1.append(p1) | |
| to_orig1.append(trf1) | |
| to_orig2.append(trf2) | |
| if len(crops1) == 0 or len(crops2) == 0: | |
| valid_pts3d, matches_im_query, matches_im_map, matches_conf = [], [], [], [] | |
| else: | |
| crops1, crops2 = torch.stack(crops1), torch.stack(crops2) | |
| if len(crops1.shape) == 3: | |
| crops1, crops2 = crops1[None], crops2[None] | |
| crops_v1 = torch.stack(crops_v1) | |
| crops_p1 = torch.stack(crops_p1) | |
| to_orig1, to_orig2 = torch.stack(to_orig1), torch.stack(to_orig2) | |
| map_crop_view = dict(img=crops1.permute(0, 3, 1, 2), | |
| instance=['1' for _ in range(crops1.shape[0])], | |
| valid=crops_v1, pts3d=crops_p1, | |
| to_orig=to_orig1) | |
| query_crop_view = dict(img=crops2.permute(0, 3, 1, 2), | |
| instance=['2' for _ in range(crops2.shape[0])], | |
| to_orig=to_orig2) | |
| # Inference and Matching | |
| valid_pts3d, matches_im_query, matches_im_map, matches_conf = fine_matching(query_crop_view, | |
| map_crop_view, | |
| model, device, | |
| args.max_batch_size, | |
| args.pixel_tol, | |
| fast_nn_params) | |
| matches_im_query = geotrf(query_to_orig_max, matches_im_query, norm=True) | |
| matches_im_map = geotrf(map_to_orig_max, matches_im_map, norm=True) | |
| else: | |
| # use only valid 2d points | |
| valid_pts3d, matches_im_query, matches_im_map, matches_conf = coarse_matching(query_view, map_view, | |
| model, device, | |
| args.pixel_tol, | |
| fast_nn_params) | |
| if cache_file is not None: | |
| mkdir_for(cache_file) | |
| np.savez(cache_file, valid_pts3d=valid_pts3d, matches_im_query=matches_im_query, | |
| matches_im_map=matches_im_map, matches_conf=matches_conf) | |
| # apply conf | |
| if len(matches_conf) > 0: | |
| mask = matches_conf >= conf_thr | |
| valid_pts3d = valid_pts3d[mask] | |
| matches_im_query = matches_im_query[mask] | |
| matches_im_map = matches_im_map[mask] | |
| matches_conf = matches_conf[mask] | |
| # visualize a few matches | |
| if viz_matches > 0: | |
| num_matches = matches_im_map.shape[0] | |
| print(f'found {num_matches} matches') | |
| viz_imgs = [np.array(query_view['rgb']), np.array(map_view['rgb'])] | |
| from matplotlib import pyplot as pl | |
| n_viz = viz_matches | |
| match_idx_to_viz = np.round(np.linspace(0, num_matches - 1, n_viz)).astype(int) | |
| viz_matches_im_query = matches_im_query[match_idx_to_viz] | |
| viz_matches_im_map = matches_im_map[match_idx_to_viz] | |
| H0, W0, H1, W1 = *viz_imgs[0].shape[:2], *viz_imgs[1].shape[:2] | |
| img0 = np.pad(viz_imgs[0], ((0, max(H1 - H0, 0)), (0, 0), (0, 0)), 'constant', constant_values=0) | |
| img1 = np.pad(viz_imgs[1], ((0, max(H0 - H1, 0)), (0, 0), (0, 0)), 'constant', constant_values=0) | |
| img = np.concatenate((img0, img1), axis=1) | |
| pl.figure() | |
| pl.imshow(img) | |
| cmap = pl.get_cmap('jet') | |
| for i in range(n_viz): | |
| (x0, y0), (x1, y1) = viz_matches_im_query[i].T, viz_matches_im_map[i].T | |
| pl.plot([x0, x1 + W0], [y0, y1], '-+', color=cmap(i / (n_viz - 1)), scalex=False, scaley=False) | |
| pl.show(block=True) | |
| if len(valid_pts3d) == 0: | |
| pass | |
| else: | |
| query_pts3d.append(valid_pts3d) | |
| query_pts2d.append(matches_im_query) | |
| if len(query_pts2d) == 0: | |
| success = False | |
| pr_querycam_to_world = None | |
| else: | |
| query_pts2d = np.concatenate(query_pts2d, axis=0).astype(np.float32) | |
| query_pts3d = np.concatenate(query_pts3d, axis=0) | |
| if len(query_pts2d) > pnp_max_points: | |
| idxs = random.sample(range(len(query_pts2d)), pnp_max_points) | |
| query_pts3d = query_pts3d[idxs] | |
| query_pts2d = query_pts2d[idxs] | |
| W, H = query_view['rgb'].size | |
| if reprojection_error_diag_ratio is not None: | |
| reprojection_error_img = reprojection_error_diag_ratio * math.sqrt(W**2 + H**2) | |
| else: | |
| reprojection_error_img = reprojection_error | |
| success, pr_querycam_to_world = run_pnp(query_pts2d, query_pts3d, | |
| query_view['intrinsics'], query_view['distortion'], | |
| pnp_mode, reprojection_error_img, img_size=[W, H]) | |
| if not success: | |
| abs_transl_error = float('inf') | |
| abs_angular_error = float('inf') | |
| else: | |
| abs_transl_error, abs_angular_error = get_pose_error(pr_querycam_to_world, query_view['cam_to_world']) | |
| pose_errors.append(abs_transl_error) | |
| angular_errors.append(abs_angular_error) | |
| poses_pred.append(pr_querycam_to_world) | |
| xp_label = params_str + f'_conf_{conf_thr}' | |
| if args.output_label: | |
| xp_label = args.output_label + "_" + xp_label | |
| if reprojection_error_diag_ratio is not None: | |
| xp_label = xp_label + f'_reproj_diag_{reprojection_error_diag_ratio}' | |
| else: | |
| xp_label = xp_label + f'_reproj_err_{reprojection_error}' | |
| export_results(args.output_dir, xp_label, query_names, poses_pred) | |
| out_string = aggregate_stats(f'{args.dataset}', pose_errors, angular_errors) | |
| print(out_string) | |