# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import logging import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Tuple, Union, List, Dict, Any from src.model.encoder.vggt.layers import PatchEmbed from src.model.encoder.vggt.layers.block import Block from src.model.encoder.vggt.layers.rope import RotaryPositionEmbedding2D, PositionGetter from src.model.encoder.vggt.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2 logger = logging.getLogger(__name__) _RESNET_MEAN = [0.485, 0.456, 0.406] _RESNET_STD = [0.229, 0.224, 0.225] class Aggregator(nn.Module): """ The Aggregator applies alternating-attention over input frames, as described in VGGT: Visual Geometry Grounded Transformer. Args: img_size (int): Image size in pixels. patch_size (int): Size of each patch for PatchEmbed. embed_dim (int): Dimension of the token embeddings. depth (int): Number of blocks. num_heads (int): Number of attention heads. mlp_ratio (float): Ratio of MLP hidden dim to embedding dim. num_register_tokens (int): Number of register tokens. block_fn (nn.Module): The block type used for attention (Block by default). qkv_bias (bool): Whether to include bias in QKV projections. proj_bias (bool): Whether to include bias in the output projection. ffn_bias (bool): Whether to include bias in MLP layers. patch_embed (str): Type of patch embed. e.g., "conv" or "dinov2_vitl14_reg". aa_order (list[str]): The order of alternating attention, e.g. ["frame", "global"]. aa_block_size (int): How many blocks to group under each attention type before switching. If not necessary, set to 1. qk_norm (bool): Whether to apply QK normalization. rope_freq (int): Base frequency for rotary embedding. -1 to disable. init_values (float): Init scale for layer scale. """ def __init__( self, img_size=518, patch_size=14, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4.0, num_register_tokens=4, block_fn=Block, qkv_bias=True, proj_bias=True, ffn_bias=True, patch_embed="dinov2_vitl14_reg", aa_order=["frame", "global"], aa_block_size=1, qk_norm=True, rope_freq=100, init_values=0.01, ): super().__init__() self.use_checkpoint = True self.__build_patch_embed__(patch_embed, img_size, patch_size, num_register_tokens, embed_dim=embed_dim) # Initialize rotary position embedding if frequency > 0 self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None self.position_getter = PositionGetter() if self.rope is not None else None self.frame_blocks = nn.ModuleList( [ block_fn( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, proj_bias=proj_bias, ffn_bias=ffn_bias, init_values=init_values, qk_norm=qk_norm, rope=self.rope, ) for _ in range(depth) ] ) self.global_blocks = nn.ModuleList( [ block_fn( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, proj_bias=proj_bias, ffn_bias=ffn_bias, init_values=init_values, qk_norm=qk_norm, rope=self.rope, ) for _ in range(depth) ] ) self.depth = depth self.aa_order = aa_order self.patch_size = patch_size self.aa_block_size = aa_block_size # Validate that depth is divisible by aa_block_size if self.depth % self.aa_block_size != 0: raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})") self.aa_block_num = self.depth // self.aa_block_size # Note: We have two camera tokens, one for the first frame and one for the rest # The same applies for register tokens self.camera_token = nn.Parameter(torch.randn(1, 2, 1, embed_dim)) self.register_token = nn.Parameter(torch.randn(1, 2, num_register_tokens, embed_dim)) # The patch tokens start after the camera and register tokens self.patch_start_idx = 1 + num_register_tokens # Initialize parameters with small values nn.init.normal_(self.camera_token, std=1e-6) nn.init.normal_(self.register_token, std=1e-6) # Register normalization constants as buffers for name, value in ( ("_resnet_mean", _RESNET_MEAN), ("_resnet_std", _RESNET_STD), ): self.register_buffer( name, torch.FloatTensor(value).view(1, 1, 3, 1, 1), persistent=False, ) def __build_patch_embed__( self, patch_embed, img_size, patch_size, num_register_tokens, interpolate_antialias=True, interpolate_offset=0.0, block_chunks=0, init_values=1.0, embed_dim=1024, ): """ Build the patch embed layer. If 'conv', we use a simple PatchEmbed conv layer. Otherwise, we use a vision transformer. """ if "conv" in patch_embed: self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=3, embed_dim=embed_dim) else: vit_models = { "dinov2_vitl14_reg": vit_large, "dinov2_vitb14_reg": vit_base, "dinov2_vits14_reg": vit_small, "dinov2_vitg2_reg": vit_giant2, } self.patch_embed = vit_models[patch_embed]( img_size=img_size, patch_size=patch_size, num_register_tokens=num_register_tokens, interpolate_antialias=interpolate_antialias, interpolate_offset=interpolate_offset, block_chunks=block_chunks, init_values=init_values, ) # Disable gradient updates for mask token if hasattr(self.patch_embed, "mask_token"): self.patch_embed.mask_token.requires_grad_(False) def forward( self, images: torch.Tensor, intermediate_layer_idx: Optional[List[int]] = None ) -> Tuple[List[torch.Tensor], int]: """ Args: images (torch.Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1]. B: batch size, S: sequence length, 3: RGB channels, H: height, W: width Returns: (list[torch.Tensor], int): The list of outputs from the attention blocks, and the patch_start_idx indicating where patch tokens begin. """ B, S, C_in, H, W = images.shape if C_in != 3: raise ValueError(f"Expected 3 input channels, got {C_in}") # Normalize images and reshape for patch embed images = (images - self._resnet_mean) / self._resnet_std # Reshape to [B*S, C, H, W] for patch embedding images = images.view(B * S, C_in, H, W) patch_tokens = self.patch_embed(images) if isinstance(patch_tokens, dict): patch_tokens = patch_tokens["x_norm_patchtokens"] _, P, C = patch_tokens.shape # Expand camera and register tokens to match batch size and sequence length camera_token = slice_expand_and_flatten(self.camera_token, B, S) register_token = slice_expand_and_flatten(self.register_token, B, S) # Concatenate special tokens with patch tokens tokens = torch.cat([camera_token, register_token, patch_tokens], dim=1) pos = None if self.rope is not None: pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=images.device) if self.patch_start_idx > 0: # do not use position embedding for special tokens (camera and register tokens) # so set pos to 0 for the special tokens pos = pos + 1 pos_special = torch.zeros(B * S, self.patch_start_idx, 2).to(images.device).to(pos.dtype) pos = torch.cat([pos_special, pos], dim=1) # update P because we added special tokens _, P, C = tokens.shape frame_idx = 0 global_idx = 0 output_list = [] layer_idx = 0 # Convert intermediate_layer_idx to a set for O(1) lookup if intermediate_layer_idx is not None: required_layers = set(intermediate_layer_idx) # Always include the last layer for camera_head required_layers.add(self.depth - 1) for _ in range(self.aa_block_num): for attn_type in self.aa_order: if attn_type == "frame": tokens, frame_idx, frame_intermediates = self._process_frame_attention( tokens, B, S, P, C, frame_idx, pos=pos ) elif attn_type == "global": tokens, global_idx, global_intermediates = self._process_global_attention( tokens, B, S, P, C, global_idx, pos=pos ) else: raise ValueError(f"Unknown attention type: {attn_type}") if intermediate_layer_idx is not None: for i in range(len(frame_intermediates)): current_layer = layer_idx + i if current_layer in required_layers: # concat frame and global intermediates, [B x S x P x 2C] concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1) output_list.append(concat_inter) layer_idx += self.aa_block_size else: for i in range(len(frame_intermediates)): # concat frame and global intermediates, [B x S x P x 2C] concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1) output_list.append(concat_inter) del concat_inter del frame_intermediates del global_intermediates return output_list, self.patch_start_idx def _process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=None): """ Process frame attention blocks. We keep tokens in shape (B*S, P, C). """ # If needed, reshape tokens or positions: if tokens.shape != (B * S, P, C): tokens = tokens.view(B, S, P, C).view(B * S, P, C) if pos is not None and pos.shape != (B * S, P, 2): pos = pos.view(B, S, P, 2).view(B * S, P, 2) intermediates = [] # by default, self.aa_block_size=1, which processes one block at a time for _ in range(self.aa_block_size): if self.use_checkpoint: tokens = torch.utils.checkpoint.checkpoint( self.frame_blocks[frame_idx], tokens, pos, use_reentrant=False, ) else: tokens = self.frame_blocks[frame_idx](tokens, pos=pos) frame_idx += 1 intermediates.append(tokens.view(B, S, P, C)) return tokens, frame_idx, intermediates def _process_global_attention(self, tokens, B, S, P, C, global_idx, pos=None): """ Process global attention blocks. We keep tokens in shape (B, S*P, C). """ if tokens.shape != (B, S * P, C): tokens = tokens.view(B, S, P, C).view(B, S * P, C) if pos is not None and pos.shape != (B, S * P, 2): pos = pos.view(B, S, P, 2).view(B, S * P, 2) intermediates = [] # by default, self.aa_block_size=1, which processes one block at a time for _ in range(self.aa_block_size): if self.use_checkpoint: tokens = torch.utils.checkpoint.checkpoint( self.global_blocks[global_idx], tokens, pos, use_reentrant=False, ) else: tokens = self.global_blocks[global_idx](tokens, pos=pos) global_idx += 1 intermediates.append(tokens.view(B, S, P, C)) return tokens, global_idx, intermediates def slice_expand_and_flatten(token_tensor, B, S): """ Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing: 1) Uses the first position (index=0) for the first frame only 2) Uses the second position (index=1) for all remaining frames (S-1 frames) 3) Expands both to match batch size B 4) Concatenates to form (B, S, X, C) where each sequence has 1 first-position token followed by (S-1) second-position tokens 5) Flattens to (B*S, X, C) for processing Returns: torch.Tensor: Processed tokens with shape (B*S, X, C) """ # Slice out the "query" tokens => shape (1, 1, ...) query = token_tensor[:, 0:1, ...].expand(B, 1, *token_tensor.shape[2:]) # Slice out the "other" tokens => shape (1, S-1, ...) others = token_tensor[:, 1:, ...].expand(B, S - 1, *token_tensor.shape[2:]) # Concatenate => shape (B, S, ...) combined = torch.cat([query, others], dim=1) # Finally flatten => shape (B*S, ...) combined = combined.view(B * S, *combined.shape[2:]) return combined