# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch.nn as nn from .dpt_head import DPTHead from .track_modules.base_track_predictor import BaseTrackerPredictor class TrackHead(nn.Module): """ Track head that uses DPT head to process tokens and BaseTrackerPredictor for tracking. The tracking is performed iteratively, refining predictions over multiple iterations. """ def __init__( self, dim_in, patch_size=14, features=128, iters=4, predict_conf=True, stride=2, corr_levels=7, corr_radius=4, hidden_size=384, ): """ Initialize the TrackHead module. Args: dim_in (int): Input dimension of tokens from the backbone. patch_size (int): Size of image patches used in the vision transformer. features (int): Number of feature channels in the feature extractor output. iters (int): Number of refinement iterations for tracking predictions. predict_conf (bool): Whether to predict confidence scores for tracked points. stride (int): Stride value for the tracker predictor. corr_levels (int): Number of correlation pyramid levels corr_radius (int): Radius for correlation computation, controlling the search area. hidden_size (int): Size of hidden layers in the tracker network. """ super().__init__() self.patch_size = patch_size # Feature extractor based on DPT architecture # Processes tokens into feature maps for tracking self.feature_extractor = DPTHead( dim_in=dim_in, patch_size=patch_size, features=features, feature_only=True, # Only output features, no activation down_ratio=2, # Reduces spatial dimensions by factor of 2 pos_embed=False, ) # Tracker module that predicts point trajectories # Takes feature maps and predicts coordinates and visibility self.tracker = BaseTrackerPredictor( latent_dim=features, # Match the output_dim of feature extractor predict_conf=predict_conf, stride=stride, corr_levels=corr_levels, corr_radius=corr_radius, hidden_size=hidden_size, ) self.iters = iters def forward(self, aggregated_tokens_list, images, patch_start_idx, query_points=None, iters=None): """ Forward pass of the TrackHead. Args: aggregated_tokens_list (list): List of aggregated tokens from the backbone. images (torch.Tensor): Input images of shape (B, S, C, H, W) where: B = batch size, S = sequence length. patch_start_idx (int): Starting index for patch tokens. query_points (torch.Tensor, optional): Initial query points to track. If None, points are initialized by the tracker. iters (int, optional): Number of refinement iterations. If None, uses self.iters. Returns: tuple: - coord_preds (torch.Tensor): Predicted coordinates for tracked points. - vis_scores (torch.Tensor): Visibility scores for tracked points. - conf_scores (torch.Tensor): Confidence scores for tracked points (if predict_conf=True). """ B, S, _, H, W = images.shape # Extract features from tokens # feature_maps has shape (B, S, C, H//2, W//2) due to down_ratio=2 feature_maps = self.feature_extractor(aggregated_tokens_list, images, patch_start_idx) # Use default iterations if not specified if iters is None: iters = self.iters # Perform tracking using the extracted features coord_preds, vis_scores, conf_scores = self.tracker( query_points=query_points, fmaps=feature_maps, iters=iters, ) return coord_preds, vis_scores, conf_scores