AnySplat / src /dataset /view_sampler /view_sampler_rank.py
alexnasa's picture
Upload 243 files
2568013 verified
from dataclasses import dataclass
from typing import Literal
import torch
import copy
from jaxtyping import Float, Int64
from torch import Tensor
import random
from .view_sampler import ViewSampler
@dataclass
class ViewSamplerRankCfg:
name: Literal["rank"]
num_context_views: int # max number of context views
num_target_views: int
min_distance_between_context_views: int
max_distance_between_context_views: int
min_distance_to_context_views: int
warm_up_steps: int
initial_min_distance_between_context_views: int
initial_max_distance_between_context_views: int
max_img_per_gpu: int
def rotation_angle(R1, R2):
# R1 and R2 are 3x3 rotation matrices
R = R1.T @ R2
# Numerical stability: clamp values into [-1,1]
val = (torch.trace(R) - 1) / 2
val = torch.clamp(val, -1.0, 1.0)
angle_rad = torch.acos(val)
angle_deg = angle_rad * 180 / torch.pi # Convert radians to degrees
return angle_deg
def extrinsic_distance(extrinsic1, extrinsic2, lambda_t=1.0):
R1, t1 = extrinsic1[:3, :3], extrinsic1[:3, 3]
R2, t2 = extrinsic2[:3, :3], extrinsic2[:3, 3]
rot_diff = rotation_angle(R1, R2) / 180
center_diff = torch.norm(t1 - t2)
return rot_diff + lambda_t * center_diff
def rotation_angle_batch(R1, R2):
# R1, R2: shape (N, 3, 3)
# We want a matrix of rotation angles for all pairs.
# We'll get R1^T R2 for each pair.
# Expand dimensions to broadcast:
# R1^T: (N,3,3) -> (N,1,3,3)
# R2: (N,3,3) -> (1,N,3,3)
R1_t = R1.transpose(-2, -1)[:, None, :, :] # shape (N,1,3,3)
R2_b = R2[None, :, :, :] # shape (1,N,3,3)
R_mult = torch.matmul(R1_t, R2_b) # shape (N,N,3,3)
# trace(R) for each pair
trace_vals = R_mult[..., 0, 0] + R_mult[..., 1, 1] + R_mult[..., 2, 2] # (N,N)
val = (trace_vals - 1) / 2
val = torch.clamp(val, -1.0, 1.0)
angle_rad = torch.acos(val)
angle_deg = angle_rad * 180 / torch.pi
return angle_deg / 180.0 # normalized rotation difference
def extrinsic_distance_batch(extrinsics, lambda_t=1.0):
# extrinsics: (N,4,4)
# Extract rotation and translation
R = extrinsics[:, :3, :3] # (N,3,3)
t = extrinsics[:, :3, 3] # (N,3)
# Compute all pairwise rotation differences
rot_diff = rotation_angle_batch(R, R) # (N,N)
# Compute all pairwise translation differences
# For t, shape (N,3). We want all pair differences: t[i] - t[j].
# t_i: (N,1,3), t_j: (1,N,3)
t_i = t[:, None, :] # (N,1,3)
t_j = t[None, :, :] # (1,N,3)
trans_diff = torch.norm(t_i - t_j, dim=2) # (N,N)
dists = rot_diff + lambda_t * trans_diff
return dists
def compute_ranking(extrinsics, lambda_t=1.0, normalize=True, batched=True):
if normalize:
extrinsics = copy.deepcopy(extrinsics)
camera_center = copy.deepcopy(extrinsics[:, :3, 3])
camera_center_scale = torch.norm(camera_center, dim=1)
avg_scale = torch.mean(camera_center_scale)
extrinsics[:, :3, 3] = extrinsics[:, :3, 3] / avg_scale
if batched:
dists = extrinsic_distance_batch(extrinsics, lambda_t=lambda_t)
else:
N = extrinsics.shape[0]
dists = torch.zeros((N, N), device=extrinsics.device)
for i in range(N):
for j in range(N):
dists[i,j] = extrinsic_distance(extrinsics[i], extrinsics[j], lambda_t=lambda_t)
ranking = torch.argsort(dists, dim=1)
return ranking, dists
# class ViewSamplerRank(ViewSampler[ViewSamplerRankCfg]):
# def schedule(self, initial: int, final: int) -> int:
# fraction = self.global_step / self.cfg.warm_up_steps
# return min(initial + int((final - initial) * fraction), final)
# def sample(
# self,
# scene: str,
# extrinsics: Float[Tensor, "view 4 4"],
# intrinsics: Float[Tensor, "view 3 3"],
# device: torch.device = torch.device("cpu"),
# ) -> tuple[
# Int64[Tensor, " context_view"], # indices for context views
# Int64[Tensor, " target_view"], # indices for target views
# Float[Tensor, " overlap"], # overlap
# ]:
# num_views, _, _ = extrinsics.shape
# # breakpoint()
# # Compute the context view spacing based on the current global step.
# ranking, dists = compute_ranking(extrinsics, lambda_t=1.0, normalize=True, batched=True)
# reference_view = random.sample(range(num_views), 1)[0]
# refview_ranking = ranking[reference_view]
# # if self.cfg.warm_up_steps > 0:
# # max_gap = self.schedule(
# # self.cfg.initial_max_distance_between_context_views,
# # self.cfg.max_distance_between_context_views,
# # )
# # min_gap = self.schedule(
# # self.cfg.initial_min_distance_between_context_views,
# # self.cfg.min_distance_between_context_views,
# # )
# # else:
# max_gap = self.cfg.max_distance_between_context_views
# min_gap = self.cfg.min_distance_between_context_views
# index_context_left = reference_view
# rightmost_index = random.sample(range(min_gap, max_gap + 1), 1)[0] + 1
# index_context_right = refview_ranking[rightmost_index].item()
# middle_indices = refview_ranking[1: rightmost_index].tolist()
# index_target = random.sample(middle_indices, self.num_target_views)
# remaining_indices = [idx for idx in middle_indices if idx not in index_target]
# # Sample extra context views if needed
# extra_views = []
# num_extra_views = self.num_context_views - 2 # subtract left and right context views
# if num_extra_views > 0 and remaining_indices:
# extra_views = random.sample(remaining_indices, min(num_extra_views, len(remaining_indices)))
# else:
# extra_views = []
# overlap = torch.zeros(1)
# return (
# torch.tensor((index_context_left, *extra_views, index_context_right)),
# torch.tensor(index_target),
# overlap
# )
# @property
# def num_context_views(self) -> int:
# return self.cfg.num_context_views
# @property
# def num_target_views(self) -> int:
# return self.cfg.num_target_views
class ViewSamplerRank(ViewSampler[ViewSamplerRankCfg]):
def sample(
self,
scene: str,
num_context_views: int,
extrinsics: Float[Tensor, "view 4 4"],
intrinsics: Float[Tensor, "view 3 3"],
device: torch.device = torch.device("cpu"),
) -> tuple[
Int64[Tensor, " context_view"], # indices for context views
Int64[Tensor, " target_view"], # indices for target views
Float[Tensor, " overlap"], # overlap
]:
num_views, _, _ = extrinsics.shape
# breakpoint()
extrinsics = extrinsics.clone()
# Compute the context view spacing based on the current global step.
ranking, dists = compute_ranking(extrinsics, lambda_t=1.0, normalize=True, batched=True)
reference_view = random.sample(range(num_views), 1)[0]
refview_ranking = ranking[reference_view]
# if self.cfg.warm_up_steps > 0:
# max_gap = self.schedule(
# self.cfg.initial_max_distance_between_context_views,
# self.cfg.max_distance_between_context_views,
# )
# min_gap = self.schedule(
# self.cfg.initial_min_distance_between_context_views,
# self.cfg.min_distance_between_context_views,
# )
# else:
min_gap, max_gap = self.num_ctxt_gap_mapping[num_context_views]
# min_gap = self.cfg.min_distance_between_context_views
# max_gap = self.cfg.max_distance_between_context_views
max_gap = min(max_gap, num_views-1)
# print(f"num_context_views: {num_context_views}, min_gap: {min_gap}, max_gap: {max_gap}")
index_context_left = reference_view
rightmost_index = random.sample(range(min_gap, max_gap + 1), 1)[0]
# #! hard code for visualization
# rightmost_index = self.cfg.max_distance_between_context_views
index_context_right = refview_ranking[rightmost_index].item()
middle_indices = refview_ranking[1: rightmost_index].tolist()
index_target = random.sample(middle_indices, self.num_target_views)
remaining_indices = [idx for idx in middle_indices if idx not in index_target]
# Sample extra context views if needed
extra_views = []
num_extra_views = num_context_views - 2 # subtract left and right context views
if num_extra_views > 0 and remaining_indices:
extra_views = random.sample(remaining_indices, min(num_extra_views, len(remaining_indices)))
else:
extra_views = []
overlap = torch.zeros(1)
return (
torch.tensor((index_context_left, *extra_views, index_context_right)),
torch.tensor(index_target),
overlap
)
@property
def num_context_views(self) -> int:
return self.cfg.num_context_views
@property
def num_target_views(self) -> int:
return self.cfg.num_target_views
@property
def num_ctxt_gap_mapping_target(self) -> dict:
mapping = dict()
for num_ctxt in range(2, self.cfg.num_context_views + 1):
mapping[num_ctxt] = [max(num_ctxt * 2, self.cfg.num_target_views + num_ctxt), max(self.cfg.num_target_views + num_ctxt, min(num_ctxt ** 2, self.cfg.max_distance_between_context_views))]
return mapping
@property
def num_ctxt_gap_mapping(self) -> dict:
mapping = dict()
for num_ctxt in range(2, self.cfg.num_context_views + 1):
mapping[num_ctxt] = [min(num_ctxt * 3, self.cfg.min_distance_between_context_views), min(max(num_ctxt * 5, num_ctxt ** 2), self.cfg.max_distance_between_context_views)]
return mapping