|
from typing import Protocol, runtime_checkable |
|
|
|
import cv2 |
|
import torch |
|
from einops import repeat, pack |
|
from jaxtyping import Float |
|
from torch import Tensor |
|
|
|
from .camera_trajectory.interpolation import interpolate_extrinsics, interpolate_intrinsics |
|
from .camera_trajectory.wobble import generate_wobble, generate_wobble_transformation |
|
from .layout import vcat |
|
from ..dataset.types import BatchedExample |
|
from ..misc.image_io import save_video |
|
from ..misc.utils import vis_depth_map |
|
from ..model.decoder import Decoder |
|
from ..model.types import Gaussians |
|
|
|
|
|
@runtime_checkable |
|
class TrajectoryFn(Protocol): |
|
def __call__( |
|
self, |
|
t: Float[Tensor, " t"], |
|
) -> tuple[ |
|
Float[Tensor, "batch view 4 4"], |
|
Float[Tensor, "batch view 3 3"], |
|
]: |
|
pass |
|
|
|
|
|
def render_video_wobble( |
|
gaussians: Gaussians, |
|
decoder: Decoder, |
|
batch: BatchedExample, |
|
num_frames: int = 60, |
|
smooth: bool = True, |
|
loop_reverse: bool = True, |
|
add_depth: bool = False, |
|
) -> Tensor: |
|
|
|
_, v, _, _ = batch["context"]["extrinsics"].shape |
|
|
|
def trajectory_fn(t): |
|
origin_a = batch["context"]["extrinsics"][:, 0, :3, 3] |
|
origin_b = batch["context"]["extrinsics"][:, -1, :3, 3] |
|
delta = (origin_a - origin_b).norm(dim=-1) |
|
extrinsics = generate_wobble( |
|
batch["context"]["extrinsics"][:, 0], |
|
delta * 0.25, |
|
t, |
|
) |
|
intrinsics = repeat( |
|
batch["context"]["intrinsics"][:, 0], |
|
"b i j -> b v i j", |
|
v=t.shape[0], |
|
) |
|
return extrinsics, intrinsics |
|
|
|
return render_video_generic(gaussians, decoder, batch, trajectory_fn, num_frames, smooth, loop_reverse, add_depth) |
|
|
|
|
|
def render_video_interpolation( |
|
gaussians: Gaussians, |
|
decoder: Decoder, |
|
batch: BatchedExample, |
|
num_frames: int = 60, |
|
smooth: bool = True, |
|
loop_reverse: bool = True, |
|
add_depth: bool = False, |
|
) -> Tensor: |
|
_, v, _, _ = batch["context"]["extrinsics"].shape |
|
|
|
def trajectory_fn(t): |
|
extrinsics = interpolate_extrinsics( |
|
batch["context"]["extrinsics"][0, 0], |
|
batch["context"]["extrinsics"][0, -1], |
|
t, |
|
) |
|
intrinsics = interpolate_intrinsics( |
|
batch["context"]["intrinsics"][0, 0], |
|
batch["context"]["intrinsics"][0, -1], |
|
t, |
|
) |
|
return extrinsics[None], intrinsics[None] |
|
|
|
return render_video_generic(gaussians, decoder, batch, trajectory_fn, num_frames, smooth, loop_reverse, add_depth) |
|
|
|
|
|
def render_video_interpolation_exaggerated( |
|
gaussians: Gaussians, |
|
decoder: Decoder, |
|
batch: BatchedExample, |
|
num_frames: int = 300, |
|
smooth: bool = False, |
|
loop_reverse: bool = False, |
|
add_depth: bool = False, |
|
) -> Tensor: |
|
|
|
_, v, _, _ = batch["context"]["extrinsics"].shape |
|
|
|
def trajectory_fn(t): |
|
origin_a = batch["context"]["extrinsics"][:, 0, :3, 3] |
|
origin_b = batch["context"]["extrinsics"][:, -1, :3, 3] |
|
delta = (origin_a - origin_b).norm(dim=-1) |
|
tf = generate_wobble_transformation( |
|
delta * 0.5, |
|
t, |
|
5, |
|
scale_radius_with_t=False, |
|
) |
|
extrinsics = interpolate_extrinsics( |
|
batch["context"]["extrinsics"][0, 0], |
|
batch["context"]["extrinsics"][0, -1], |
|
t * 5 - 2, |
|
) |
|
intrinsics = interpolate_intrinsics( |
|
batch["context"]["intrinsics"][0, 0], |
|
batch["context"]["extrinsics"][0, -1], |
|
t * 5 - 2, |
|
) |
|
return extrinsics @ tf, intrinsics[None] |
|
|
|
return render_video_generic(gaussians, decoder, batch, trajectory_fn, num_frames, smooth, loop_reverse, add_depth) |
|
|
|
|
|
def render_video_generic( |
|
gaussians: Gaussians, |
|
decoder: Decoder, |
|
batch: BatchedExample, |
|
trajectory_fn: TrajectoryFn, |
|
num_frames: int = 30, |
|
smooth: bool = True, |
|
loop_reverse: bool = True, |
|
add_depth: bool = False, |
|
) -> Tensor: |
|
device = gaussians.means.device |
|
|
|
t = torch.linspace(0, 1, num_frames, dtype=torch.float32, device=device) |
|
if smooth: |
|
t = (torch.cos(torch.pi * (t + 1)) + 1) / 2 |
|
|
|
extrinsics, intrinsics = trajectory_fn(t) |
|
|
|
_, _, _, h, w = batch["context"]["image"].shape |
|
|
|
near = repeat(batch["context"]["near"][:, 0], "b -> b v", v=num_frames) |
|
far = repeat(batch["context"]["far"][:, 0], "b -> b v", v=num_frames) |
|
output = decoder.forward( |
|
gaussians, extrinsics, intrinsics, near, far, (h, w), "depth" |
|
) |
|
images = [ |
|
vcat(rgb, depth) if add_depth else rgb |
|
for rgb, depth in zip(output.color[0], vis_depth_map(output.depth[0])) |
|
] |
|
|
|
video = torch.stack(images) |
|
|
|
if loop_reverse: |
|
|
|
video = pack([video, video.flip(dims=(0,))[1:-1]], "* c h w")[0] |
|
|
|
return video |
|
|