import pycolmap import numpy as np import torch import torch.nn.functional as F from src.model.encoder.vggt.utils.pose_enc import pose_encoding_to_extri_intri from src.model.encoder.vggt.utils.geometry import closed_form_inverse_se3, unproject_depth_map_to_point_map from lightglue import ALIKED, SuperPoint, SIFT from src.utils.tensor_to_pycolmap import batch_matrix_to_pycolmap, pycolmap_to_batch_matrix _RESNET_MEAN = [0.485, 0.456, 0.406] _RESNET_STD = [0.229, 0.224, 0.225] def generate_rank_by_dino( images, query_frame_num, image_size=518, model_name="dinov2_vitb14_reg", device="cuda", spatial_similarity=True ): """ Generate a ranking of frames using DINO ViT features. Args: images: Tensor of shape (S, 3, H, W) with values in range [0, 1] query_frame_num: Number of frames to select image_size: Size to resize images to before processing model_name: Name of the DINO model to use device: Device to run the model on spatial_similarity: Whether to use spatial token similarity or CLS token similarity Returns: List of frame indices ranked by their representativeness """ dino_v2_model = torch.hub.load('facebookresearch/dinov2', model_name) dino_v2_model.eval() dino_v2_model = dino_v2_model.to(device) resnet_mean = torch.tensor(_RESNET_MEAN, device=device).view(1, 3, 1, 1) resnet_std = torch.tensor(_RESNET_STD, device=device).view(1, 3, 1, 1) images_resnet_norm = (images - resnet_mean) / resnet_std with torch.no_grad(): frame_feat = dino_v2_model(images_resnet_norm, is_training=True) if spatial_similarity: frame_feat = frame_feat["x_norm_patchtokens"] frame_feat_norm = F.normalize(frame_feat, p=2, dim=1) # Compute the similarity matrix frame_feat_norm = frame_feat_norm.permute(1, 0, 2) similarity_matrix = torch.bmm( frame_feat_norm, frame_feat_norm.transpose(-1, -2) ) similarity_matrix = similarity_matrix.mean(dim=0) else: frame_feat = frame_feat["x_norm_clstoken"] frame_feat_norm = F.normalize(frame_feat, p=2, dim=1) similarity_matrix = torch.mm( frame_feat_norm, frame_feat_norm.transpose(-1, -2) ) distance_matrix = 100 - similarity_matrix.clone() # Ignore self-pairing similarity_matrix.fill_diagonal_(-100) similarity_sum = similarity_matrix.sum(dim=1) # Find the most common frame most_common_frame_index = torch.argmax(similarity_sum).item() # Conduct FPS sampling starting from the most common frame fps_idx = farthest_point_sampling( distance_matrix, query_frame_num, most_common_frame_index ) return fps_idx def farthest_point_sampling( distance_matrix, num_samples, most_common_frame_index=0 ): """ Farthest point sampling algorithm to select diverse frames. Args: distance_matrix: Matrix of distances between frames num_samples: Number of frames to select most_common_frame_index: Index of the first frame to select Returns: List of selected frame indices """ distance_matrix = distance_matrix.clamp(min=0) N = distance_matrix.size(0) # Initialize with the most common frame selected_indices = [most_common_frame_index] check_distances = distance_matrix[selected_indices] while len(selected_indices) < num_samples: # Find the farthest point from the current set of selected points farthest_point = torch.argmax(check_distances) selected_indices.append(farthest_point.item()) check_distances = distance_matrix[farthest_point] # Mark already selected points to avoid selecting them again check_distances[selected_indices] = 0 # Break if all points have been selected if len(selected_indices) == N: break return selected_indices def calculate_index_mappings(query_index, S, device=None): """ Construct an order that switches [query_index] and [0] so that the content of query_index would be placed at [0]. Args: query_index: Index to swap with 0 S: Total number of elements device: Device to place the tensor on Returns: Tensor of indices with the swapped order """ new_order = torch.arange(S) new_order[0] = query_index new_order[query_index] = 0 if device is not None: new_order = new_order.to(device) return new_order def switch_tensor_order(tensors, order, dim=1): """ Reorder tensors along a specific dimension according to the given order. Args: tensors: List of tensors to reorder order: Tensor of indices specifying the new order dim: Dimension along which to reorder Returns: List of reordered tensors """ return [ torch.index_select(tensor, dim, order) if tensor is not None else None for tensor in tensors ] def predict_track(model, images, query_points, dtype=torch.bfloat16, use_tf32_for_track=True, iters=4): """ Predict tracks for query points across frames. Args: model: VGGT model images: Tensor of images of shape (S, 3, H, W) query_points: Query points to track dtype: Data type for computation use_tf32_for_track: Whether to use TF32 precision for tracking iters: Number of iterations for tracking Returns: Predicted tracks, visibility scores, and confidence scores """ with torch.no_grad(): with torch.cuda.amp.autocast(dtype=dtype): images = images[None] # add batch dimension aggregated_tokens_list, ps_idx = model.aggregator(images) if not use_tf32_for_track: track_list, vis_score, conf_score = model.track_head( aggregated_tokens_list, images, ps_idx, query_points=query_points, iters=iters ) if use_tf32_for_track: with torch.cuda.amp.autocast(enabled=False): track_list, vis_score, conf_score = model.track_head( aggregated_tokens_list, images, ps_idx, query_points=query_points, iters=iters ) pred_track = track_list[-1] return pred_track.squeeze(0), vis_score.squeeze(0), conf_score.squeeze(0) def initialize_feature_extractors(max_query_num, det_thres, extractor_method="aliked", device="cuda"): """ Initialize feature extractors that can be reused based on a method string. Args: max_query_num: Maximum number of keypoints to extract det_thres: Detection threshold for keypoint extraction extractor_method: String specifying which extractors to use (e.g., "aliked", "sp+sift", "aliked+sp+sift") device: Device to run extraction on Returns: Dictionary of initialized extractors """ extractors = {} methods = extractor_method.lower().split('+') active_extractors = len(methods) for method in methods: method = method.strip() if method == "aliked": aliked_max_points = max_query_num // active_extractors aliked_extractor = ALIKED(max_num_keypoints=aliked_max_points, detection_threshold=det_thres) extractors['aliked'] = aliked_extractor.to(device).eval() elif method == "sp": sp_max_points = max_query_num // active_extractors sp_extractor = SuperPoint(max_num_keypoints=sp_max_points, detection_threshold=det_thres) extractors['sp'] = sp_extractor.to(device).eval() elif method == "sift": sift_max_points = max_query_num // active_extractors sift_extractor = SIFT(max_num_keypoints=sift_max_points) extractors['sift'] = sift_extractor.to(device).eval() else: print(f"Warning: Unknown feature extractor '{method}', ignoring.") if not extractors: print(f"Warning: No valid extractors found in '{extractor_method}'. Using ALIKED by default.") aliked_extractor = ALIKED(max_num_keypoints=max_query_num, detection_threshold=det_thres) extractors['aliked'] = aliked_extractor.to(device).eval() return extractors def extract_keypoints(query_image, extractors): """ Extract keypoints using pre-initialized feature extractors. Args: query_image: Input image tensor (3xHxW, range [0, 1]) extractors: Dictionary of initialized extractors Returns: Tensor of keypoint coordinates (1xNx2) """ query_points_round = None with torch.no_grad(): for extractor_name, extractor in extractors.items(): query_points_data = extractor.extract(query_image) extractor_points = query_points_data["keypoints"].round() if query_points_round is not None: query_points_round = torch.cat([query_points_round, extractor_points], dim=1) else: query_points_round = extractor_points return query_points_round def run_vggt_with_ba(model, images, image_names=None, dtype=torch.bfloat16, max_query_num=2048, det_thres=0.005, query_frame_num=3, extractor_method="aliked+sp+sift", max_reproj_error=4, shared_camera=True, ): """ Run VGGT with bundle adjustment for pose estimation. Args: model: VGGT model images: Tensor of images of shape (S, 3, H, W) image_names: Optional list of image names dtype: Data type for computation Returns: Predicted extrinsic camera parameters TODO: - [ ] Use VGGT's vit instead of dinov2 for rank generation """ device = images.device frame_num = images.shape[0] # TODO: use vggt's vit instead of dinov2 # Select representative frames for feature extraction query_frame_indexes = generate_rank_by_dino( images, query_frame_num, image_size=518, model_name="dinov2_vitb14_reg", device=device, spatial_similarity=False ) # Add the first image to the front if not already present if 0 in query_frame_indexes: query_frame_indexes.remove(0) query_frame_indexes = [0, *query_frame_indexes] # Get initial pose and depth predictions with torch.no_grad(), torch.cuda.amp.autocast(enabled=False, dtype=torch.bfloat16): aggregated_tokens_list, patch_start_idx = model.aggregator(images, intermediate_layer_idx=model.cfg.intermediate_layer_idx) with torch.cuda.amp.autocast(enabled=False): fp32_tokens = [token.float() for token in aggregated_tokens_list] pred_all_pose_enc = model.camera_head(fp32_tokens)[-1] pred_all_extrinsic, pred_all_intrinsic = pose_encoding_to_extri_intri(pred_all_pose_enc, images.shape[-2:]) pred_extrinsic = pred_all_extrinsic[0] pred_intrinsic = pred_all_intrinsic[0] depth_map, depth_conf = model.depth_head(aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx) world_points = unproject_depth_map_to_point_map(depth_map, pred_extrinsic, pred_intrinsic) world_points = torch.from_numpy(world_points).to(device) world_points_conf = depth_conf.to(device) torch.cuda.empty_cache() # Lists to store predictions pred_tracks = [] pred_vis_scores = [] pred_conf_scores = [] pred_world_points = [] pred_world_points_conf = [] # Initialize feature extractors extractors = initialize_feature_extractors(max_query_num, det_thres, extractor_method, device) # Process each query frame for query_index in query_frame_indexes: query_image = images[query_index] query_points_round = extract_keypoints(query_image, extractors) # Reorder images to put query image first reorder_index = calculate_index_mappings(query_index, frame_num, device=device) reorder_images = switch_tensor_order([images], reorder_index, dim=0)[0] # Track points across frames reorder_tracks, reorder_vis_score, reorder_conf_score = predict_track( model, reorder_images, query_points_round, dtype=dtype, use_tf32_for_track=True, iters=4 ) # Restore original order pred_track, pred_vis, pred_score = switch_tensor_order( [reorder_tracks, reorder_vis_score, reorder_conf_score], reorder_index, dim=0 ) pred_tracks.append(pred_track) pred_vis_scores.append(pred_vis) pred_conf_scores.append(pred_score) # Get corresponding 3D points query_points_round_long = query_points_round.squeeze(0).long() query_world_points = world_points[query_index][ query_points_round_long[:, 1], query_points_round_long[:, 0] ] query_world_points_conf = world_points_conf[query_index][ query_points_round_long[:, 1], query_points_round_long[:, 0] ] pred_world_points.append(query_world_points) pred_world_points_conf.append(query_world_points_conf) # Concatenate prediction lists pred_tracks = torch.cat(pred_tracks, dim=1) pred_vis_scores = torch.cat(pred_vis_scores, dim=1) pred_conf_scores = torch.cat(pred_conf_scores, dim=1) pred_world_points = torch.cat(pred_world_points, dim=0) pred_world_points_conf = torch.cat(pred_world_points_conf, dim=0) # Filter points by confidence filtered_flag = pred_world_points_conf > 1.5 if filtered_flag.sum() > 1024: # well if the number of points is too small, we will not filter pred_world_points = pred_world_points[filtered_flag] pred_world_points_conf = pred_world_points_conf[filtered_flag] pred_tracks = pred_tracks[:, filtered_flag] pred_vis_scores = pred_vis_scores[:, filtered_flag] pred_conf_scores = pred_conf_scores[:, filtered_flag] torch.cuda.empty_cache() # Bundle adjustment parameters S, _, H, W = images.shape image_size = torch.tensor([W, H], dtype=pred_tracks.dtype, device=device) # Run bundle adjustment reconstruction = batch_matrix_to_pycolmap( pred_world_points, pred_extrinsic, pred_intrinsic, pred_tracks, image_size, max_reproj_error=max_reproj_error, shared_camera=shared_camera ) ba_options = pycolmap.BundleAdjustmentOptions() pycolmap.bundle_adjustment(reconstruction, ba_options) _, updated_extrinsic, _, _ = pycolmap_to_batch_matrix( reconstruction, device=device, camera_type="SIMPLE_PINHOLE" ) return updated_extrinsic