AnySplat / src /dataset /shims /bounds_shim.py
alexnasa's picture
Upload 243 files
2568013 verified
import torch
from einops import einsum, reduce, repeat
from jaxtyping import Float
from torch import Tensor
from ..types import BatchedExample
def compute_depth_for_disparity(
extrinsics: Float[Tensor, "batch view 4 4"],
intrinsics: Float[Tensor, "batch view 3 3"],
image_shape: tuple[int, int],
disparity: float,
delta_min: float = 1e-6, # This prevents motionless scenes from lacking depth.
) -> Float[Tensor, " batch"]:
"""Compute the depth at which moving the maximum distance between cameras
corresponds to the specified disparity (in pixels).
"""
# Use the furthest distance between cameras as the baseline.
origins = extrinsics[:, :, :3, 3]
deltas = (origins[:, None, :, :] - origins[:, :, None, :]).norm(dim=-1)
deltas = deltas.clip(min=delta_min)
baselines = reduce(deltas, "b v ov -> b", "max")
# Compute a single pixel's size at depth 1.
h, w = image_shape
pixel_size = 1 / torch.tensor((w, h), dtype=torch.float32, device=extrinsics.device)
pixel_size = einsum(
intrinsics[..., :2, :2].inverse(), pixel_size, "... i j, j -> ... i"
)
# This wouldn't make sense with non-square pixels, but then again, non-square pixels
# don't make much sense anyway.
mean_pixel_size = reduce(pixel_size, "b v xy -> b", "mean")
return baselines / (disparity * mean_pixel_size)
def apply_bounds_shim(
batch: BatchedExample,
near_disparity: float,
far_disparity: float,
) -> BatchedExample:
"""Compute reasonable near and far planes (lower and upper bounds on depth). This
assumes that all of an example's views are of roughly the same thing.
"""
context = batch["context"]
_, cv, _, h, w = context["image"].shape
# Compute near and far planes using the context views.
near = compute_depth_for_disparity(
context["extrinsics"],
context["intrinsics"],
(h, w),
near_disparity,
)
far = compute_depth_for_disparity(
context["extrinsics"],
context["intrinsics"],
(h, w),
far_disparity,
)
target = batch["target"]
_, tv, _, _, _ = target["image"].shape
return {
**batch,
"context": {
**context,
"near": repeat(near, "b -> b v", v=cv),
"far": repeat(far, "b -> b v", v=cv),
},
"target": {
**target,
"near": repeat(near, "b -> b v", v=tv),
"far": repeat(far, "b -> b v", v=tv),
},
}