AnySplat / src /visualization /video_render.py
alexnasa's picture
Upload 243 files
2568013 verified
from typing import Protocol, runtime_checkable
import cv2
import torch
from einops import repeat, pack
from jaxtyping import Float
from torch import Tensor
from .camera_trajectory.interpolation import interpolate_extrinsics, interpolate_intrinsics
from .camera_trajectory.wobble import generate_wobble, generate_wobble_transformation
from .layout import vcat
from ..dataset.types import BatchedExample
from ..misc.image_io import save_video
from ..misc.utils import vis_depth_map
from ..model.decoder import Decoder
from ..model.types import Gaussians
@runtime_checkable
class TrajectoryFn(Protocol):
def __call__(
self,
t: Float[Tensor, " t"],
) -> tuple[
Float[Tensor, "batch view 4 4"], # extrinsics
Float[Tensor, "batch view 3 3"], # intrinsics
]:
pass
def render_video_wobble(
gaussians: Gaussians,
decoder: Decoder,
batch: BatchedExample,
num_frames: int = 60,
smooth: bool = True,
loop_reverse: bool = True,
add_depth: bool = False,
) -> Tensor:
# Two views are needed to get the wobble radius,use the first and the last view
_, v, _, _ = batch["context"]["extrinsics"].shape
def trajectory_fn(t):
origin_a = batch["context"]["extrinsics"][:, 0, :3, 3]
origin_b = batch["context"]["extrinsics"][:, -1, :3, 3]
delta = (origin_a - origin_b).norm(dim=-1)
extrinsics = generate_wobble(
batch["context"]["extrinsics"][:, 0],
delta * 0.25,
t,
)
intrinsics = repeat(
batch["context"]["intrinsics"][:, 0],
"b i j -> b v i j",
v=t.shape[0],
)
return extrinsics, intrinsics
return render_video_generic(gaussians, decoder, batch, trajectory_fn, num_frames, smooth, loop_reverse, add_depth)
def render_video_interpolation(
gaussians: Gaussians,
decoder: Decoder,
batch: BatchedExample,
num_frames: int = 60,
smooth: bool = True,
loop_reverse: bool = True,
add_depth: bool = False,
) -> Tensor:
_, v, _, _ = batch["context"]["extrinsics"].shape
def trajectory_fn(t):
extrinsics = interpolate_extrinsics(
batch["context"]["extrinsics"][0, 0],
batch["context"]["extrinsics"][0, -1],
t,
)
intrinsics = interpolate_intrinsics(
batch["context"]["intrinsics"][0, 0],
batch["context"]["intrinsics"][0, -1],
t,
)
return extrinsics[None], intrinsics[None]
return render_video_generic(gaussians, decoder, batch, trajectory_fn, num_frames, smooth, loop_reverse, add_depth)
def render_video_interpolation_exaggerated(
gaussians: Gaussians,
decoder: Decoder,
batch: BatchedExample,
num_frames: int = 300,
smooth: bool = False,
loop_reverse: bool = False,
add_depth: bool = False,
) -> Tensor:
# Two views are needed to get the wobble radius.
_, v, _, _ = batch["context"]["extrinsics"].shape
def trajectory_fn(t):
origin_a = batch["context"]["extrinsics"][:, 0, :3, 3]
origin_b = batch["context"]["extrinsics"][:, -1, :3, 3]
delta = (origin_a - origin_b).norm(dim=-1)
tf = generate_wobble_transformation(
delta * 0.5,
t,
5,
scale_radius_with_t=False,
)
extrinsics = interpolate_extrinsics(
batch["context"]["extrinsics"][0, 0],
batch["context"]["extrinsics"][0, -1],
t * 5 - 2,
)
intrinsics = interpolate_intrinsics(
batch["context"]["intrinsics"][0, 0],
batch["context"]["extrinsics"][0, -1],
t * 5 - 2,
)
return extrinsics @ tf, intrinsics[None]
return render_video_generic(gaussians, decoder, batch, trajectory_fn, num_frames, smooth, loop_reverse, add_depth)
def render_video_generic(
gaussians: Gaussians,
decoder: Decoder,
batch: BatchedExample,
trajectory_fn: TrajectoryFn,
num_frames: int = 30,
smooth: bool = True,
loop_reverse: bool = True,
add_depth: bool = False,
) -> Tensor:
device = gaussians.means.device
t = torch.linspace(0, 1, num_frames, dtype=torch.float32, device=device)
if smooth:
t = (torch.cos(torch.pi * (t + 1)) + 1) / 2
extrinsics, intrinsics = trajectory_fn(t)
_, _, _, h, w = batch["context"]["image"].shape
near = repeat(batch["context"]["near"][:, 0], "b -> b v", v=num_frames)
far = repeat(batch["context"]["far"][:, 0], "b -> b v", v=num_frames)
output = decoder.forward(
gaussians, extrinsics, intrinsics, near, far, (h, w), "depth"
)
images = [
vcat(rgb, depth) if add_depth else rgb
for rgb, depth in zip(output.color[0], vis_depth_map(output.depth[0]))
]
video = torch.stack(images)
# video = (video.clip(min=0, max=1) * 255).type(torch.uint8).cpu().numpy()
if loop_reverse:
# video = pack([video, video[::-1][1:-1]], "* c h w")[0]
video = pack([video, video.flip(dims=(0,))[1:-1]], "* c h w")[0]
return video