File size: 5,615 Bytes
2568013 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
from dataclasses import dataclass
from typing import Literal
import torch
from jaxtyping import Float, Int64
from torch import Tensor
from .view_sampler import ViewSampler
@dataclass
class ViewSamplerBoundedCfg:
name: Literal["bounded"]
num_context_views: int
num_target_views: int
min_distance_between_context_views: int
max_distance_between_context_views: int
min_distance_to_context_views: int
warm_up_steps: int
initial_min_distance_between_context_views: int
initial_max_distance_between_context_views: int
max_img_per_gpu: int
min_gap_multiplier: int
max_gap_multiplier: int
class ViewSamplerBounded(ViewSampler[ViewSamplerBoundedCfg]):
def schedule(self, initial: int, final: int) -> int:
fraction = self.global_step / self.cfg.warm_up_steps
return min(initial + int((final - initial) * fraction), final)
def sample(
self,
scene: str,
num_context_views: int,
extrinsics: Float[Tensor, "view 4 4"],
intrinsics: Float[Tensor, "view 3 3"],
device: torch.device = torch.device("cpu"),
) -> tuple[
Int64[Tensor, " context_view"], # indices for context views
Int64[Tensor, " target_view"], # indices for target views
Float[Tensor, " overlap"], # overlap
]:
num_views, _, _ = extrinsics.shape
# Compute the context view spacing based on the current global step.
if self.stage == "test":
# When testing, always use the full gap.
max_gap = self.cfg.max_distance_between_context_views
min_gap = self.cfg.max_distance_between_context_views
# elif self.cfg.warm_up_steps > 0:
# max_gap = self.schedule(
# self.cfg.initial_max_distance_between_context_views,
# self.cfg.max_distance_between_context_views,
# )
# min_gap = self.schedule(
# self.cfg.initial_min_distance_between_context_views,
# self.cfg.min_distance_between_context_views,
# )
# else:
# max_gap = self.cfg.max_distance_between_context_views
# min_gap = self.cfg.min_distance_between_context_views
min_gap, max_gap = self.num_ctxt_gap_mapping[num_context_views]
max_gap = min(max_gap, num_views-1)
# Pick the gap between the context views.
if not self.cameras_are_circular:
max_gap = min(num_views - 1, max_gap)
min_gap = max(2 * self.cfg.min_distance_to_context_views, min_gap)
if max_gap < min_gap:
raise ValueError("Example does not have enough frames!")
context_gap = torch.randint(
min_gap,
max_gap + 1,
size=tuple(),
device=device,
).item()
# Pick the left and right context indices.
index_context_left = torch.randint(
num_views if self.cameras_are_circular else num_views - context_gap,
size=tuple(),
device=device,
).item()
if self.stage == "test":
index_context_left = index_context_left * 0
index_context_right = index_context_left + context_gap
if self.is_overfitting:
index_context_left *= 0
index_context_right *= 0
index_context_right += max_gap
# Pick the target view indices.
if self.stage == "test":
# When testing, pick all.
index_target = torch.arange(
index_context_left,
index_context_right + 1,
device=device,
)
else:
# When training or validating (visualizing), pick at random.
index_target = torch.randint(
index_context_left + self.cfg.min_distance_to_context_views,
index_context_right + 1 - self.cfg.min_distance_to_context_views,
size=(self.cfg.num_target_views,),
device=device,
)
# Apply modulo for circular datasets.
if self.cameras_are_circular:
index_target %= num_views
index_context_right %= num_views
# If more than two context views are desired, pick extra context views between
# the left and right ones.
if num_context_views > 2:
num_extra_views = num_context_views - 2
extra_views = []
while len(set(extra_views)) != num_extra_views:
extra_views = torch.randint(
index_context_left + 1,
index_context_right,
(num_extra_views,),
).tolist()
else:
extra_views = []
overlap = torch.tensor([0.5], dtype=torch.float32, device=device) # dummy
return (
torch.tensor((index_context_left, *extra_views, index_context_right)),
index_target,
overlap
)
@property
def num_context_views(self) -> int:
return self.cfg.num_context_views
@property
def num_target_views(self) -> int:
return self.cfg.num_target_views
@property
def num_ctxt_gap_mapping(self) -> dict:
mapping = dict()
for num_ctxt in range(2, self.cfg.num_context_views + 1):
mapping[num_ctxt] = [min(num_ctxt * self.cfg.min_gap_multiplier, self.cfg.min_distance_between_context_views),
min(max(num_ctxt * self.cfg.max_gap_multiplier, num_ctxt ** 2), self.cfg.max_distance_between_context_views)]
return mapping
|