| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import torch | 
					
					
						
						| 
							 | 
						import torch.nn.functional as F | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"): | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    Activate pose parameters with specified activation functions. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Args: | 
					
					
						
						| 
							 | 
						        pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length] | 
					
					
						
						| 
							 | 
						        trans_act: Activation type for translation component | 
					
					
						
						| 
							 | 
						        quat_act: Activation type for quaternion component | 
					
					
						
						| 
							 | 
						        fl_act: Activation type for focal length component | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Returns: | 
					
					
						
						| 
							 | 
						        Activated pose parameters tensor | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    T = pred_pose_enc[..., :3] | 
					
					
						
						| 
							 | 
						    quat = pred_pose_enc[..., 3:7] | 
					
					
						
						| 
							 | 
						    fl = pred_pose_enc[..., 7:]   | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    T = base_pose_act(T, trans_act) | 
					
					
						
						| 
							 | 
						    quat = base_pose_act(quat, quat_act) | 
					
					
						
						| 
							 | 
						    fl = base_pose_act(fl, fl_act)   | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    pred_pose_enc = torch.cat([T, quat, fl], dim=-1) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    return pred_pose_enc | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def base_pose_act(pose_enc, act_type="linear"): | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    Apply basic activation function to pose parameters. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Args: | 
					
					
						
						| 
							 | 
						        pose_enc: Tensor containing encoded pose parameters | 
					
					
						
						| 
							 | 
						        act_type: Activation type ("linear", "inv_log", "exp", "relu") | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Returns: | 
					
					
						
						| 
							 | 
						        Activated pose parameters | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    if act_type == "linear": | 
					
					
						
						| 
							 | 
						        return pose_enc | 
					
					
						
						| 
							 | 
						    elif act_type == "inv_log": | 
					
					
						
						| 
							 | 
						        return inverse_log_transform(pose_enc) | 
					
					
						
						| 
							 | 
						    elif act_type == "exp": | 
					
					
						
						| 
							 | 
						        return torch.exp(pose_enc) | 
					
					
						
						| 
							 | 
						    elif act_type == "relu": | 
					
					
						
						| 
							 | 
						        return F.relu(pose_enc) | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        raise ValueError(f"Unknown act_type: {act_type}") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def activate_head(out, activation="norm_exp", conf_activation="expp1"): | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    Process network output to extract 3D points and confidence values. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Args: | 
					
					
						
						| 
							 | 
						        out: Network output tensor (B, C, H, W) | 
					
					
						
						| 
							 | 
						        activation: Activation type for 3D points | 
					
					
						
						| 
							 | 
						        conf_activation: Activation type for confidence values | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Returns: | 
					
					
						
						| 
							 | 
						        Tuple of (3D points tensor, confidence tensor) | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    fmap = out.permute(0, 2, 3, 1)   | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    xyz = fmap[:, :, :, :-1] | 
					
					
						
						| 
							 | 
						    conf = fmap[:, :, :, -1] | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    if activation == "norm_exp": | 
					
					
						
						| 
							 | 
						        d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8) | 
					
					
						
						| 
							 | 
						        xyz_normed = xyz / d | 
					
					
						
						| 
							 | 
						        pts3d = xyz_normed * torch.expm1(d) | 
					
					
						
						| 
							 | 
						    elif activation == "norm": | 
					
					
						
						| 
							 | 
						        pts3d = xyz / xyz.norm(dim=-1, keepdim=True) | 
					
					
						
						| 
							 | 
						    elif activation == "exp": | 
					
					
						
						| 
							 | 
						        pts3d = torch.exp(xyz) | 
					
					
						
						| 
							 | 
						    elif activation == "relu": | 
					
					
						
						| 
							 | 
						        pts3d = F.relu(xyz) | 
					
					
						
						| 
							 | 
						    elif activation == "inv_log": | 
					
					
						
						| 
							 | 
						        pts3d = inverse_log_transform(xyz) | 
					
					
						
						| 
							 | 
						    elif activation == "xy_inv_log": | 
					
					
						
						| 
							 | 
						        xy, z = xyz.split([2, 1], dim=-1) | 
					
					
						
						| 
							 | 
						        z = inverse_log_transform(z) | 
					
					
						
						| 
							 | 
						        pts3d = torch.cat([xy * z, z], dim=-1) | 
					
					
						
						| 
							 | 
						    elif activation == "sigmoid": | 
					
					
						
						| 
							 | 
						        pts3d = torch.sigmoid(xyz) | 
					
					
						
						| 
							 | 
						    elif activation == "linear": | 
					
					
						
						| 
							 | 
						        pts3d = xyz | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        raise ValueError(f"Unknown activation: {activation}") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if conf_activation == "expp1": | 
					
					
						
						| 
							 | 
						        conf_out = 1 + conf.exp() | 
					
					
						
						| 
							 | 
						    elif conf_activation == "expp0": | 
					
					
						
						| 
							 | 
						        conf_out = conf.exp() | 
					
					
						
						| 
							 | 
						    elif conf_activation == "sigmoid": | 
					
					
						
						| 
							 | 
						        conf_out = torch.sigmoid(conf) | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        raise ValueError(f"Unknown conf_activation: {conf_activation}") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    return pts3d, conf_out | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def inverse_log_transform(y): | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    Apply inverse log transform: sign(y) * (exp(|y|) - 1) | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Args: | 
					
					
						
						| 
							 | 
						        y: Input tensor | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Returns: | 
					
					
						
						| 
							 | 
						        Transformed tensor | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    return torch.sign(y) * (torch.expm1(torch.abs(y))) | 
					
					
						
						| 
							 | 
						
 |