AnySplat / src /misc /image_io.py
alexnasa's picture
Upload 243 files
2568013 verified
import io
import os
from pathlib import Path
from typing import Union
import cv2
import numpy as np
import skvideo
import torch
import torchvision.transforms as tf
from einops import rearrange, repeat
from jaxtyping import Float, UInt8
from matplotlib import pyplot as plt
from matplotlib.figure import Figure
from PIL import Image
from torch import Tensor
FloatImage = Union[
Float[Tensor, "height width"],
Float[Tensor, "channel height width"],
Float[Tensor, "batch channel height width"],
]
def fig_to_image(
fig: Figure,
dpi: int = 100,
device: torch.device = torch.device("cpu"),
) -> Float[Tensor, "3 height width"]:
buffer = io.BytesIO()
fig.savefig(buffer, format="raw", dpi=dpi)
buffer.seek(0)
data = np.frombuffer(buffer.getvalue(), dtype=np.uint8)
h = int(fig.bbox.bounds[3])
w = int(fig.bbox.bounds[2])
data = rearrange(data, "(h w c) -> c h w", h=h, w=w, c=4)
buffer.close()
return (torch.tensor(data, device=device, dtype=torch.float32) / 255)[:3]
def prep_image(image: FloatImage) -> UInt8[np.ndarray, "height width channel"]:
# Handle batched images.
if image.ndim == 4:
image = rearrange(image, "b c h w -> c h (b w)")
# Handle single-channel images.
if image.ndim == 2:
image = rearrange(image, "h w -> () h w")
# Ensure that there are 3 or 4 channels.
channel, _, _ = image.shape
if channel == 1:
image = repeat(image, "() h w -> c h w", c=3)
assert image.shape[0] in (3, 4)
image = (image.detach().clip(min=0, max=1) * 255).type(torch.uint8)
return rearrange(image, "c h w -> h w c").cpu().numpy()
def save_image(
image: FloatImage,
path: Union[Path, str],
) -> None:
"""Save an image. Assumed to be in range 0-1."""
# Create the parent directory if it doesn't already exist.
path = Path(path)
path.parent.mkdir(exist_ok=True, parents=True)
# Save the image.
Image.fromarray(prep_image(image)).save(path)
def load_image(
path: Union[Path, str],
) -> Float[Tensor, "3 height width"]:
return tf.ToTensor()(Image.open(path))[:3]
def save_video(tensor, save_path, fps=10):
"""
Save a tensor of shape (N, C, H, W) as a video file using imageio.
Args:
tensor: Tensor of shape (N, C, H, W) in range [0, 1]
save_path: Path to save the video file
fps: Frames per second for the video
"""
# Convert tensor to numpy array and adjust dimensions
video = tensor.cpu().detach().numpy() # (N, C, H, W)
video = np.transpose(video, (0, 2, 3, 1)) # (N, H, W, C)
# Scale to [0, 255] and convert to uint8
video = (video * 255).astype(np.uint8)
# Ensure the directory exists
import os
os.makedirs(os.path.dirname(save_path), exist_ok=True)
# Use imageio to write video (handles codec compatibility automatically)
import imageio
writer = imageio.get_writer(save_path, fps=fps)
for frame in video:
writer.append_data(frame)
writer.close()
def save_interpolated_video(
pred_extrinsics, pred_intrinsics, b, h, w, gaussians, save_path, decoder_func, t=10
):
# Interpolate between neighboring frames
# t: Number of extra views to interpolate between each pair
interpolated_extrinsics = []
interpolated_intrinsics = []
# For each pair of neighboring frame
for i in range(pred_extrinsics.shape[1] - 1):
# Add the current frame
interpolated_extrinsics.append(pred_extrinsics[:, i : i + 1])
interpolated_intrinsics.append(pred_intrinsics[:, i : i + 1])
# Interpolate between current and next frame
for j in range(1, t + 1):
alpha = j / (t + 1)
# Interpolate extrinsics
start_extrinsic = pred_extrinsics[:, i]
end_extrinsic = pred_extrinsics[:, i + 1]
# Separate rotation and translation
start_rot = start_extrinsic[:, :3, :3]
end_rot = end_extrinsic[:, :3, :3]
start_trans = start_extrinsic[:, :3, 3]
end_trans = end_extrinsic[:, :3, 3]
# Interpolate translation (linear)
interp_trans = (1 - alpha) * start_trans + alpha * end_trans
# Interpolate rotation (spherical)
start_rot_flat = start_rot.reshape(b, 9)
end_rot_flat = end_rot.reshape(b, 9)
interp_rot_flat = (1 - alpha) * start_rot_flat + alpha * end_rot_flat
interp_rot = interp_rot_flat.reshape(b, 3, 3)
# Normalize rotation matrix to ensure it's orthogonal
u, _, v = torch.svd(interp_rot)
interp_rot = torch.bmm(u, v.transpose(1, 2))
# Combine interpolated rotation and translation
interp_extrinsic = (
torch.eye(4, device=pred_extrinsics.device).unsqueeze(0).repeat(b, 1, 1)
)
interp_extrinsic[:, :3, :3] = interp_rot
interp_extrinsic[:, :3, 3] = interp_trans
# Interpolate intrinsics (linear)
start_intrinsic = pred_intrinsics[:, i]
end_intrinsic = pred_intrinsics[:, i + 1]
interp_intrinsic = (1 - alpha) * start_intrinsic + alpha * end_intrinsic
# Add interpolated frame
interpolated_extrinsics.append(interp_extrinsic.unsqueeze(1))
interpolated_intrinsics.append(interp_intrinsic.unsqueeze(1))
# Concatenate all frames
pred_all_extrinsic = torch.cat(interpolated_extrinsics, dim=1)
pred_all_intrinsic = torch.cat(interpolated_intrinsics, dim=1)
# Add the last frame
interpolated_extrinsics.append(pred_all_extrinsic[:, -1:])
interpolated_intrinsics.append(pred_all_intrinsic[:, -1:])
# Update K to reflect the new number of frames
num_frames = pred_all_extrinsic.shape[1]
# Render interpolated views
interpolated_output = decoder_func.forward(
gaussians,
pred_all_extrinsic,
pred_all_intrinsic.float(),
torch.ones(1, num_frames, device=pred_all_extrinsic.device) * 0.1,
torch.ones(1, num_frames, device=pred_all_extrinsic.device) * 100,
(h, w),
)
# Convert to video format
video = interpolated_output.color[0].clip(min=0, max=1)
depth = interpolated_output.depth[0]
# Normalize depth for visualization
# to avoid `quantile() input tensor is too large`
num_views = pred_extrinsics.shape[1]
depth_norm = (depth - depth[::num_views].quantile(0.01)) / (
depth[::num_views].quantile(0.99) - depth[::num_views].quantile(0.01)
)
depth_norm = plt.cm.turbo(depth_norm.cpu().numpy())
depth_colored = (
torch.from_numpy(depth_norm[..., :3]).permute(0, 3, 1, 2).to(depth.device)
)
depth_colored = depth_colored.clip(min=0, max=1)
# Save depth video
save_video(depth_colored, os.path.join(save_path, f"depth.mp4"))
# Save video
save_video(video, os.path.join(save_path, f"rgb.mp4"))
return os.path.join(save_path, f"rgb.mp4"), os.path.join(save_path, f"depth.mp4")