Spaces:
Running
Running
| import sys | |
| sys.path.insert(0, 'thirdparty/DROID-SLAM/droid_slam') | |
| sys.path.insert(0, 'thirdparty/DROID-SLAM') | |
| from tqdm import tqdm | |
| import numpy as np | |
| import torch | |
| import os | |
| import argparse | |
| from PIL import Image | |
| import cv2 | |
| from glob import glob | |
| from droid import Droid | |
| from torch.multiprocessing import Process | |
| import evo | |
| from evo.core.trajectory import PoseTrajectory3D | |
| from evo.tools import file_interface | |
| from evo.core import sync | |
| import evo.main_ape as main_ape | |
| from evo.core.metrics import PoseRelation | |
| from pycocotools import mask as masktool | |
| from torchvision.transforms import Resize | |
| # Some default settings for DROID-SLAM | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--imagedir", type=str, help="path to image directory") | |
| parser.add_argument("--calib", type=str, help="path to calibration file") | |
| parser.add_argument("--t0", default=0, type=int, help="starting frame") | |
| parser.add_argument("--stride", default=1, type=int, help="frame stride") | |
| parser.add_argument("--weights", default="weights/external/droid.pth") | |
| parser.add_argument("--buffer", type=int, default=512) | |
| parser.add_argument("--image_size", default=[240, 320]) | |
| parser.add_argument("--disable_vis", action="store_true") | |
| parser.add_argument("--beta", type=float, default=0.3, help="weight for translation / rotation components of flow") | |
| parser.add_argument("--filter_thresh", type=float, default=2.4, help="how much motion before considering new keyframe") | |
| parser.add_argument("--warmup", type=int, default=8, help="number of warmup frames") | |
| parser.add_argument("--keyframe_thresh", type=float, default=4.0, help="threshold to create a new keyframe") | |
| parser.add_argument("--frontend_thresh", type=float, default=16.0, help="add edges between frames whithin this distance") | |
| parser.add_argument("--frontend_window", type=int, default=25, help="frontend optimization window") | |
| parser.add_argument("--frontend_radius", type=int, default=2, help="force edges between frames within radius") | |
| parser.add_argument("--frontend_nms", type=int, default=1, help="non-maximal supression of edges") | |
| parser.add_argument("--backend_thresh", type=float, default=22.0) | |
| parser.add_argument("--backend_radius", type=int, default=2) | |
| parser.add_argument("--backend_nms", type=int, default=3) | |
| parser.add_argument("--upsample", action="store_true") | |
| parser.add_argument("--reconstruction_path", help="path to saved reconstruction") | |
| args = parser.parse_args([]) | |
| args.stereo = False | |
| args.upsample = True | |
| args.disable_vis = True | |
| torch.multiprocessing.set_start_method('spawn') | |
| def est_calib(imagedir): | |
| """ Roughly estimate intrinsics by image dimensions """ | |
| if isinstance(imagedir, list): | |
| imgfiles = imagedir | |
| else: | |
| imgfiles = sorted(glob(f'{imagedir}/*.jpg')) | |
| image = cv2.imread(imgfiles[0]) | |
| h0, w0, _ = image.shape | |
| focal = np.max([h0, w0]) | |
| cx, cy = w0/2., h0/2. | |
| calib = [focal, focal, cx, cy] | |
| return calib | |
| def get_dimention(imagedir): | |
| """ Get proper image dimension for DROID """ | |
| if isinstance(imagedir, list): | |
| imgfiles = imagedir | |
| else: | |
| imgfiles = sorted(glob(f'{imagedir}/*.jpg')) | |
| image = cv2.imread(imgfiles[0]) | |
| h0, w0, _ = image.shape | |
| h1 = int(h0 * np.sqrt((384 * 512) / (h0 * w0))) | |
| w1 = int(w0 * np.sqrt((384 * 512) / (h0 * w0))) | |
| image = cv2.resize(image, (w1, h1)) | |
| image = image[:h1-h1%8, :w1-w1%8] | |
| H, W, _ = image.shape | |
| return H, W | |
| def image_stream(imagedir, calib, stride, max_frame=None): | |
| """ Image generator for DROID """ | |
| fx, fy, cx, cy = calib[:4] | |
| K = np.eye(3) | |
| K[0,0] = fx | |
| K[0,2] = cx | |
| K[1,1] = fy | |
| K[1,2] = cy | |
| if isinstance(imagedir, list): | |
| image_list = imagedir | |
| else: | |
| image_list = sorted(glob(f'{imagedir}/*.jpg')) | |
| image_list = image_list[::stride] | |
| if max_frame is not None: | |
| image_list = image_list[:max_frame] | |
| for t, imfile in enumerate(image_list): | |
| image = cv2.imread(imfile) | |
| if len(calib) > 4: | |
| image = cv2.undistort(image, K, calib[4:]) | |
| h0, w0, _ = image.shape | |
| h1 = int(h0 * np.sqrt((384 * 512) / (h0 * w0))) | |
| w1 = int(w0 * np.sqrt((384 * 512) / (h0 * w0))) | |
| image = cv2.resize(image, (w1, h1)) | |
| image = image[:h1-h1%8, :w1-w1%8] | |
| image = torch.as_tensor(image).permute(2, 0, 1) | |
| intrinsics = torch.as_tensor([fx, fy, cx, cy]) | |
| intrinsics[0::2] *= (w1 / w0) | |
| intrinsics[1::2] *= (h1 / h0) | |
| yield t, image[None], intrinsics | |
| def run_slam(imagedir, masks, calib=None, depth=None, stride=1, | |
| filter_thresh=2.4, disable_vis=True): | |
| """ Maksed DROID-SLAM """ | |
| droid = None | |
| depth = None | |
| args.filter_thresh = filter_thresh | |
| args.disable_vis = disable_vis | |
| masks = masks[::stride] | |
| img_msks, conf_msks = preprocess_masks(imagedir, masks) | |
| if calib is None: | |
| calib = est_calib(imagedir) | |
| for (t, image, intrinsics) in tqdm(image_stream(imagedir, calib, stride)): | |
| if droid is None: | |
| args.image_size = [image.shape[2], image.shape[3]] | |
| droid = Droid(args) | |
| img_msk = img_msks[t] | |
| conf_msk = conf_msks[t] | |
| image = image * (img_msk < 0.5) | |
| # cv2.imwrite('debug.png', image[0].permute(1, 2, 0).numpy()) | |
| droid.track(t, image, intrinsics=intrinsics, depth=depth, mask=conf_msk) | |
| traj = droid.terminate(image_stream(imagedir, calib, stride)) | |
| return droid, traj | |
| def run_droid_slam(imagedir, calib=None, depth=None, stride=1, | |
| filter_thresh=2.4, disable_vis=True): | |
| """ Maksed DROID-SLAM """ | |
| droid = None | |
| depth = None | |
| args.filter_thresh = filter_thresh | |
| args.disable_vis = disable_vis | |
| if calib is None: | |
| calib = est_calib(imagedir) | |
| for (t, image, intrinsics) in tqdm(image_stream(imagedir, calib, stride)): | |
| if droid is None: | |
| args.image_size = [image.shape[2], image.shape[3]] | |
| droid = Droid(args) | |
| droid.track(t, image, intrinsics=intrinsics, depth=depth) | |
| traj = droid.terminate(image_stream(imagedir, calib, stride)) | |
| return droid, traj | |
| def eval_slam(traj_est, cam_t, cam_q, return_traj=True, correct_scale=False, align=True, align_origin=False): | |
| """ Evaluation for SLAM """ | |
| tstamps = np.array([i for i in range(len(traj_est))], dtype=np.float32) | |
| traj_est = PoseTrajectory3D( | |
| positions_xyz=traj_est[:,:3], | |
| orientations_quat_wxyz=traj_est[:,3:], | |
| timestamps=tstamps) | |
| traj_ref = PoseTrajectory3D( | |
| positions_xyz=cam_t.copy(), | |
| orientations_quat_wxyz=cam_q.copy(), | |
| timestamps=tstamps) | |
| traj_ref, traj_est = sync.associate_trajectories(traj_ref, traj_est) | |
| result = main_ape.ape(traj_ref, traj_est, est_name='traj', | |
| pose_relation=PoseRelation.translation_part, align=align, align_origin=align_origin, | |
| correct_scale=correct_scale) | |
| stats = result.stats | |
| if return_traj: | |
| return stats, traj_ref, traj_est | |
| return stats | |
| def test_slam(imagedir, img_msks, conf_msks, calib, stride=10, max_frame=50): | |
| """ Shorter SLAM step to test reprojection error """ | |
| args = parser.parse_args([]) | |
| args.stereo = False | |
| args.upsample = False | |
| args.disable_vis = True | |
| args.frontend_window = 10 | |
| args.frontend_thresh = 10 | |
| droid = None | |
| for (t, image, intrinsics) in image_stream(imagedir, calib, stride, max_frame): | |
| if droid is None: | |
| args.image_size = [image.shape[2], image.shape[3]] | |
| droid = Droid(args) | |
| img_msk = img_msks[t] | |
| conf_msk = conf_msks[t] | |
| image = image * (img_msk < 0.5) | |
| droid.track(t, image, intrinsics=intrinsics, mask=conf_msk) | |
| reprojection_error = droid.compute_error() | |
| del droid | |
| return reprojection_error | |
| def search_focal_length(img_folder, masks, stride=10, max_frame=50, | |
| low=500, high=1500, step=100): | |
| """ Search for a good focal length by SLAM reprojection error """ | |
| masks = masks[::stride] | |
| masks = masks[:max_frame] | |
| img_msks, conf_msks = preprocess_masks(img_folder, masks) | |
| # default estimate | |
| calib = np.array(est_calib(img_folder)) | |
| best_focal = calib[0] | |
| best_err = test_slam(img_folder, img_msks, conf_msks, | |
| stride=stride, calib=calib, max_frame=max_frame) | |
| # search based on slam reprojection error | |
| for focal in range(low, high, step): | |
| calib[:2] = focal | |
| err = test_slam(img_folder, img_msks, conf_msks, | |
| stride=stride, calib=calib, max_frame=max_frame) | |
| if err < best_err: | |
| best_err = err | |
| best_focal = focal | |
| print('Best focal length:', best_focal) | |
| return best_focal | |
| def preprocess_masks(img_folder, masks): | |
| """ Resize masks for masked droid """ | |
| H, W = get_dimention(img_folder) | |
| resize_1 = Resize((H, W), antialias=True) | |
| resize_2 = Resize((H//8, W//8), antialias=True) | |
| img_msks = [] | |
| for i in range(0, len(masks), 500): | |
| m = resize_1(masks[i:i+500]) | |
| img_msks.append(m) | |
| img_msks = torch.cat(img_msks) | |
| conf_msks = [] | |
| for i in range(0, len(masks), 500): | |
| m = resize_2(masks[i:i+500]) | |
| conf_msks.append(m) | |
| conf_msks = torch.cat(conf_msks) | |
| return img_msks, conf_msks | |