|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
import numpy as np |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from src.model.encoder.vggt.layers import Mlp |
|
from src.model.encoder.vggt.layers.block import Block |
|
from src.model.encoder.vggt.heads.head_act import activate_pose |
|
|
|
|
|
class CameraHead(nn.Module): |
|
""" |
|
CameraHead predicts camera parameters from token representations using iterative refinement. |
|
|
|
It applies a series of transformer blocks (the "trunk") to dedicated camera tokens. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dim_in: int = 2048, |
|
trunk_depth: int = 4, |
|
pose_encoding_type: str = "absT_quaR_FoV", |
|
num_heads: int = 16, |
|
mlp_ratio: int = 4, |
|
init_values: float = 0.01, |
|
trans_act: str = "linear", |
|
quat_act: str = "linear", |
|
fl_act: str = "relu", |
|
): |
|
super().__init__() |
|
|
|
if pose_encoding_type == "absT_quaR_FoV": |
|
self.target_dim = 9 |
|
else: |
|
raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}") |
|
|
|
self.trans_act = trans_act |
|
self.quat_act = quat_act |
|
self.fl_act = fl_act |
|
self.trunk_depth = trunk_depth |
|
|
|
|
|
self.trunk = nn.Sequential( |
|
*[ |
|
Block( |
|
dim=dim_in, |
|
num_heads=num_heads, |
|
mlp_ratio=mlp_ratio, |
|
init_values=init_values, |
|
) |
|
for _ in range(trunk_depth) |
|
] |
|
) |
|
|
|
|
|
self.token_norm = nn.LayerNorm(dim_in) |
|
self.trunk_norm = nn.LayerNorm(dim_in) |
|
|
|
|
|
self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim)) |
|
self.embed_pose = nn.Linear(self.target_dim, dim_in) |
|
|
|
|
|
self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True)) |
|
|
|
|
|
self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6) |
|
self.pose_branch = Mlp( |
|
in_features=dim_in, |
|
hidden_features=dim_in // 2, |
|
out_features=self.target_dim, |
|
drop=0, |
|
) |
|
|
|
def forward(self, aggregated_tokens_list: list, num_iterations: int = 4) -> list: |
|
""" |
|
Forward pass to predict camera parameters. |
|
|
|
Args: |
|
aggregated_tokens_list (list): List of token tensors from the network; |
|
the last tensor is used for prediction. |
|
num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4. |
|
|
|
Returns: |
|
list: A list of predicted camera encodings (post-activation) from each iteration. |
|
""" |
|
|
|
tokens = aggregated_tokens_list[-1] |
|
|
|
|
|
pose_tokens = tokens[:, :, 0] |
|
|
|
pose_tokens = self.token_norm(pose_tokens) |
|
|
|
pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations) |
|
return pred_pose_enc_list |
|
|
|
def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list: |
|
""" |
|
Iteratively refine camera pose predictions. |
|
|
|
Args: |
|
pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C]. |
|
num_iterations (int): Number of refinement iterations. |
|
|
|
Returns: |
|
list: List of activated camera encodings from each iteration. |
|
""" |
|
B, S, C = pose_tokens.shape |
|
pred_pose_enc = None |
|
pred_pose_enc_list = [] |
|
|
|
for _ in range(num_iterations): |
|
|
|
if pred_pose_enc is None: |
|
module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1)) |
|
else: |
|
|
|
pred_pose_enc = pred_pose_enc.detach() |
|
module_input = self.embed_pose(pred_pose_enc) |
|
|
|
|
|
shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1) |
|
|
|
|
|
pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa) |
|
pose_tokens_modulated = pose_tokens_modulated + pose_tokens |
|
|
|
pose_tokens_modulated = torch.utils.checkpoint.checkpoint( |
|
self.trunk, |
|
pose_tokens_modulated, |
|
use_reentrant=False, |
|
) |
|
|
|
|
|
pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated)) |
|
|
|
if pred_pose_enc is None: |
|
pred_pose_enc = pred_pose_enc_delta |
|
else: |
|
pred_pose_enc = pred_pose_enc + pred_pose_enc_delta |
|
|
|
|
|
activated_pose = activate_pose( |
|
pred_pose_enc, |
|
trans_act=self.trans_act, |
|
quat_act=self.quat_act, |
|
fl_act=self.fl_act, |
|
) |
|
pred_pose_enc_list.append(activated_pose) |
|
|
|
return pred_pose_enc_list |
|
|
|
|
|
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Modulate the input tensor using scaling and shifting parameters. |
|
""" |
|
|
|
return x * (1 + scale) + shift |
|
|