# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch import torch.nn.functional as F def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"): """ Activate pose parameters with specified activation functions. Args: pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length] trans_act: Activation type for translation component quat_act: Activation type for quaternion component fl_act: Activation type for focal length component Returns: Activated pose parameters tensor """ T = pred_pose_enc[..., :3] quat = pred_pose_enc[..., 3:7] fl = pred_pose_enc[..., 7:] # or fov T = base_pose_act(T, trans_act) quat = base_pose_act(quat, quat_act) fl = base_pose_act(fl, fl_act) # or fov pred_pose_enc = torch.cat([T, quat, fl], dim=-1) return pred_pose_enc def base_pose_act(pose_enc, act_type="linear"): """ Apply basic activation function to pose parameters. Args: pose_enc: Tensor containing encoded pose parameters act_type: Activation type ("linear", "inv_log", "exp", "relu") Returns: Activated pose parameters """ if act_type == "linear": return pose_enc elif act_type == "inv_log": return inverse_log_transform(pose_enc) elif act_type == "exp": return torch.exp(pose_enc) elif act_type == "relu": return F.relu(pose_enc) else: raise ValueError(f"Unknown act_type: {act_type}") def activate_head(out, activation="norm_exp", conf_activation="expp1"): """ Process network output to extract 3D points and confidence values. Args: out: Network output tensor (B, C, H, W) activation: Activation type for 3D points conf_activation: Activation type for confidence values Returns: Tuple of (3D points tensor, confidence tensor) """ # Move channels from last dim to the 4th dimension => (B, H, W, C) fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected # Split into xyz (first C-1 channels) and confidence (last channel) xyz = fmap[:, :, :, :-1] conf = fmap[:, :, :, -1] if activation == "norm_exp": d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8) xyz_normed = xyz / d pts3d = xyz_normed * torch.expm1(d) elif activation == "norm": pts3d = xyz / xyz.norm(dim=-1, keepdim=True) elif activation == "exp": pts3d = torch.exp(xyz) elif activation == "relu": pts3d = F.relu(xyz) elif activation == "inv_log": pts3d = inverse_log_transform(xyz) elif activation == "xy_inv_log": xy, z = xyz.split([2, 1], dim=-1) z = inverse_log_transform(z) pts3d = torch.cat([xy * z, z], dim=-1) elif activation == "sigmoid": pts3d = torch.sigmoid(xyz) elif activation == "linear": pts3d = xyz else: raise ValueError(f"Unknown activation: {activation}") if conf_activation == "expp1": conf_out = 1 + conf.exp() elif conf_activation == "expp0": conf_out = conf.exp() elif conf_activation == "sigmoid": conf_out = torch.sigmoid(conf) else: raise ValueError(f"Unknown conf_activation: {conf_activation}") return pts3d, conf_out def inverse_log_transform(y): """ Apply inverse log transform: sign(y) * (exp(|y|) - 1) Args: y: Input tensor Returns: Transformed tensor """ return torch.sign(y) * (torch.expm1(torch.abs(y)))