AnySplat / src /dataset /view_sampler /three_view_hack.py
alexnasa's picture
Upload 243 files
2568013 verified
raw
history blame contribute delete
270 Bytes
import torch
from jaxtyping import Int
from torch import Tensor
def add_third_context_index(
indices: Int[Tensor, "*batch 2"]
) -> Int[Tensor, "*batch 3"]:
left, right = indices.unbind(dim=-1)
return torch.stack((left, (left + right) // 2, right), dim=-1)