File size: 5,242 Bytes
2568013 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
from typing import Protocol, runtime_checkable
import cv2
import torch
from einops import repeat, pack
from jaxtyping import Float
from torch import Tensor
from .camera_trajectory.interpolation import interpolate_extrinsics, interpolate_intrinsics
from .camera_trajectory.wobble import generate_wobble, generate_wobble_transformation
from .layout import vcat
from ..dataset.types import BatchedExample
from ..misc.image_io import save_video
from ..misc.utils import vis_depth_map
from ..model.decoder import Decoder
from ..model.types import Gaussians
@runtime_checkable
class TrajectoryFn(Protocol):
def __call__(
self,
t: Float[Tensor, " t"],
) -> tuple[
Float[Tensor, "batch view 4 4"], # extrinsics
Float[Tensor, "batch view 3 3"], # intrinsics
]:
pass
def render_video_wobble(
gaussians: Gaussians,
decoder: Decoder,
batch: BatchedExample,
num_frames: int = 60,
smooth: bool = True,
loop_reverse: bool = True,
add_depth: bool = False,
) -> Tensor:
# Two views are needed to get the wobble radius,use the first and the last view
_, v, _, _ = batch["context"]["extrinsics"].shape
def trajectory_fn(t):
origin_a = batch["context"]["extrinsics"][:, 0, :3, 3]
origin_b = batch["context"]["extrinsics"][:, -1, :3, 3]
delta = (origin_a - origin_b).norm(dim=-1)
extrinsics = generate_wobble(
batch["context"]["extrinsics"][:, 0],
delta * 0.25,
t,
)
intrinsics = repeat(
batch["context"]["intrinsics"][:, 0],
"b i j -> b v i j",
v=t.shape[0],
)
return extrinsics, intrinsics
return render_video_generic(gaussians, decoder, batch, trajectory_fn, num_frames, smooth, loop_reverse, add_depth)
def render_video_interpolation(
gaussians: Gaussians,
decoder: Decoder,
batch: BatchedExample,
num_frames: int = 60,
smooth: bool = True,
loop_reverse: bool = True,
add_depth: bool = False,
) -> Tensor:
_, v, _, _ = batch["context"]["extrinsics"].shape
def trajectory_fn(t):
extrinsics = interpolate_extrinsics(
batch["context"]["extrinsics"][0, 0],
batch["context"]["extrinsics"][0, -1],
t,
)
intrinsics = interpolate_intrinsics(
batch["context"]["intrinsics"][0, 0],
batch["context"]["intrinsics"][0, -1],
t,
)
return extrinsics[None], intrinsics[None]
return render_video_generic(gaussians, decoder, batch, trajectory_fn, num_frames, smooth, loop_reverse, add_depth)
def render_video_interpolation_exaggerated(
gaussians: Gaussians,
decoder: Decoder,
batch: BatchedExample,
num_frames: int = 300,
smooth: bool = False,
loop_reverse: bool = False,
add_depth: bool = False,
) -> Tensor:
# Two views are needed to get the wobble radius.
_, v, _, _ = batch["context"]["extrinsics"].shape
def trajectory_fn(t):
origin_a = batch["context"]["extrinsics"][:, 0, :3, 3]
origin_b = batch["context"]["extrinsics"][:, -1, :3, 3]
delta = (origin_a - origin_b).norm(dim=-1)
tf = generate_wobble_transformation(
delta * 0.5,
t,
5,
scale_radius_with_t=False,
)
extrinsics = interpolate_extrinsics(
batch["context"]["extrinsics"][0, 0],
batch["context"]["extrinsics"][0, -1],
t * 5 - 2,
)
intrinsics = interpolate_intrinsics(
batch["context"]["intrinsics"][0, 0],
batch["context"]["extrinsics"][0, -1],
t * 5 - 2,
)
return extrinsics @ tf, intrinsics[None]
return render_video_generic(gaussians, decoder, batch, trajectory_fn, num_frames, smooth, loop_reverse, add_depth)
def render_video_generic(
gaussians: Gaussians,
decoder: Decoder,
batch: BatchedExample,
trajectory_fn: TrajectoryFn,
num_frames: int = 30,
smooth: bool = True,
loop_reverse: bool = True,
add_depth: bool = False,
) -> Tensor:
device = gaussians.means.device
t = torch.linspace(0, 1, num_frames, dtype=torch.float32, device=device)
if smooth:
t = (torch.cos(torch.pi * (t + 1)) + 1) / 2
extrinsics, intrinsics = trajectory_fn(t)
_, _, _, h, w = batch["context"]["image"].shape
near = repeat(batch["context"]["near"][:, 0], "b -> b v", v=num_frames)
far = repeat(batch["context"]["far"][:, 0], "b -> b v", v=num_frames)
output = decoder.forward(
gaussians, extrinsics, intrinsics, near, far, (h, w), "depth"
)
images = [
vcat(rgb, depth) if add_depth else rgb
for rgb, depth in zip(output.color[0], vis_depth_map(output.depth[0]))
]
video = torch.stack(images)
# video = (video.clip(min=0, max=1) * 255).type(torch.uint8).cpu().numpy()
if loop_reverse:
# video = pack([video, video[::-1][1:-1]], "* c h w")[0]
video = pack([video, video.flip(dims=(0,))[1:-1]], "* c h w")[0]
return video
|