DiffusionSfM / diffusionsfm /utils /visualization.py
qitaoz's picture
Upload 57 files
4562a06 verified
from http.client import MOVED_PERMANENTLY
import io
import ipdb # noqa: F401
import matplotlib.pyplot as plt
import numpy as np
import trimesh
import torch
import torchvision
from pytorch3d.loss import chamfer_distance
from scipy.spatial.transform import Rotation
from diffusionsfm.inference.ddim import inference_ddim
from diffusionsfm.utils.rays import (
Rays,
cameras_to_rays,
rays_to_cameras,
rays_to_cameras_homography,
)
from diffusionsfm.utils.geometry import (
compute_optimal_alignment,
)
cmap = plt.get_cmap("hsv")
def create_training_visualizations(
model,
images,
device,
cameras_gt,
num_images,
crop_parameters,
pred_x0=False,
no_crop_param_device="cpu",
visualize_pred=False,
return_first=False,
calculate_intrinsics=False,
mode=None,
depths=None,
scale_min=-1,
scale_max=1,
diffuse_depths=False,
vis_mode=None,
average_centers=True,
full_num_patches_x=16,
full_num_patches_y=16,
use_homogeneous=False,
distortion_coefficients=None,
):
if model.depth_resolution == 1:
W_in = W_out = full_num_patches_x
H_in = H_out = full_num_patches_y
else:
W_in = H_in = model.width
W_out = model.width * model.depth_resolution
H_out = model.width * model.depth_resolution
rays_final, rays_intermediate, pred_intermediate, _ = inference_ddim(
model,
images,
device,
crop_parameters=crop_parameters,
eta=[1, 0],
num_patches_x=W_in,
num_patches_y=H_in,
visualize=True,
)
if vis_mode is None:
vis_mode = mode
T = model.noise_scheduler.max_timesteps
if T == 1000:
ts = [0, 100, 200, 300, 400, 500, 600, 700, 800, 900, 999]
else:
ts = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 99]
# Get predicted cameras from rays
pred_cameras_batched = []
vis_images = []
pred_rays = []
for index in range(len(images)):
pred_cameras = []
per_sample_images = []
for ii in range(num_images):
rays_gt = cameras_to_rays(
cameras_gt[index],
crop_parameters[index],
no_crop_param_device=no_crop_param_device,
num_patches_x=W_in,
num_patches_y=H_in,
depths=None if depths is None else depths[index],
mode=mode,
depth_resolution=model.depth_resolution,
distortion_coefficients=(
None
if distortion_coefficients is None
else distortion_coefficients[index]
),
)
image_vis = (images[index, ii].cpu().permute(1, 2, 0).numpy() + 1) / 2
if diffuse_depths:
fig, axs = plt.subplots(3, 13, figsize=(15, 4.5), dpi=100)
else:
fig, axs = plt.subplots(3, 9, figsize=(12, 4.5), dpi=100)
for i, t in enumerate(ts):
r, c = i // 4, i % 4
if visualize_pred:
curr = pred_intermediate[t][index]
else:
curr = rays_intermediate[t][index]
rays = Rays.from_spatial(
curr,
mode=mode,
num_patches_x=H_in,
num_patches_y=W_in,
use_homogeneous=use_homogeneous,
)
if vis_mode == "segment":
vis = (
torch.clip(
rays.get_segments()[ii], min=scale_min, max=scale_max
)
- scale_min
) / (scale_max - scale_min)
else:
vis = (
torch.nn.functional.normalize(rays.get_moments()[ii], dim=-1)
+ 1
) / 2
axs[r, c].imshow(vis.reshape(W_out, H_out, 3).cpu())
axs[r, c].set_title(f"T={T - t}")
i += 1
r, c = i // 4, i % 4
if vis_mode == "segment":
vis = (
torch.clip(rays_gt.get_segments()[ii], min=scale_min, max=scale_max)
- scale_min
) / (scale_max - scale_min)
else:
vis = (
torch.nn.functional.normalize(rays_gt.get_moments()[ii], dim=-1) + 1
) / 2
axs[r, c].imshow(vis.reshape(W_out, H_out, 3).cpu())
type_str = "Endpoints" if vis_mode == "segment" else "Moments"
axs[r, c].set_title(f"GT {type_str}")
for i, t in enumerate(ts):
r, c = i // 4, i % 4 + 4
if visualize_pred:
curr = pred_intermediate[t][index]
else:
curr = rays_intermediate[t][index]
rays = Rays.from_spatial(
curr,
mode,
num_patches_x=H_in,
num_patches_y=W_in,
use_homogeneous=use_homogeneous,
)
if vis_mode == "segment":
vis = (
torch.clip(
rays.get_origins(high_res=True)[ii],
min=scale_min,
max=scale_max,
)
- scale_min
) / (scale_max - scale_min)
else:
vis = (
torch.nn.functional.normalize(rays.get_directions()[ii], dim=-1)
+ 1
) / 2
axs[r, c].imshow(vis.reshape(W_out, H_out, 3).cpu())
axs[r, c].set_title(f"T={T - t}")
i += 1
r, c = i // 4, i % 4 + 4
if vis_mode == "segment":
vis = (
torch.clip(
rays_gt.get_origins(high_res=True)[ii],
min=scale_min,
max=scale_max,
)
- scale_min
) / (scale_max - scale_min)
else:
vis = (
torch.nn.functional.normalize(rays_gt.get_directions()[ii], dim=-1)
+ 1
) / 2
axs[r, c].imshow(vis.reshape(W_out, H_out, 3).cpu())
type_str = "Origins" if vis_mode == "segment" else "Directions"
axs[r, c].set_title(f"GT {type_str}")
if diffuse_depths:
for i, t in enumerate(ts):
r, c = i // 4, i % 4 + 8
if visualize_pred:
curr = pred_intermediate[t][index]
else:
curr = rays_intermediate[t][index]
rays = Rays.from_spatial(
curr,
mode,
num_patches_x=H_in,
num_patches_y=W_in,
use_homogeneous=use_homogeneous,
)
vis = rays.depths[ii]
if len(rays.depths[ii].shape) < 2:
vis = rays.depths[ii].reshape(H_out, W_out)
axs[r, c].imshow(vis.cpu())
axs[r, c].set_title(f"T={T - t}")
i += 1
r, c = i // 4, i % 4 + 8
vis = depths[index][ii]
if len(rays.depths[ii].shape) < 2:
vis = depths[index][ii].reshape(256, 256)
axs[r, c].imshow(vis.cpu())
axs[r, c].set_title(f"GT Depths")
axs[2, -1].imshow(image_vis)
axs[2, -1].set_title("Input Image")
for s in ["bottom", "top", "left", "right"]:
axs[2, -1].spines[s].set_color(cmap(ii / (num_images)))
axs[2, -1].spines[s].set_linewidth(5)
for ax in axs.flatten():
ax.set_xticks([])
ax.set_yticks([])
plt.tight_layout()
img = plot_to_image(fig)
plt.close()
per_sample_images.append(img)
if return_first:
rays_camera = pred_intermediate[0][index]
elif pred_x0:
rays_camera = pred_intermediate[-1][index]
else:
rays_camera = rays_final[index]
rays = Rays.from_spatial(
rays_camera,
mode=mode,
num_patches_x=H_in,
num_patches_y=W_in,
use_homogeneous=use_homogeneous,
)
if calculate_intrinsics:
pred_camera = rays_to_cameras_homography(
rays=rays[ii, None],
crop_parameters=crop_parameters[index],
num_patches_x=W_in,
num_patches_y=H_in,
average_centers=average_centers,
depth_resolution=model.depth_resolution,
)
else:
pred_camera = rays_to_cameras(
rays=rays[ii, None],
crop_parameters=crop_parameters[index],
no_crop_param_device=no_crop_param_device,
num_patches_x=W_in,
num_patches_y=H_in,
depth_resolution=model.depth_resolution,
average_centers=average_centers,
)
pred_cameras.append(pred_camera[0])
pred_rays.append(rays)
pred_cameras_batched.append(pred_cameras)
vis_images.append(np.vstack(per_sample_images))
return vis_images, pred_cameras_batched, pred_rays
def plot_to_image(figure, dpi=100):
"""Converts matplotlib fig to a png for logging with tf.summary.image."""
buffer = io.BytesIO()
figure.savefig(buffer, format="raw", dpi=dpi)
plt.close(figure)
buffer.seek(0)
image = np.reshape(
np.frombuffer(buffer.getvalue(), dtype=np.uint8),
newshape=(int(figure.bbox.bounds[3]), int(figure.bbox.bounds[2]), -1),
)
return image[..., :3]
def view_color_coded_images_from_tensor(images, depth=False):
num_frames = images.shape[0]
cmap = plt.get_cmap("hsv")
num_rows = 3
num_cols = 3
figsize = (num_cols * 2, num_rows * 2)
fig, axs = plt.subplots(num_rows, num_cols, figsize=figsize)
axs = axs.flatten()
for i in range(num_rows * num_cols):
if i < num_frames:
if images[i].shape[0] == 3:
image = images[i].permute(1, 2, 0)
else:
image = images[i].unsqueeze(-1)
if not depth:
image = image * 0.5 + 0.5
else:
image = image.repeat((1, 1, 3)) / torch.max(image)
axs[i].imshow(image)
for s in ["bottom", "top", "left", "right"]:
axs[i].spines[s].set_color(cmap(i / (num_frames)))
axs[i].spines[s].set_linewidth(5)
axs[i].set_xticks([])
axs[i].set_yticks([])
else:
axs[i].axis("off")
plt.tight_layout()
return fig
def color_and_filter_points(points, images, mask, num_show, resolution):
# Resize images
resize = torchvision.transforms.Resize(resolution)
images = resize(images) * 0.5 + 0.5
# Reshape points and calculate mask
points = points.reshape(num_show * resolution * resolution, 3)
mask = mask.reshape(num_show * resolution * resolution)
depth_mask = torch.argwhere(mask > 0.5)[:, 0]
points = points[depth_mask]
# Mask and reshape colors
colors = images.permute(0, 2, 3, 1).reshape(num_show * resolution * resolution, 3)
colors = colors[depth_mask]
return points, colors
def filter_and_align_point_clouds(
num_frames,
gt_points,
pred_points,
gt_masks,
pred_masks,
images,
metrics=False,
num_patches_x=16,
):
# Filter and color points
gt_points, gt_colors = color_and_filter_points(
gt_points, images, gt_masks, num_show=num_frames, resolution=num_patches_x
)
pred_points, pred_colors = color_and_filter_points(
pred_points, images, pred_masks, num_show=num_frames, resolution=num_patches_x
)
pred_points, _, _, _ = compute_optimal_alignment(
gt_points.float(), pred_points.float()
)
# Scale PCL so that furthest point from centroid is distance 1
centroid = torch.mean(gt_points, dim=0)
dists = torch.norm(gt_points - centroid.unsqueeze(0), dim=-1)
scale = torch.mean(dists)
gt_points_scaled = (gt_points - centroid) / scale
pred_points_scaled = (pred_points - centroid) / scale
if metrics:
cd, _ = chamfer_distance(
pred_points_scaled.unsqueeze(0), gt_points_scaled.unsqueeze(0)
)
cd = cd.item()
mse = torch.mean(
torch.norm(pred_points_scaled - gt_points_scaled, dim=-1), dim=-1
).item()
else:
mse, cd = None, None
return (
gt_points,
pred_points,
gt_colors,
pred_colors,
[mse, cd, None],
)
def add_scene_cam(scene, c2w, edge_color, image=None, focal=None, imsize=None, screen_width=0.03):
OPENGL = np.array([
[1, 0, 0, 0],
[0, -1, 0, 0],
[0, 0, -1, 0],
[0, 0, 0, 1]
])
if image is not None:
H, W, THREE = image.shape
assert THREE == 3
if image.dtype != np.uint8:
image = np.uint8(255*image)
elif imsize is not None:
W, H = imsize
elif focal is not None:
H = W = focal / 1.1
else:
H = W = 1
if focal is None:
focal = min(H, W) * 1.1 # default value
elif isinstance(focal, np.ndarray):
focal = focal[0]
# create fake camera
height = focal * screen_width / H
width = screen_width * 0.5**0.5
rot45 = np.eye(4)
rot45[:3, :3] = Rotation.from_euler('z', np.deg2rad(45)).as_matrix()
rot45[2, 3] = -height # set the tip of the cone = optical center
aspect_ratio = np.eye(4)
aspect_ratio[0, 0] = W/H
transform = c2w @ OPENGL @ aspect_ratio @ rot45
cam = trimesh.creation.cone(width, height, sections=4)
# this is the camera mesh
rot2 = np.eye(4)
rot2[:3, :3] = Rotation.from_euler('z', np.deg2rad(4)).as_matrix()
vertices = cam.vertices
vertices_offset = 0.9 * cam.vertices
vertices = np.r_[vertices, vertices_offset, geotrf(rot2, cam.vertices)]
vertices = geotrf(transform, vertices)
faces = []
for face in cam.faces:
if 0 in face:
continue
a, b, c = face
a2, b2, c2 = face + len(cam.vertices)
# add 3 pseudo-edges
faces.append((a, b, b2))
faces.append((a, a2, c))
faces.append((c2, b, c))
faces.append((a, b2, a2))
faces.append((a2, c, c2))
faces.append((c2, b2, b))
# no culling
faces += [(c, b, a) for a, b, c in faces]
for i,face in enumerate(cam.faces):
if 0 in face:
continue
if i == 1 or i == 5:
a, b, c = face
faces.append((a, b, c))
cam = trimesh.Trimesh(vertices=vertices, faces=faces)
cam.visual.face_colors[:, :3] = edge_color
scene.add_geometry(cam)
def geotrf(Trf, pts, ncol=None, norm=False):
""" Apply a geometric transformation to a list of 3-D points.
H: 3x3 or 4x4 projection matrix (typically a Homography)
p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3)
ncol: int. number of columns of the result (2 or 3)
norm: float. if != 0, the resut is projected on the z=norm plane.
Returns an array of projected 2d points.
"""
assert Trf.ndim >= 2
if isinstance(Trf, np.ndarray):
pts = np.asarray(pts)
elif isinstance(Trf, torch.Tensor):
pts = torch.as_tensor(pts, dtype=Trf.dtype)
# adapt shape if necessary
output_reshape = pts.shape[:-1]
ncol = ncol or pts.shape[-1]
# optimized code
if (isinstance(Trf, torch.Tensor) and isinstance(pts, torch.Tensor) and
Trf.ndim == 3 and pts.ndim == 4):
d = pts.shape[3]
if Trf.shape[-1] == d:
pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts)
elif Trf.shape[-1] == d+1:
pts = torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) + Trf[:, None, None, :d, d]
else:
raise ValueError(f'bad shape, not ending with 3 or 4, for {pts.shape=}')
else:
if Trf.ndim >= 3:
n = Trf.ndim-2
assert Trf.shape[:n] == pts.shape[:n], 'batch size does not match'
Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1])
if pts.ndim > Trf.ndim:
# Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d)
pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1])
elif pts.ndim == 2:
# Trf == (B,d,d) & pts == (B,d) --> (B, 1, d)
pts = pts[:, None, :]
if pts.shape[-1]+1 == Trf.shape[-1]:
Trf = Trf.swapaxes(-1, -2) # transpose Trf
pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :]
elif pts.shape[-1] == Trf.shape[-1]:
Trf = Trf.swapaxes(-1, -2) # transpose Trf
pts = pts @ Trf
else:
pts = Trf @ pts.T
if pts.ndim >= 2:
pts = pts.swapaxes(-1, -2)
if norm:
pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG
if norm != 1:
pts *= norm
res = pts[..., :ncol].reshape(*output_reshape, ncol)
return res