File size: 10,182 Bytes
2568013 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 |
from dataclasses import dataclass
from typing import Literal
import torch
import copy
from jaxtyping import Float, Int64
from torch import Tensor
import random
from .view_sampler import ViewSampler
@dataclass
class ViewSamplerRankCfg:
name: Literal["rank"]
num_context_views: int # max number of context views
num_target_views: int
min_distance_between_context_views: int
max_distance_between_context_views: int
min_distance_to_context_views: int
warm_up_steps: int
initial_min_distance_between_context_views: int
initial_max_distance_between_context_views: int
max_img_per_gpu: int
def rotation_angle(R1, R2):
# R1 and R2 are 3x3 rotation matrices
R = R1.T @ R2
# Numerical stability: clamp values into [-1,1]
val = (torch.trace(R) - 1) / 2
val = torch.clamp(val, -1.0, 1.0)
angle_rad = torch.acos(val)
angle_deg = angle_rad * 180 / torch.pi # Convert radians to degrees
return angle_deg
def extrinsic_distance(extrinsic1, extrinsic2, lambda_t=1.0):
R1, t1 = extrinsic1[:3, :3], extrinsic1[:3, 3]
R2, t2 = extrinsic2[:3, :3], extrinsic2[:3, 3]
rot_diff = rotation_angle(R1, R2) / 180
center_diff = torch.norm(t1 - t2)
return rot_diff + lambda_t * center_diff
def rotation_angle_batch(R1, R2):
# R1, R2: shape (N, 3, 3)
# We want a matrix of rotation angles for all pairs.
# We'll get R1^T R2 for each pair.
# Expand dimensions to broadcast:
# R1^T: (N,3,3) -> (N,1,3,3)
# R2: (N,3,3) -> (1,N,3,3)
R1_t = R1.transpose(-2, -1)[:, None, :, :] # shape (N,1,3,3)
R2_b = R2[None, :, :, :] # shape (1,N,3,3)
R_mult = torch.matmul(R1_t, R2_b) # shape (N,N,3,3)
# trace(R) for each pair
trace_vals = R_mult[..., 0, 0] + R_mult[..., 1, 1] + R_mult[..., 2, 2] # (N,N)
val = (trace_vals - 1) / 2
val = torch.clamp(val, -1.0, 1.0)
angle_rad = torch.acos(val)
angle_deg = angle_rad * 180 / torch.pi
return angle_deg / 180.0 # normalized rotation difference
def extrinsic_distance_batch(extrinsics, lambda_t=1.0):
# extrinsics: (N,4,4)
# Extract rotation and translation
R = extrinsics[:, :3, :3] # (N,3,3)
t = extrinsics[:, :3, 3] # (N,3)
# Compute all pairwise rotation differences
rot_diff = rotation_angle_batch(R, R) # (N,N)
# Compute all pairwise translation differences
# For t, shape (N,3). We want all pair differences: t[i] - t[j].
# t_i: (N,1,3), t_j: (1,N,3)
t_i = t[:, None, :] # (N,1,3)
t_j = t[None, :, :] # (1,N,3)
trans_diff = torch.norm(t_i - t_j, dim=2) # (N,N)
dists = rot_diff + lambda_t * trans_diff
return dists
def compute_ranking(extrinsics, lambda_t=1.0, normalize=True, batched=True):
if normalize:
extrinsics = copy.deepcopy(extrinsics)
camera_center = copy.deepcopy(extrinsics[:, :3, 3])
camera_center_scale = torch.norm(camera_center, dim=1)
avg_scale = torch.mean(camera_center_scale)
extrinsics[:, :3, 3] = extrinsics[:, :3, 3] / avg_scale
if batched:
dists = extrinsic_distance_batch(extrinsics, lambda_t=lambda_t)
else:
N = extrinsics.shape[0]
dists = torch.zeros((N, N), device=extrinsics.device)
for i in range(N):
for j in range(N):
dists[i,j] = extrinsic_distance(extrinsics[i], extrinsics[j], lambda_t=lambda_t)
ranking = torch.argsort(dists, dim=1)
return ranking, dists
# class ViewSamplerRank(ViewSampler[ViewSamplerRankCfg]):
# def schedule(self, initial: int, final: int) -> int:
# fraction = self.global_step / self.cfg.warm_up_steps
# return min(initial + int((final - initial) * fraction), final)
# def sample(
# self,
# scene: str,
# extrinsics: Float[Tensor, "view 4 4"],
# intrinsics: Float[Tensor, "view 3 3"],
# device: torch.device = torch.device("cpu"),
# ) -> tuple[
# Int64[Tensor, " context_view"], # indices for context views
# Int64[Tensor, " target_view"], # indices for target views
# Float[Tensor, " overlap"], # overlap
# ]:
# num_views, _, _ = extrinsics.shape
# # breakpoint()
# # Compute the context view spacing based on the current global step.
# ranking, dists = compute_ranking(extrinsics, lambda_t=1.0, normalize=True, batched=True)
# reference_view = random.sample(range(num_views), 1)[0]
# refview_ranking = ranking[reference_view]
# # if self.cfg.warm_up_steps > 0:
# # max_gap = self.schedule(
# # self.cfg.initial_max_distance_between_context_views,
# # self.cfg.max_distance_between_context_views,
# # )
# # min_gap = self.schedule(
# # self.cfg.initial_min_distance_between_context_views,
# # self.cfg.min_distance_between_context_views,
# # )
# # else:
# max_gap = self.cfg.max_distance_between_context_views
# min_gap = self.cfg.min_distance_between_context_views
# index_context_left = reference_view
# rightmost_index = random.sample(range(min_gap, max_gap + 1), 1)[0] + 1
# index_context_right = refview_ranking[rightmost_index].item()
# middle_indices = refview_ranking[1: rightmost_index].tolist()
# index_target = random.sample(middle_indices, self.num_target_views)
# remaining_indices = [idx for idx in middle_indices if idx not in index_target]
# # Sample extra context views if needed
# extra_views = []
# num_extra_views = self.num_context_views - 2 # subtract left and right context views
# if num_extra_views > 0 and remaining_indices:
# extra_views = random.sample(remaining_indices, min(num_extra_views, len(remaining_indices)))
# else:
# extra_views = []
# overlap = torch.zeros(1)
# return (
# torch.tensor((index_context_left, *extra_views, index_context_right)),
# torch.tensor(index_target),
# overlap
# )
# @property
# def num_context_views(self) -> int:
# return self.cfg.num_context_views
# @property
# def num_target_views(self) -> int:
# return self.cfg.num_target_views
class ViewSamplerRank(ViewSampler[ViewSamplerRankCfg]):
def sample(
self,
scene: str,
num_context_views: int,
extrinsics: Float[Tensor, "view 4 4"],
intrinsics: Float[Tensor, "view 3 3"],
device: torch.device = torch.device("cpu"),
) -> tuple[
Int64[Tensor, " context_view"], # indices for context views
Int64[Tensor, " target_view"], # indices for target views
Float[Tensor, " overlap"], # overlap
]:
num_views, _, _ = extrinsics.shape
# breakpoint()
extrinsics = extrinsics.clone()
# Compute the context view spacing based on the current global step.
ranking, dists = compute_ranking(extrinsics, lambda_t=1.0, normalize=True, batched=True)
reference_view = random.sample(range(num_views), 1)[0]
refview_ranking = ranking[reference_view]
# if self.cfg.warm_up_steps > 0:
# max_gap = self.schedule(
# self.cfg.initial_max_distance_between_context_views,
# self.cfg.max_distance_between_context_views,
# )
# min_gap = self.schedule(
# self.cfg.initial_min_distance_between_context_views,
# self.cfg.min_distance_between_context_views,
# )
# else:
min_gap, max_gap = self.num_ctxt_gap_mapping[num_context_views]
# min_gap = self.cfg.min_distance_between_context_views
# max_gap = self.cfg.max_distance_between_context_views
max_gap = min(max_gap, num_views-1)
# print(f"num_context_views: {num_context_views}, min_gap: {min_gap}, max_gap: {max_gap}")
index_context_left = reference_view
rightmost_index = random.sample(range(min_gap, max_gap + 1), 1)[0]
# #! hard code for visualization
# rightmost_index = self.cfg.max_distance_between_context_views
index_context_right = refview_ranking[rightmost_index].item()
middle_indices = refview_ranking[1: rightmost_index].tolist()
index_target = random.sample(middle_indices, self.num_target_views)
remaining_indices = [idx for idx in middle_indices if idx not in index_target]
# Sample extra context views if needed
extra_views = []
num_extra_views = num_context_views - 2 # subtract left and right context views
if num_extra_views > 0 and remaining_indices:
extra_views = random.sample(remaining_indices, min(num_extra_views, len(remaining_indices)))
else:
extra_views = []
overlap = torch.zeros(1)
return (
torch.tensor((index_context_left, *extra_views, index_context_right)),
torch.tensor(index_target),
overlap
)
@property
def num_context_views(self) -> int:
return self.cfg.num_context_views
@property
def num_target_views(self) -> int:
return self.cfg.num_target_views
@property
def num_ctxt_gap_mapping_target(self) -> dict:
mapping = dict()
for num_ctxt in range(2, self.cfg.num_context_views + 1):
mapping[num_ctxt] = [max(num_ctxt * 2, self.cfg.num_target_views + num_ctxt), max(self.cfg.num_target_views + num_ctxt, min(num_ctxt ** 2, self.cfg.max_distance_between_context_views))]
return mapping
@property
def num_ctxt_gap_mapping(self) -> dict:
mapping = dict()
for num_ctxt in range(2, self.cfg.num_context_views + 1):
mapping[num_ctxt] = [min(num_ctxt * 3, self.cfg.min_distance_between_context_views), min(max(num_ctxt * 5, num_ctxt ** 2), self.cfg.max_distance_between_context_views)]
return mapping
|