AnySplat / src /dataset /shims /bounds_shim.py
alexnasa's picture
Upload 243 files
2568013 verified
raw
history blame
2.52 kB
import torch
from einops import einsum, reduce, repeat
from jaxtyping import Float
from torch import Tensor
from ..types import BatchedExample
def compute_depth_for_disparity(
extrinsics: Float[Tensor, "batch view 4 4"],
intrinsics: Float[Tensor, "batch view 3 3"],
image_shape: tuple[int, int],
disparity: float,
delta_min: float = 1e-6, # This prevents motionless scenes from lacking depth.
) -> Float[Tensor, " batch"]:
"""Compute the depth at which moving the maximum distance between cameras
corresponds to the specified disparity (in pixels).
"""
# Use the furthest distance between cameras as the baseline.
origins = extrinsics[:, :, :3, 3]
deltas = (origins[:, None, :, :] - origins[:, :, None, :]).norm(dim=-1)
deltas = deltas.clip(min=delta_min)
baselines = reduce(deltas, "b v ov -> b", "max")
# Compute a single pixel's size at depth 1.
h, w = image_shape
pixel_size = 1 / torch.tensor((w, h), dtype=torch.float32, device=extrinsics.device)
pixel_size = einsum(
intrinsics[..., :2, :2].inverse(), pixel_size, "... i j, j -> ... i"
)
# This wouldn't make sense with non-square pixels, but then again, non-square pixels
# don't make much sense anyway.
mean_pixel_size = reduce(pixel_size, "b v xy -> b", "mean")
return baselines / (disparity * mean_pixel_size)
def apply_bounds_shim(
batch: BatchedExample,
near_disparity: float,
far_disparity: float,
) -> BatchedExample:
"""Compute reasonable near and far planes (lower and upper bounds on depth). This
assumes that all of an example's views are of roughly the same thing.
"""
context = batch["context"]
_, cv, _, h, w = context["image"].shape
# Compute near and far planes using the context views.
near = compute_depth_for_disparity(
context["extrinsics"],
context["intrinsics"],
(h, w),
near_disparity,
)
far = compute_depth_for_disparity(
context["extrinsics"],
context["intrinsics"],
(h, w),
far_disparity,
)
target = batch["target"]
_, tv, _, _, _ = target["image"].shape
return {
**batch,
"context": {
**context,
"near": repeat(near, "b -> b v", v=cv),
"far": repeat(far, "b -> b v", v=cv),
},
"target": {
**target,
"near": repeat(near, "b -> b v", v=tv),
"far": repeat(far, "b -> b v", v=tv),
},
}