|
import io |
|
import os |
|
from pathlib import Path |
|
from typing import Union |
|
|
|
import cv2 |
|
import numpy as np |
|
import skvideo |
|
import torch |
|
import torchvision.transforms as tf |
|
from einops import rearrange, repeat |
|
from jaxtyping import Float, UInt8 |
|
|
|
from matplotlib import pyplot as plt |
|
from matplotlib.figure import Figure |
|
from PIL import Image |
|
from torch import Tensor |
|
|
|
FloatImage = Union[ |
|
Float[Tensor, "height width"], |
|
Float[Tensor, "channel height width"], |
|
Float[Tensor, "batch channel height width"], |
|
] |
|
|
|
|
|
def fig_to_image( |
|
fig: Figure, |
|
dpi: int = 100, |
|
device: torch.device = torch.device("cpu"), |
|
) -> Float[Tensor, "3 height width"]: |
|
buffer = io.BytesIO() |
|
fig.savefig(buffer, format="raw", dpi=dpi) |
|
buffer.seek(0) |
|
data = np.frombuffer(buffer.getvalue(), dtype=np.uint8) |
|
h = int(fig.bbox.bounds[3]) |
|
w = int(fig.bbox.bounds[2]) |
|
data = rearrange(data, "(h w c) -> c h w", h=h, w=w, c=4) |
|
buffer.close() |
|
return (torch.tensor(data, device=device, dtype=torch.float32) / 255)[:3] |
|
|
|
|
|
def prep_image(image: FloatImage) -> UInt8[np.ndarray, "height width channel"]: |
|
|
|
if image.ndim == 4: |
|
image = rearrange(image, "b c h w -> c h (b w)") |
|
|
|
|
|
if image.ndim == 2: |
|
image = rearrange(image, "h w -> () h w") |
|
|
|
|
|
channel, _, _ = image.shape |
|
if channel == 1: |
|
image = repeat(image, "() h w -> c h w", c=3) |
|
assert image.shape[0] in (3, 4) |
|
|
|
image = (image.detach().clip(min=0, max=1) * 255).type(torch.uint8) |
|
return rearrange(image, "c h w -> h w c").cpu().numpy() |
|
|
|
|
|
def save_image( |
|
image: FloatImage, |
|
path: Union[Path, str], |
|
) -> None: |
|
"""Save an image. Assumed to be in range 0-1.""" |
|
|
|
|
|
path = Path(path) |
|
path.parent.mkdir(exist_ok=True, parents=True) |
|
|
|
|
|
Image.fromarray(prep_image(image)).save(path) |
|
|
|
|
|
def load_image( |
|
path: Union[Path, str], |
|
) -> Float[Tensor, "3 height width"]: |
|
return tf.ToTensor()(Image.open(path))[:3] |
|
|
|
|
|
def save_video(tensor, save_path, fps=10): |
|
""" |
|
Save a tensor of shape (N, C, H, W) as a video file using imageio. |
|
Args: |
|
tensor: Tensor of shape (N, C, H, W) in range [0, 1] |
|
save_path: Path to save the video file |
|
fps: Frames per second for the video |
|
""" |
|
|
|
video = tensor.cpu().detach().numpy() |
|
video = np.transpose(video, (0, 2, 3, 1)) |
|
|
|
|
|
video = (video * 255).astype(np.uint8) |
|
|
|
|
|
import os |
|
|
|
os.makedirs(os.path.dirname(save_path), exist_ok=True) |
|
|
|
|
|
import imageio |
|
|
|
writer = imageio.get_writer(save_path, fps=fps) |
|
|
|
for frame in video: |
|
writer.append_data(frame) |
|
|
|
writer.close() |
|
|
|
|
|
def save_interpolated_video( |
|
pred_extrinsics, pred_intrinsics, b, h, w, gaussians, save_path, decoder_func, t=10 |
|
): |
|
|
|
|
|
interpolated_extrinsics = [] |
|
interpolated_intrinsics = [] |
|
|
|
|
|
for i in range(pred_extrinsics.shape[1] - 1): |
|
|
|
interpolated_extrinsics.append(pred_extrinsics[:, i : i + 1]) |
|
interpolated_intrinsics.append(pred_intrinsics[:, i : i + 1]) |
|
|
|
|
|
for j in range(1, t + 1): |
|
alpha = j / (t + 1) |
|
|
|
|
|
start_extrinsic = pred_extrinsics[:, i] |
|
end_extrinsic = pred_extrinsics[:, i + 1] |
|
|
|
|
|
start_rot = start_extrinsic[:, :3, :3] |
|
end_rot = end_extrinsic[:, :3, :3] |
|
start_trans = start_extrinsic[:, :3, 3] |
|
end_trans = end_extrinsic[:, :3, 3] |
|
|
|
|
|
interp_trans = (1 - alpha) * start_trans + alpha * end_trans |
|
|
|
|
|
start_rot_flat = start_rot.reshape(b, 9) |
|
end_rot_flat = end_rot.reshape(b, 9) |
|
interp_rot_flat = (1 - alpha) * start_rot_flat + alpha * end_rot_flat |
|
interp_rot = interp_rot_flat.reshape(b, 3, 3) |
|
|
|
|
|
u, _, v = torch.svd(interp_rot) |
|
interp_rot = torch.bmm(u, v.transpose(1, 2)) |
|
|
|
|
|
interp_extrinsic = ( |
|
torch.eye(4, device=pred_extrinsics.device).unsqueeze(0).repeat(b, 1, 1) |
|
) |
|
interp_extrinsic[:, :3, :3] = interp_rot |
|
interp_extrinsic[:, :3, 3] = interp_trans |
|
|
|
|
|
start_intrinsic = pred_intrinsics[:, i] |
|
end_intrinsic = pred_intrinsics[:, i + 1] |
|
interp_intrinsic = (1 - alpha) * start_intrinsic + alpha * end_intrinsic |
|
|
|
|
|
interpolated_extrinsics.append(interp_extrinsic.unsqueeze(1)) |
|
interpolated_intrinsics.append(interp_intrinsic.unsqueeze(1)) |
|
|
|
|
|
pred_all_extrinsic = torch.cat(interpolated_extrinsics, dim=1) |
|
pred_all_intrinsic = torch.cat(interpolated_intrinsics, dim=1) |
|
|
|
|
|
interpolated_extrinsics.append(pred_all_extrinsic[:, -1:]) |
|
interpolated_intrinsics.append(pred_all_intrinsic[:, -1:]) |
|
|
|
|
|
num_frames = pred_all_extrinsic.shape[1] |
|
|
|
|
|
interpolated_output = decoder_func.forward( |
|
gaussians, |
|
pred_all_extrinsic, |
|
pred_all_intrinsic.float(), |
|
torch.ones(1, num_frames, device=pred_all_extrinsic.device) * 0.1, |
|
torch.ones(1, num_frames, device=pred_all_extrinsic.device) * 100, |
|
(h, w), |
|
) |
|
|
|
|
|
video = interpolated_output.color[0].clip(min=0, max=1) |
|
depth = interpolated_output.depth[0] |
|
|
|
|
|
|
|
num_views = pred_extrinsics.shape[1] |
|
depth_norm = (depth - depth[::num_views].quantile(0.01)) / ( |
|
depth[::num_views].quantile(0.99) - depth[::num_views].quantile(0.01) |
|
) |
|
depth_norm = plt.cm.turbo(depth_norm.cpu().numpy()) |
|
depth_colored = ( |
|
torch.from_numpy(depth_norm[..., :3]).permute(0, 3, 1, 2).to(depth.device) |
|
) |
|
depth_colored = depth_colored.clip(min=0, max=1) |
|
|
|
|
|
save_video(depth_colored, os.path.join(save_path, f"depth.mp4")) |
|
|
|
save_video(video, os.path.join(save_path, f"rgb.mp4")) |
|
|
|
return os.path.join(save_path, f"rgb.mp4"), os.path.join(save_path, f"depth.mp4") |
|
|