AnySplat / src /model /encoder /visualization /encoder_visualizer_epipolar.py
alexnasa's picture
Upload 243 files
2568013 verified
raw
history blame
19.1 kB
from pathlib import Path
from random import randrange
from typing import Optional
import numpy as np
import torch
import wandb
from einops import rearrange, reduce, repeat
from jaxtyping import Bool, Float
from torch import Tensor
from ....dataset.types import BatchedViews
from ....misc.heterogeneous_pairings import generate_heterogeneous_index
from ....visualization.annotation import add_label
from ....visualization.color_map import apply_color_map, apply_color_map_to_image
from ....visualization.colors import get_distinct_color
from ....visualization.drawing.lines import draw_lines
from ....visualization.drawing.points import draw_points
from ....visualization.layout import add_border, hcat, vcat
from ...ply_export import export_ply
from .encoder_visualizer import EncoderVisualizer
from .encoder_visualizer_epipolar_cfg import EncoderVisualizerEpipolarCfg
def box(
image: Float[Tensor, "3 height width"],
) -> Float[Tensor, "3 new_height new_width"]:
return add_border(add_border(image), 1, 0)
class EncoderVisualizerEpipolar(
EncoderVisualizer[EncoderVisualizerEpipolarCfg, EncoderEpipolar]
):
def visualize(
self,
context: BatchedViews,
global_step: int,
) -> dict[str, Float[Tensor, "3 _ _"]]:
# Short-circuit execution when ablating the epipolar transformer.
if self.encoder.epipolar_transformer is None:
return {}
visualization_dump = {}
softmax_weights = []
def hook(module, input, output):
softmax_weights.append(output)
# Register hooks to grab attention.
handles = [
layer[0].fn.attend.register_forward_hook(hook)
for layer in self.encoder.epipolar_transformer.transformer.layers
]
result = self.encoder.forward(
context,
global_step,
visualization_dump=visualization_dump,
deterministic=True,
)
# De-register hooks.
for handle in handles:
handle.remove()
softmax_weights = torch.stack(softmax_weights)
# Generate high-resolution context images that can be drawn on.
context_images = context["image"]
_, _, _, h, w = context_images.shape
length = min(h, w)
min_resolution = self.cfg.min_resolution
scale_multiplier = (min_resolution + length - 1) // length
if scale_multiplier > 1:
context_images = repeat(
context_images,
"b v c h w -> b v c (h rh) (w rw)",
rh=scale_multiplier,
rw=scale_multiplier,
)
# This is kind of hacky for now, since we're using it for short experiments.
if self.cfg.export_ply and wandb.run is not None:
name = wandb.run._name.split(" ")[0]
ply_path = Path(f"outputs/gaussians/{name}/{global_step:0>6}.ply")
export_ply(
context["extrinsics"][0, 0],
result.means[0],
visualization_dump["scales"][0],
visualization_dump["rotations"][0],
result.harmonics[0],
result.opacities[0],
ply_path,
)
return {
# "attention": self.visualize_attention(
# context_images,
# visualization_dump["sampling"],
# softmax_weights,
# ),
"epipolar_samples": self.visualize_epipolar_samples(
context_images,
visualization_dump["sampling"],
),
"epipolar_color_samples": self.visualize_epipolar_color_samples(
context_images,
context,
),
"gaussians": self.visualize_gaussians(
context["image"],
result.opacities,
result.covariances,
result.harmonics[..., 0], # Just visualize DC component.
),
"overlaps": self.visualize_overlaps(
context["image"],
visualization_dump["sampling"],
visualization_dump.get("is_monocular", None),
),
"depth": self.visualize_depth(
context,
visualization_dump["depth"],
),
}
def visualize_attention(
self,
context_images: Float[Tensor, "batch view 3 height width"],
sampling: EpipolarSampling,
attention: Float[Tensor, "layer bvr head 1 sample"],
) -> Float[Tensor, "3 vis_height vis_width"]:
device = context_images.device
# Pick a random batch element, view, and other view.
b, v, ov, r, s, _ = sampling.xy_sample.shape
rb = randrange(b)
rv = randrange(v)
rov = randrange(ov)
num_samples = self.cfg.num_samples
rr = np.random.choice(r, num_samples, replace=False)
rr = torch.tensor(rr, dtype=torch.int64, device=device)
# Visualize the rays in the ray view.
ray_view = draw_points(
context_images[rb, rv],
sampling.xy_ray[rb, rv, rr],
0,
radius=4,
x_range=(0, 1),
y_range=(0, 1),
)
ray_view = draw_points(
ray_view,
sampling.xy_ray[rb, rv, rr],
[get_distinct_color(i) for i, _ in enumerate(rr)],
radius=3,
x_range=(0, 1),
y_range=(0, 1),
)
# Visualize attention in the sample view.
attention = rearrange(
attention, "l (b v r) hd () s -> l b v r hd s", b=b, v=v, r=r
)
attention = attention[:, rb, rv, rr, :, :]
num_layers, _, hd, _ = attention.shape
vis = []
for il in range(num_layers):
vis_layer = []
for ihd in range(hd):
# Create colors according to attention.
color = [get_distinct_color(i) for i, _ in enumerate(rr)]
color = torch.tensor(color, device=attention.device)
color = rearrange(color, "r c -> r () c")
attn = rearrange(attention[il, :, ihd], "r s -> r s ()")
color = rearrange(attn * color, "r s c -> (r s ) c")
# Draw the alternating bucket lines.
vis_layer_head = draw_lines(
context_images[rb, self.encoder.sampler.index_v[rv, rov]],
rearrange(
sampling.xy_sample_near[rb, rv, rov, rr], "r s xy -> (r s) xy"
),
rearrange(
sampling.xy_sample_far[rb, rv, rov, rr], "r s xy -> (r s) xy"
),
color,
3,
cap="butt",
x_range=(0, 1),
y_range=(0, 1),
)
vis_layer.append(vis_layer_head)
vis.append(add_label(vcat(*vis_layer), f"Layer {il}"))
vis = add_label(add_border(add_border(hcat(*vis)), 1, 0), "Keys & Values")
vis = add_border(hcat(add_label(ray_view), vis, align="top"))
return vis
def visualize_depth(
self,
context: BatchedViews,
multi_depth: Float[Tensor, "batch view height width surface spp"],
) -> Float[Tensor, "3 vis_width vis_height"]:
multi_vis = []
*_, srf, _ = multi_depth.shape
for i in range(srf):
depth = multi_depth[..., i, :]
depth = depth.mean(dim=-1)
# Compute relative depth and disparity.
near = rearrange(context["near"], "b v -> b v () ()")
far = rearrange(context["far"], "b v -> b v () ()")
relative_depth = (depth - near) / (far - near)
relative_disparity = 1 - (1 / depth - 1 / far) / (1 / near - 1 / far)
relative_depth = apply_color_map_to_image(relative_depth, "turbo")
relative_depth = vcat(*[hcat(*x) for x in relative_depth])
relative_depth = add_label(relative_depth, "Depth")
relative_disparity = apply_color_map_to_image(relative_disparity, "turbo")
relative_disparity = vcat(*[hcat(*x) for x in relative_disparity])
relative_disparity = add_label(relative_disparity, "Disparity")
multi_vis.append(add_border(hcat(relative_depth, relative_disparity)))
return add_border(vcat(*multi_vis))
def visualize_overlaps(
self,
context_images: Float[Tensor, "batch view 3 height width"],
sampling: EpipolarSampling,
is_monocular: Optional[Bool[Tensor, "batch view height width"]] = None,
) -> Float[Tensor, "3 vis_width vis_height"]:
device = context_images.device
b, v, _, h, w = context_images.shape
green = torch.tensor([0.235, 0.706, 0.294], device=device)[..., None, None]
rb = randrange(b)
valid = sampling.valid[rb].float()
ds = self.encoder.cfg.epipolar_transformer.downscale
valid = repeat(
valid,
"v ov (h w) -> v ov c (h rh) (w rw)",
c=3,
h=h // ds,
w=w // ds,
rh=ds,
rw=ds,
)
if is_monocular is not None:
is_monocular = is_monocular[rb].float()
is_monocular = repeat(is_monocular, "v h w -> v c h w", c=3, h=h, w=w)
# Select context images in grid.
context_images = context_images[rb]
index, _ = generate_heterogeneous_index(v)
valid = valid * (green + context_images[index]) / 2
vis = vcat(*(hcat(im, hcat(*v)) for im, v in zip(context_images, valid)))
vis = add_label(vis, "Context Overlaps")
if is_monocular is not None:
vis = hcat(vis, add_label(vcat(*is_monocular), "Monocular?"))
return add_border(vis)
def visualize_gaussians(
self,
context_images: Float[Tensor, "batch view 3 height width"],
opacities: Float[Tensor, "batch vrspp"],
covariances: Float[Tensor, "batch vrspp 3 3"],
colors: Float[Tensor, "batch vrspp 3"],
) -> Float[Tensor, "3 vis_height vis_width"]:
b, v, _, h, w = context_images.shape
rb = randrange(b)
context_images = context_images[rb]
opacities = repeat(
opacities[rb], "(v h w spp) -> spp v c h w", v=v, c=3, h=h, w=w
)
colors = rearrange(colors[rb], "(v h w spp) c -> spp v c h w", v=v, h=h, w=w)
# Color-map Gaussian covariawnces.
det = covariances[rb].det()
det = apply_color_map(det / det.max(), "inferno")
det = rearrange(det, "(v h w spp) c -> spp v c h w", v=v, h=h, w=w)
return add_border(
hcat(
add_label(box(hcat(*context_images)), "Context"),
add_label(box(vcat(*[hcat(*x) for x in opacities])), "Opacities"),
add_label(
box(vcat(*[hcat(*x) for x in (colors * opacities)])), "Colors"
),
add_label(box(vcat(*[hcat(*x) for x in colors])), "Colors (Raw)"),
add_label(box(vcat(*[hcat(*x) for x in det])), "Determinant"),
)
)
def visualize_probabilities(
self,
context_images: Float[Tensor, "batch view 3 height width"],
sampling: EpipolarSampling,
pdf: Float[Tensor, "batch view ray sample"],
) -> Float[Tensor, "3 vis_height vis_width"]:
device = context_images.device
# Pick a random batch element, view, and other view.
b, v, ov, r, _, _ = sampling.xy_sample.shape
rb = randrange(b)
rv = randrange(v)
rov = randrange(ov)
num_samples = self.cfg.num_samples
rr = np.random.choice(r, num_samples, replace=False)
rr = torch.tensor(rr, dtype=torch.int64, device=device)
colors = [get_distinct_color(i) for i, _ in enumerate(rr)]
colors = torch.tensor(colors, dtype=torch.float32, device=device)
# Visualize the rays in the ray view.
ray_view = draw_points(
context_images[rb, rv],
sampling.xy_ray[rb, rv, rr],
0,
radius=4,
x_range=(0, 1),
y_range=(0, 1),
)
ray_view = draw_points(
ray_view,
sampling.xy_ray[rb, rv, rr],
colors,
radius=3,
x_range=(0, 1),
y_range=(0, 1),
)
# Visualize probabilities in the sample view.
pdf = pdf[rb, rv, rr]
pdf = rearrange(pdf, "r s -> r s ()")
colors = rearrange(colors, "r c -> r () c")
sample_view = draw_lines(
context_images[rb, self.encoder.sampler.index_v[rv, rov]],
rearrange(sampling.xy_sample_near[rb, rv, rov, rr], "r s xy -> (r s) xy"),
rearrange(sampling.xy_sample_far[rb, rv, rov, rr], "r s xy -> (r s) xy"),
rearrange(pdf * colors, "r s c -> (r s) c"),
6,
cap="butt",
x_range=(0, 1),
y_range=(0, 1),
)
# Visualize rescaled probabilities in the sample view.
pdf_magnified = pdf / reduce(pdf, "r s () -> r () ()", "max")
sample_view_magnified = draw_lines(
context_images[rb, self.encoder.sampler.index_v[rv, rov]],
rearrange(sampling.xy_sample_near[rb, rv, rov, rr], "r s xy -> (r s) xy"),
rearrange(sampling.xy_sample_far[rb, rv, rov, rr], "r s xy -> (r s) xy"),
rearrange(pdf_magnified * colors, "r s c -> (r s) c"),
6,
cap="butt",
x_range=(0, 1),
y_range=(0, 1),
)
return add_border(
hcat(
add_label(ray_view, "Rays"),
add_label(sample_view, "Samples"),
add_label(sample_view_magnified, "Samples (Magnified PDF)"),
)
)
def visualize_epipolar_samples(
self,
context_images: Float[Tensor, "batch view 3 height width"],
sampling: EpipolarSampling,
) -> Float[Tensor, "3 vis_height vis_width"]:
device = context_images.device
# Pick a random batch element, view, and other view.
b, v, ov, r, s, _ = sampling.xy_sample.shape
rb = randrange(b)
rv = randrange(v)
rov = randrange(ov)
num_samples = self.cfg.num_samples
rr = np.random.choice(r, num_samples, replace=False)
rr = torch.tensor(rr, dtype=torch.int64, device=device)
# Visualize the rays in the ray view.
ray_view = draw_points(
context_images[rb, rv],
sampling.xy_ray[rb, rv, rr],
0,
radius=4,
x_range=(0, 1),
y_range=(0, 1),
)
ray_view = draw_points(
ray_view,
sampling.xy_ray[rb, rv, rr],
[get_distinct_color(i) for i, _ in enumerate(rr)],
radius=3,
x_range=(0, 1),
y_range=(0, 1),
)
# Visualize the samples and epipolar lines in the sample view.
# First, draw the epipolar line in black.
sample_view = draw_lines(
context_images[rb, self.encoder.sampler.index_v[rv, rov]],
sampling.xy_sample_near[rb, rv, rov, rr, 0],
sampling.xy_sample_far[rb, rv, rov, rr, -1],
0,
5,
cap="butt",
x_range=(0, 1),
y_range=(0, 1),
)
# Create an alternating line color for the buckets.
color = repeat(
torch.tensor([0, 1], device=device),
"ab -> r (s ab) c",
r=len(rr),
s=(s + 1) // 2,
c=3,
)
color = rearrange(color[:, :s], "r s c -> (r s) c")
# Draw the alternating bucket lines.
sample_view = draw_lines(
sample_view,
rearrange(sampling.xy_sample_near[rb, rv, rov, rr], "r s xy -> (r s) xy"),
rearrange(sampling.xy_sample_far[rb, rv, rov, rr], "r s xy -> (r s) xy"),
color,
3,
cap="butt",
x_range=(0, 1),
y_range=(0, 1),
)
# Draw the sample points.
sample_view = draw_points(
sample_view,
rearrange(sampling.xy_sample[rb, rv, rov, rr], "r s xy -> (r s) xy"),
0,
radius=4,
x_range=(0, 1),
y_range=(0, 1),
)
sample_view = draw_points(
sample_view,
rearrange(sampling.xy_sample[rb, rv, rov, rr], "r s xy -> (r s) xy"),
[get_distinct_color(i // s) for i in range(s * len(rr))],
radius=3,
x_range=(0, 1),
y_range=(0, 1),
)
return add_border(
hcat(add_label(ray_view, "Ray View"), add_label(sample_view, "Sample View"))
)
def visualize_epipolar_color_samples(
self,
context_images: Float[Tensor, "batch view 3 height width"],
context: BatchedViews,
) -> Float[Tensor, "3 vis_height vis_width"]:
device = context_images.device
sampling = self.encoder.sampler(
context["image"],
context["extrinsics"],
context["intrinsics"],
context["near"],
context["far"],
)
# Pick a random batch element, view, and other view.
b, v, ov, r, s, _ = sampling.xy_sample.shape
rb = randrange(b)
rv = randrange(v)
rov = randrange(ov)
num_samples = self.cfg.num_samples
rr = np.random.choice(r, num_samples, replace=False)
rr = torch.tensor(rr, dtype=torch.int64, device=device)
# Visualize the rays in the ray view.
ray_view = draw_points(
context_images[rb, rv],
sampling.xy_ray[rb, rv, rr],
0,
radius=4,
x_range=(0, 1),
y_range=(0, 1),
)
ray_view = draw_points(
ray_view,
sampling.xy_ray[rb, rv, rr],
[get_distinct_color(i) for i, _ in enumerate(rr)],
radius=3,
x_range=(0, 1),
y_range=(0, 1),
)
# Visualize the samples and in the sample view.
sample_view = draw_points(
context_images[rb, self.encoder.sampler.index_v[rv, rov]],
rearrange(sampling.xy_sample[rb, rv, rov, rr], "r s xy -> (r s) xy"),
[get_distinct_color(i // s) for i in range(s * len(rr))],
radius=4,
x_range=(0, 1),
y_range=(0, 1),
)
sample_view = draw_points(
sample_view,
rearrange(sampling.xy_sample[rb, rv, rov, rr], "r s xy -> (r s) xy"),
rearrange(sampling.features[rb, rv, rov, rr], "r s c -> (r s) c"),
radius=3,
x_range=(0, 1),
y_range=(0, 1),
)
return add_border(
hcat(add_label(ray_view, "Ray View"), add_label(sample_view, "Sample View"))
)