from pathlib import Path from random import randrange from typing import Optional import numpy as np import torch import wandb from einops import rearrange, reduce, repeat from jaxtyping import Bool, Float from torch import Tensor from ....dataset.types import BatchedViews from ....misc.heterogeneous_pairings import generate_heterogeneous_index from ....visualization.annotation import add_label from ....visualization.color_map import apply_color_map, apply_color_map_to_image from ....visualization.colors import get_distinct_color from ....visualization.drawing.lines import draw_lines from ....visualization.drawing.points import draw_points from ....visualization.layout import add_border, hcat, vcat from ...ply_export import export_ply from .encoder_visualizer import EncoderVisualizer from .encoder_visualizer_epipolar_cfg import EncoderVisualizerEpipolarCfg def box( image: Float[Tensor, "3 height width"], ) -> Float[Tensor, "3 new_height new_width"]: return add_border(add_border(image), 1, 0) class EncoderVisualizerEpipolar( EncoderVisualizer[EncoderVisualizerEpipolarCfg, EncoderEpipolar] ): def visualize( self, context: BatchedViews, global_step: int, ) -> dict[str, Float[Tensor, "3 _ _"]]: # Short-circuit execution when ablating the epipolar transformer. if self.encoder.epipolar_transformer is None: return {} visualization_dump = {} softmax_weights = [] def hook(module, input, output): softmax_weights.append(output) # Register hooks to grab attention. handles = [ layer[0].fn.attend.register_forward_hook(hook) for layer in self.encoder.epipolar_transformer.transformer.layers ] result = self.encoder.forward( context, global_step, visualization_dump=visualization_dump, deterministic=True, ) # De-register hooks. for handle in handles: handle.remove() softmax_weights = torch.stack(softmax_weights) # Generate high-resolution context images that can be drawn on. context_images = context["image"] _, _, _, h, w = context_images.shape length = min(h, w) min_resolution = self.cfg.min_resolution scale_multiplier = (min_resolution + length - 1) // length if scale_multiplier > 1: context_images = repeat( context_images, "b v c h w -> b v c (h rh) (w rw)", rh=scale_multiplier, rw=scale_multiplier, ) # This is kind of hacky for now, since we're using it for short experiments. if self.cfg.export_ply and wandb.run is not None: name = wandb.run._name.split(" ")[0] ply_path = Path(f"outputs/gaussians/{name}/{global_step:0>6}.ply") export_ply( context["extrinsics"][0, 0], result.means[0], visualization_dump["scales"][0], visualization_dump["rotations"][0], result.harmonics[0], result.opacities[0], ply_path, ) return { # "attention": self.visualize_attention( # context_images, # visualization_dump["sampling"], # softmax_weights, # ), "epipolar_samples": self.visualize_epipolar_samples( context_images, visualization_dump["sampling"], ), "epipolar_color_samples": self.visualize_epipolar_color_samples( context_images, context, ), "gaussians": self.visualize_gaussians( context["image"], result.opacities, result.covariances, result.harmonics[..., 0], # Just visualize DC component. ), "overlaps": self.visualize_overlaps( context["image"], visualization_dump["sampling"], visualization_dump.get("is_monocular", None), ), "depth": self.visualize_depth( context, visualization_dump["depth"], ), } def visualize_attention( self, context_images: Float[Tensor, "batch view 3 height width"], sampling: EpipolarSampling, attention: Float[Tensor, "layer bvr head 1 sample"], ) -> Float[Tensor, "3 vis_height vis_width"]: device = context_images.device # Pick a random batch element, view, and other view. b, v, ov, r, s, _ = sampling.xy_sample.shape rb = randrange(b) rv = randrange(v) rov = randrange(ov) num_samples = self.cfg.num_samples rr = np.random.choice(r, num_samples, replace=False) rr = torch.tensor(rr, dtype=torch.int64, device=device) # Visualize the rays in the ray view. ray_view = draw_points( context_images[rb, rv], sampling.xy_ray[rb, rv, rr], 0, radius=4, x_range=(0, 1), y_range=(0, 1), ) ray_view = draw_points( ray_view, sampling.xy_ray[rb, rv, rr], [get_distinct_color(i) for i, _ in enumerate(rr)], radius=3, x_range=(0, 1), y_range=(0, 1), ) # Visualize attention in the sample view. attention = rearrange( attention, "l (b v r) hd () s -> l b v r hd s", b=b, v=v, r=r ) attention = attention[:, rb, rv, rr, :, :] num_layers, _, hd, _ = attention.shape vis = [] for il in range(num_layers): vis_layer = [] for ihd in range(hd): # Create colors according to attention. color = [get_distinct_color(i) for i, _ in enumerate(rr)] color = torch.tensor(color, device=attention.device) color = rearrange(color, "r c -> r () c") attn = rearrange(attention[il, :, ihd], "r s -> r s ()") color = rearrange(attn * color, "r s c -> (r s ) c") # Draw the alternating bucket lines. vis_layer_head = draw_lines( context_images[rb, self.encoder.sampler.index_v[rv, rov]], rearrange( sampling.xy_sample_near[rb, rv, rov, rr], "r s xy -> (r s) xy" ), rearrange( sampling.xy_sample_far[rb, rv, rov, rr], "r s xy -> (r s) xy" ), color, 3, cap="butt", x_range=(0, 1), y_range=(0, 1), ) vis_layer.append(vis_layer_head) vis.append(add_label(vcat(*vis_layer), f"Layer {il}")) vis = add_label(add_border(add_border(hcat(*vis)), 1, 0), "Keys & Values") vis = add_border(hcat(add_label(ray_view), vis, align="top")) return vis def visualize_depth( self, context: BatchedViews, multi_depth: Float[Tensor, "batch view height width surface spp"], ) -> Float[Tensor, "3 vis_width vis_height"]: multi_vis = [] *_, srf, _ = multi_depth.shape for i in range(srf): depth = multi_depth[..., i, :] depth = depth.mean(dim=-1) # Compute relative depth and disparity. near = rearrange(context["near"], "b v -> b v () ()") far = rearrange(context["far"], "b v -> b v () ()") relative_depth = (depth - near) / (far - near) relative_disparity = 1 - (1 / depth - 1 / far) / (1 / near - 1 / far) relative_depth = apply_color_map_to_image(relative_depth, "turbo") relative_depth = vcat(*[hcat(*x) for x in relative_depth]) relative_depth = add_label(relative_depth, "Depth") relative_disparity = apply_color_map_to_image(relative_disparity, "turbo") relative_disparity = vcat(*[hcat(*x) for x in relative_disparity]) relative_disparity = add_label(relative_disparity, "Disparity") multi_vis.append(add_border(hcat(relative_depth, relative_disparity))) return add_border(vcat(*multi_vis)) def visualize_overlaps( self, context_images: Float[Tensor, "batch view 3 height width"], sampling: EpipolarSampling, is_monocular: Optional[Bool[Tensor, "batch view height width"]] = None, ) -> Float[Tensor, "3 vis_width vis_height"]: device = context_images.device b, v, _, h, w = context_images.shape green = torch.tensor([0.235, 0.706, 0.294], device=device)[..., None, None] rb = randrange(b) valid = sampling.valid[rb].float() ds = self.encoder.cfg.epipolar_transformer.downscale valid = repeat( valid, "v ov (h w) -> v ov c (h rh) (w rw)", c=3, h=h // ds, w=w // ds, rh=ds, rw=ds, ) if is_monocular is not None: is_monocular = is_monocular[rb].float() is_monocular = repeat(is_monocular, "v h w -> v c h w", c=3, h=h, w=w) # Select context images in grid. context_images = context_images[rb] index, _ = generate_heterogeneous_index(v) valid = valid * (green + context_images[index]) / 2 vis = vcat(*(hcat(im, hcat(*v)) for im, v in zip(context_images, valid))) vis = add_label(vis, "Context Overlaps") if is_monocular is not None: vis = hcat(vis, add_label(vcat(*is_monocular), "Monocular?")) return add_border(vis) def visualize_gaussians( self, context_images: Float[Tensor, "batch view 3 height width"], opacities: Float[Tensor, "batch vrspp"], covariances: Float[Tensor, "batch vrspp 3 3"], colors: Float[Tensor, "batch vrspp 3"], ) -> Float[Tensor, "3 vis_height vis_width"]: b, v, _, h, w = context_images.shape rb = randrange(b) context_images = context_images[rb] opacities = repeat( opacities[rb], "(v h w spp) -> spp v c h w", v=v, c=3, h=h, w=w ) colors = rearrange(colors[rb], "(v h w spp) c -> spp v c h w", v=v, h=h, w=w) # Color-map Gaussian covariawnces. det = covariances[rb].det() det = apply_color_map(det / det.max(), "inferno") det = rearrange(det, "(v h w spp) c -> spp v c h w", v=v, h=h, w=w) return add_border( hcat( add_label(box(hcat(*context_images)), "Context"), add_label(box(vcat(*[hcat(*x) for x in opacities])), "Opacities"), add_label( box(vcat(*[hcat(*x) for x in (colors * opacities)])), "Colors" ), add_label(box(vcat(*[hcat(*x) for x in colors])), "Colors (Raw)"), add_label(box(vcat(*[hcat(*x) for x in det])), "Determinant"), ) ) def visualize_probabilities( self, context_images: Float[Tensor, "batch view 3 height width"], sampling: EpipolarSampling, pdf: Float[Tensor, "batch view ray sample"], ) -> Float[Tensor, "3 vis_height vis_width"]: device = context_images.device # Pick a random batch element, view, and other view. b, v, ov, r, _, _ = sampling.xy_sample.shape rb = randrange(b) rv = randrange(v) rov = randrange(ov) num_samples = self.cfg.num_samples rr = np.random.choice(r, num_samples, replace=False) rr = torch.tensor(rr, dtype=torch.int64, device=device) colors = [get_distinct_color(i) for i, _ in enumerate(rr)] colors = torch.tensor(colors, dtype=torch.float32, device=device) # Visualize the rays in the ray view. ray_view = draw_points( context_images[rb, rv], sampling.xy_ray[rb, rv, rr], 0, radius=4, x_range=(0, 1), y_range=(0, 1), ) ray_view = draw_points( ray_view, sampling.xy_ray[rb, rv, rr], colors, radius=3, x_range=(0, 1), y_range=(0, 1), ) # Visualize probabilities in the sample view. pdf = pdf[rb, rv, rr] pdf = rearrange(pdf, "r s -> r s ()") colors = rearrange(colors, "r c -> r () c") sample_view = draw_lines( context_images[rb, self.encoder.sampler.index_v[rv, rov]], rearrange(sampling.xy_sample_near[rb, rv, rov, rr], "r s xy -> (r s) xy"), rearrange(sampling.xy_sample_far[rb, rv, rov, rr], "r s xy -> (r s) xy"), rearrange(pdf * colors, "r s c -> (r s) c"), 6, cap="butt", x_range=(0, 1), y_range=(0, 1), ) # Visualize rescaled probabilities in the sample view. pdf_magnified = pdf / reduce(pdf, "r s () -> r () ()", "max") sample_view_magnified = draw_lines( context_images[rb, self.encoder.sampler.index_v[rv, rov]], rearrange(sampling.xy_sample_near[rb, rv, rov, rr], "r s xy -> (r s) xy"), rearrange(sampling.xy_sample_far[rb, rv, rov, rr], "r s xy -> (r s) xy"), rearrange(pdf_magnified * colors, "r s c -> (r s) c"), 6, cap="butt", x_range=(0, 1), y_range=(0, 1), ) return add_border( hcat( add_label(ray_view, "Rays"), add_label(sample_view, "Samples"), add_label(sample_view_magnified, "Samples (Magnified PDF)"), ) ) def visualize_epipolar_samples( self, context_images: Float[Tensor, "batch view 3 height width"], sampling: EpipolarSampling, ) -> Float[Tensor, "3 vis_height vis_width"]: device = context_images.device # Pick a random batch element, view, and other view. b, v, ov, r, s, _ = sampling.xy_sample.shape rb = randrange(b) rv = randrange(v) rov = randrange(ov) num_samples = self.cfg.num_samples rr = np.random.choice(r, num_samples, replace=False) rr = torch.tensor(rr, dtype=torch.int64, device=device) # Visualize the rays in the ray view. ray_view = draw_points( context_images[rb, rv], sampling.xy_ray[rb, rv, rr], 0, radius=4, x_range=(0, 1), y_range=(0, 1), ) ray_view = draw_points( ray_view, sampling.xy_ray[rb, rv, rr], [get_distinct_color(i) for i, _ in enumerate(rr)], radius=3, x_range=(0, 1), y_range=(0, 1), ) # Visualize the samples and epipolar lines in the sample view. # First, draw the epipolar line in black. sample_view = draw_lines( context_images[rb, self.encoder.sampler.index_v[rv, rov]], sampling.xy_sample_near[rb, rv, rov, rr, 0], sampling.xy_sample_far[rb, rv, rov, rr, -1], 0, 5, cap="butt", x_range=(0, 1), y_range=(0, 1), ) # Create an alternating line color for the buckets. color = repeat( torch.tensor([0, 1], device=device), "ab -> r (s ab) c", r=len(rr), s=(s + 1) // 2, c=3, ) color = rearrange(color[:, :s], "r s c -> (r s) c") # Draw the alternating bucket lines. sample_view = draw_lines( sample_view, rearrange(sampling.xy_sample_near[rb, rv, rov, rr], "r s xy -> (r s) xy"), rearrange(sampling.xy_sample_far[rb, rv, rov, rr], "r s xy -> (r s) xy"), color, 3, cap="butt", x_range=(0, 1), y_range=(0, 1), ) # Draw the sample points. sample_view = draw_points( sample_view, rearrange(sampling.xy_sample[rb, rv, rov, rr], "r s xy -> (r s) xy"), 0, radius=4, x_range=(0, 1), y_range=(0, 1), ) sample_view = draw_points( sample_view, rearrange(sampling.xy_sample[rb, rv, rov, rr], "r s xy -> (r s) xy"), [get_distinct_color(i // s) for i in range(s * len(rr))], radius=3, x_range=(0, 1), y_range=(0, 1), ) return add_border( hcat(add_label(ray_view, "Ray View"), add_label(sample_view, "Sample View")) ) def visualize_epipolar_color_samples( self, context_images: Float[Tensor, "batch view 3 height width"], context: BatchedViews, ) -> Float[Tensor, "3 vis_height vis_width"]: device = context_images.device sampling = self.encoder.sampler( context["image"], context["extrinsics"], context["intrinsics"], context["near"], context["far"], ) # Pick a random batch element, view, and other view. b, v, ov, r, s, _ = sampling.xy_sample.shape rb = randrange(b) rv = randrange(v) rov = randrange(ov) num_samples = self.cfg.num_samples rr = np.random.choice(r, num_samples, replace=False) rr = torch.tensor(rr, dtype=torch.int64, device=device) # Visualize the rays in the ray view. ray_view = draw_points( context_images[rb, rv], sampling.xy_ray[rb, rv, rr], 0, radius=4, x_range=(0, 1), y_range=(0, 1), ) ray_view = draw_points( ray_view, sampling.xy_ray[rb, rv, rr], [get_distinct_color(i) for i, _ in enumerate(rr)], radius=3, x_range=(0, 1), y_range=(0, 1), ) # Visualize the samples and in the sample view. sample_view = draw_points( context_images[rb, self.encoder.sampler.index_v[rv, rov]], rearrange(sampling.xy_sample[rb, rv, rov, rr], "r s xy -> (r s) xy"), [get_distinct_color(i // s) for i in range(s * len(rr))], radius=4, x_range=(0, 1), y_range=(0, 1), ) sample_view = draw_points( sample_view, rearrange(sampling.xy_sample[rb, rv, rov, rr], "r s xy -> (r s) xy"), rearrange(sampling.features[rb, rv, rov, rr], "r s c -> (r s) c"), radius=3, x_range=(0, 1), y_range=(0, 1), ) return add_border( hcat(add_label(ray_view, "Ray View"), add_label(sample_view, "Sample View")) )