Spaces:
Running
Running
| import torch | |
| import cv2 | |
| import numpy as np | |
| from collections import OrderedDict | |
| from loguru import logger | |
| from kornia.geometry.epipolar import numeric | |
| from kornia.geometry.conversions import convert_points_to_homogeneous | |
| # --- METRICS --- | |
| def relative_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0): | |
| # angle error between 2 vectors | |
| t_gt = T_0to1[:3, 3] | |
| n = np.linalg.norm(t) * np.linalg.norm(t_gt) | |
| t_err = np.rad2deg(np.arccos(np.clip(np.dot(t, t_gt) / n, -1.0, 1.0))) | |
| t_err = np.minimum(t_err, 180 - t_err) # handle E ambiguity | |
| if np.linalg.norm(t_gt) < ignore_gt_t_thr: # pure rotation is challenging | |
| t_err = 0 | |
| # angle error between 2 rotation matrices | |
| R_gt = T_0to1[:3, :3] | |
| cos = (np.trace(np.dot(R.T, R_gt)) - 1) / 2 | |
| cos = np.clip(cos, -1.0, 1.0) # handle numercial errors | |
| R_err = np.rad2deg(np.abs(np.arccos(cos))) | |
| return t_err, R_err | |
| def symmetric_epipolar_distance(pts0, pts1, E, K0, K1): | |
| """Squared symmetric epipolar distance. | |
| This can be seen as a biased estimation of the reprojection error. | |
| Args: | |
| pts0 (torch.Tensor): [N, 2] | |
| E (torch.Tensor): [3, 3] | |
| """ | |
| pts0 = (pts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None] | |
| pts1 = (pts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None] | |
| pts0 = convert_points_to_homogeneous(pts0) | |
| pts1 = convert_points_to_homogeneous(pts1) | |
| Ep0 = pts0 @ E.T # [N, 3] | |
| p1Ep0 = torch.sum(pts1 * Ep0, -1) # [N,] | |
| Etp1 = pts1 @ E # [N, 3] | |
| d = p1Ep0**2 * ( | |
| 1.0 / (Ep0[:, 0] ** 2 + Ep0[:, 1] ** 2) | |
| + 1.0 / (Etp1[:, 0] ** 2 + Etp1[:, 1] ** 2) | |
| ) # N | |
| return d | |
| def compute_symmetrical_epipolar_errors(data): | |
| """ | |
| Update: | |
| data (dict):{"epi_errs": [M]} | |
| """ | |
| Tx = numeric.cross_product_matrix(data["T_0to1"][:, :3, 3]) | |
| E_mat = Tx @ data["T_0to1"][:, :3, :3] | |
| m_bids = data["m_bids"] | |
| pts0 = data["mkpts0_f"] | |
| pts1 = data["mkpts1_f"] | |
| epi_errs = [] | |
| for bs in range(Tx.size(0)): | |
| mask = m_bids == bs | |
| epi_errs.append( | |
| symmetric_epipolar_distance( | |
| pts0[mask], pts1[mask], E_mat[bs], data["K0"][bs], data["K1"][bs] | |
| ) | |
| ) | |
| epi_errs = torch.cat(epi_errs, dim=0) | |
| data.update({"epi_errs": epi_errs}) | |
| def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999): | |
| if len(kpts0) < 5: | |
| return None | |
| # normalize keypoints | |
| kpts0 = (kpts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None] | |
| kpts1 = (kpts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None] | |
| # normalize ransac threshold | |
| ransac_thr = thresh / np.mean([K0[0, 0], K1[1, 1], K0[0, 0], K1[1, 1]]) | |
| # compute pose with cv2 | |
| E, mask = cv2.findEssentialMat( | |
| kpts0, kpts1, np.eye(3), threshold=ransac_thr, prob=conf, method=cv2.RANSAC | |
| ) | |
| if E is None: | |
| print("\nE is None while trying to recover pose.\n") | |
| return None | |
| # recover pose from E | |
| best_num_inliers = 0 | |
| ret = None | |
| for _E in np.split(E, len(E) / 3): | |
| n, R, t, _ = cv2.recoverPose(_E, kpts0, kpts1, np.eye(3), 1e9, mask=mask) | |
| if n > best_num_inliers: | |
| ret = (R, t[:, 0], mask.ravel() > 0) | |
| best_num_inliers = n | |
| return ret | |
| def compute_pose_errors(data, config=None, ransac_thr=0.5, ransac_conf=0.99999): | |
| """ | |
| Update: | |
| data (dict):{ | |
| "R_errs" List[float]: [N] | |
| "t_errs" List[float]: [N] | |
| "inliers" List[np.ndarray]: [N] | |
| } | |
| """ | |
| pixel_thr = ( | |
| config.TRAINER.RANSAC_PIXEL_THR if config is not None else ransac_thr | |
| ) # 0.5 | |
| conf = config.TRAINER.RANSAC_CONF if config is not None else ransac_conf # 0.99999 | |
| data.update({"R_errs": [], "t_errs": [], "inliers": []}) | |
| m_bids = data["m_bids"].cpu().numpy() | |
| pts0 = data["mkpts0_f"].cpu().numpy() | |
| pts1 = data["mkpts1_f"].cpu().numpy() | |
| K0 = data["K0"].cpu().numpy() | |
| K1 = data["K1"].cpu().numpy() | |
| T_0to1 = data["T_0to1"].cpu().numpy() | |
| for bs in range(K0.shape[0]): | |
| mask = m_bids == bs | |
| ret = estimate_pose( | |
| pts0[mask], pts1[mask], K0[bs], K1[bs], pixel_thr, conf=conf | |
| ) | |
| if ret is None: | |
| data["R_errs"].append(np.inf) | |
| data["t_errs"].append(np.inf) | |
| data["inliers"].append(np.array([]).astype(np.bool)) | |
| else: | |
| R, t, inliers = ret | |
| t_err, R_err = relative_pose_error(T_0to1[bs], R, t, ignore_gt_t_thr=0.0) | |
| data["R_errs"].append(R_err) | |
| data["t_errs"].append(t_err) | |
| data["inliers"].append(inliers) | |
| # --- METRIC AGGREGATION --- | |
| def error_auc(errors, thresholds): | |
| """ | |
| Args: | |
| errors (list): [N,] | |
| thresholds (list) | |
| """ | |
| errors = [0] + sorted(list(errors)) | |
| recall = list(np.linspace(0, 1, len(errors))) | |
| aucs = [] | |
| thresholds = [5, 10, 20] | |
| for thr in thresholds: | |
| last_index = np.searchsorted(errors, thr) | |
| y = recall[:last_index] + [recall[last_index - 1]] | |
| x = errors[:last_index] + [thr] | |
| aucs.append(np.trapz(y, x) / thr) | |
| return {f"auc@{t}": auc for t, auc in zip(thresholds, aucs)} | |
| def epidist_prec(errors, thresholds, ret_dict=False): | |
| precs = [] | |
| for thr in thresholds: | |
| prec_ = [] | |
| for errs in errors: | |
| correct_mask = errs < thr | |
| prec_.append(np.mean(correct_mask) if len(correct_mask) > 0 else 0) | |
| precs.append(np.mean(prec_) if len(prec_) > 0 else 0) | |
| if ret_dict: | |
| return {f"prec@{t:.0e}": prec for t, prec in zip(thresholds, precs)} | |
| else: | |
| return precs | |
| def aggregate_metrics(metrics, epi_err_thr=5e-4): | |
| """Aggregate metrics for the whole dataset: | |
| (This method should be called once per dataset) | |
| 1. AUC of the pose error (angular) at the threshold [5, 10, 20] | |
| 2. Mean matching precision at the threshold 5e-4(ScanNet), 1e-4(MegaDepth) | |
| """ | |
| # filter duplicates | |
| unq_ids = OrderedDict((iden, id) for id, iden in enumerate(metrics["identifiers"])) | |
| unq_ids = list(unq_ids.values()) | |
| logger.info(f"Aggregating metrics over {len(unq_ids)} unique items...") | |
| # pose auc | |
| angular_thresholds = [5, 10, 20] | |
| pose_errors = np.max(np.stack([metrics["R_errs"], metrics["t_errs"]]), axis=0)[ | |
| unq_ids | |
| ] | |
| aucs = error_auc(pose_errors, angular_thresholds) # (auc@5, auc@10, auc@20) | |
| # matching precision | |
| dist_thresholds = [epi_err_thr] | |
| precs = epidist_prec( | |
| np.array(metrics["epi_errs"], dtype=object)[unq_ids], dist_thresholds, True | |
| ) # (prec@err_thr) | |
| return {**aucs, **precs} | |