|
|
|
|
|
|
|
|
|
|
|
|
|
import torch.nn as nn |
|
from .dpt_head import DPTHead |
|
from .track_modules.base_track_predictor import BaseTrackerPredictor |
|
|
|
|
|
class TrackHead(nn.Module): |
|
""" |
|
Track head that uses DPT head to process tokens and BaseTrackerPredictor for tracking. |
|
The tracking is performed iteratively, refining predictions over multiple iterations. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dim_in, |
|
patch_size=14, |
|
features=128, |
|
iters=4, |
|
predict_conf=True, |
|
stride=2, |
|
corr_levels=7, |
|
corr_radius=4, |
|
hidden_size=384, |
|
): |
|
""" |
|
Initialize the TrackHead module. |
|
|
|
Args: |
|
dim_in (int): Input dimension of tokens from the backbone. |
|
patch_size (int): Size of image patches used in the vision transformer. |
|
features (int): Number of feature channels in the feature extractor output. |
|
iters (int): Number of refinement iterations for tracking predictions. |
|
predict_conf (bool): Whether to predict confidence scores for tracked points. |
|
stride (int): Stride value for the tracker predictor. |
|
corr_levels (int): Number of correlation pyramid levels |
|
corr_radius (int): Radius for correlation computation, controlling the search area. |
|
hidden_size (int): Size of hidden layers in the tracker network. |
|
""" |
|
super().__init__() |
|
|
|
self.patch_size = patch_size |
|
|
|
|
|
|
|
self.feature_extractor = DPTHead( |
|
dim_in=dim_in, |
|
patch_size=patch_size, |
|
features=features, |
|
feature_only=True, |
|
down_ratio=2, |
|
pos_embed=False, |
|
) |
|
|
|
|
|
|
|
self.tracker = BaseTrackerPredictor( |
|
latent_dim=features, |
|
predict_conf=predict_conf, |
|
stride=stride, |
|
corr_levels=corr_levels, |
|
corr_radius=corr_radius, |
|
hidden_size=hidden_size, |
|
) |
|
|
|
self.iters = iters |
|
|
|
def forward(self, aggregated_tokens_list, images, patch_start_idx, query_points=None, iters=None): |
|
""" |
|
Forward pass of the TrackHead. |
|
|
|
Args: |
|
aggregated_tokens_list (list): List of aggregated tokens from the backbone. |
|
images (torch.Tensor): Input images of shape (B, S, C, H, W) where: |
|
B = batch size, S = sequence length. |
|
patch_start_idx (int): Starting index for patch tokens. |
|
query_points (torch.Tensor, optional): Initial query points to track. |
|
If None, points are initialized by the tracker. |
|
iters (int, optional): Number of refinement iterations. If None, uses self.iters. |
|
|
|
Returns: |
|
tuple: |
|
- coord_preds (torch.Tensor): Predicted coordinates for tracked points. |
|
- vis_scores (torch.Tensor): Visibility scores for tracked points. |
|
- conf_scores (torch.Tensor): Confidence scores for tracked points (if predict_conf=True). |
|
""" |
|
B, S, _, H, W = images.shape |
|
|
|
|
|
|
|
feature_maps = self.feature_extractor(aggregated_tokens_list, images, patch_start_idx) |
|
|
|
|
|
if iters is None: |
|
iters = self.iters |
|
|
|
|
|
coord_preds, vis_scores, conf_scores = self.tracker( |
|
query_points=query_points, |
|
fmaps=feature_maps, |
|
iters=iters, |
|
) |
|
|
|
return coord_preds, vis_scores, conf_scores |
|
|