Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,004 Bytes
9e15541 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 |
from demo_utils.utils import (load_modules,
load_sample_from_path,
load_sample_from_dataset,
get_fov_mask,
inference_3d,
inference_rendered_2d)
import tempfile
import os
import sys
import yaml
sys.path.append("./sscbench")
from sscbench.gen_voxelgrid_npy import save_as_voxel_ply, classes_to_colors
from download_checkpoint_hf import download_scenedino_checkpoint
import torch
import numpy as np
import gradio as gr
import open3d as o3d
import spaces
# Load checkpoints from Hugging Face
download_scenedino_checkpoint("ssc-kitti-360-dino")
download_scenedino_checkpoint("ssc-kitti-360-dinov2")
# Load model, ray sampler, datasets
ckpt_path = "out/scenedino-pretrained/seg-best-dino/"
ckpt_name = "checkpoint.pt"
net_v1, renderer_v1, ray_sampler_v1, test_dataset = load_modules(ckpt_path, ckpt_name)
renderer_v1.eval()
ckpt_path = "out/scenedino-pretrained/seg-best-dinov2/"
ckpt_name = "checkpoint.pt"
net_v2, renderer_v2, ray_sampler_v2, _ = load_modules(ckpt_path, ckpt_name)
renderer_v2.eval()
def convert_voxels(arr, map_dict):
f = np.vectorize(map_dict.__getitem__)
return f(arr)
with open("sscbench/label_maps.yaml", "r") as f:
label_maps = yaml.safe_load(f)
@spaces.GPU(duration=60)
def demo_run(image: str,
backbone: str,
mode: str,
sigma_threshold: float,
resolution: float,
x_range: int,
y_range: int,
z_range: int):
if backbone == "DINO (ViT-B)":
net, renderer, ray_sampler = net_v1, renderer_v1, ray_sampler_v1
elif backbone == "DINOv2 (ViT-B)":
net, renderer, ray_sampler = net_v2, renderer_v2, ray_sampler_v2
prediction_mode = "stego_kmeans"
if mode == "Feature PCA 1-3":
segmentation = False
rgb_from_pca_dim = 0
elif mode == "Feature PCA 4-6":
segmentation = False
rgb_from_pca_dim = 3
elif mode == "Feature PCA 7-9":
segmentation = False
rgb_from_pca_dim = 6
elif mode == "SSC (unsup.)":
segmentation = True
elif mode == "SSC (linear)":
segmentation = True
prediction_mode = "direct_linear"
# Necessary when reading from examples? cast from str
sigma_threshold, resolution = float(sigma_threshold), float(resolution)
x_range, y_range, z_range = int(x_range), int(y_range), int(z_range)
# Too many voxels
max_voxel_count = 5000000
voxel_count = (x_range//resolution + 1) * (y_range//resolution + 1) * (z_range//resolution + 1)
if voxel_count > max_voxel_count:
raise gr.Error(f"Too many voxels ({int(voxel_count) / 1_000_000:.1f}M > {max_voxel_count / 1_000_000:.1f}M).\n" +
"Reduce voxel resolution or range.", duration=5)
with torch.no_grad():
images, poses, projs = load_sample_from_path(image, intrinsic=None)
net.encode(images, projs, poses, ids_encoder=[0])
net.set_scale(0)
# 2D Features output
dino_full_2d, depth_2d, seg_2d = inference_rendered_2d(net, poses, projs, ray_sampler, renderer, prediction_mode)
net.encoder.fit_visualization(dino_full_2d.flatten(0, -2))
if segmentation:
output_2d = convert_voxels(seg_2d.detach().cpu(), label_maps["cityscapes_to_label"])
output_2d = classes_to_colors[output_2d].cpu().detach().numpy()
else:
output_2d = net.encoder.transform_visualization(dino_full_2d, from_dim=rgb_from_pca_dim)
output_2d -= output_2d.min()
output_2d /= output_2d.max()
output_2d = output_2d.cpu().detach().numpy()
# Chunking
max_chunk_size = 100000
z_layers_per_chunk = max_chunk_size // ((x_range//resolution + 1) * (y_range//resolution + 1))
# 3D Features output
x_range = (-x_range/2, x_range)
y_range = (-y_range/2, y_range)
z_range = (0, z_range)
is_occupied, output_3d, fov_mask = [], [], []
current_z = 0
while current_z <= z_range[1]:
z_range_chunk = (current_z, min(current_z + z_layers_per_chunk*resolution, z_range[1]))
current_z += (z_layers_per_chunk+1) * resolution
xyz_chunk, dino_full_3d_chunk, sigma_3d_chunk, seg_3d_chunk = inference_3d(net, x_range, y_range, z_range_chunk, resolution, prediction_mode)
fov_mask_chunk = get_fov_mask(projs[0, 0], xyz_chunk)
is_occupied_chunk = sigma_3d_chunk > sigma_threshold
if segmentation:
output_3d_chunk = seg_3d_chunk
else:
output_3d_chunk = net.encoder.transform_visualization(dino_full_3d_chunk, from_dim=rgb_from_pca_dim)
output_3d_chunk -= output_3d_chunk.min()
output_3d_chunk /= output_3d_chunk.max()
output_3d_chunk = torch.clamp(output_3d_chunk*1.2 - 0.1, 0.0, 1.0)
output_3d_chunk = (255*output_3d_chunk).int()
fov_mask_chunk = fov_mask_chunk.reshape(is_occupied_chunk.shape)
is_occupied.append(is_occupied_chunk)
output_3d.append(output_3d_chunk)
fov_mask.append(fov_mask_chunk)
is_occupied = torch.cat(is_occupied, dim=2)
output_3d = torch.cat(output_3d, dim=2)
fov_mask = torch.cat(fov_mask, dim=2)
temp_dir = tempfile.gettempdir()
ply_path = os.path.join(temp_dir, "output.ply")
if segmentation:
# mapped to "unlabeled"
is_occupied[output_3d == 10] = 0
is_occupied[output_3d == 12] = 0
save_as_voxel_ply(ply_path,
is_occupied.detach().cpu(),
voxel_size=resolution,
size=is_occupied.size(),
classes=torch.Tensor(
convert_voxels(
output_3d.detach().cpu(),
label_maps["cityscapes_to_label"])),
fov_mask=fov_mask)
else:
save_as_voxel_ply(ply_path,
is_occupied.detach().cpu(),
voxel_size=resolution,
size=is_occupied.size(),
colors=output_3d.detach().cpu(),
fov_mask=fov_mask)
mesh = o3d.io.read_triangle_mesh(ply_path)
glb_path = os.path.join(temp_dir, "output.glb")
o3d.io.write_triangle_mesh(glb_path, mesh, write_ascii=True)
del dino_full_2d, depth_2d, seg_2d
del dino_full_3d_chunk, sigma_3d_chunk, seg_3d_chunk, is_occupied_chunk
del is_occupied, output_3d, fov_mask
torch.cuda.empty_cache()
return output_2d, glb_path
demo = gr.Interface(
demo_run,
inputs=[
gr.Image(label="Input image", type="filepath"),
gr.Radio(label="Backbone", choices=["DINO (ViT-B)", "DINOv2 (ViT-B)"]),
gr.Radio(label="Mode", choices=["Feature PCA 1-3", "Feature PCA 4-6", "Feature PCA 7-9", "SSC (unsup.)", "SSC (linear)"]),
gr.Slider(label="Density threshold", minimum=0, maximum=1, step=0.05, value=0.2),
gr.Slider(label="Resolution [m]", minimum=0.05, maximum=0.5, step=0.1, value=0.2),
gr.Slider(label="X Range [m]", minimum=1, maximum=50, step=1, value=10),
gr.Slider(label="Y Range [m]", minimum=1, maximum=50, step=1, value=10),
gr.Slider(label="Z Range [m]", minimum=1, maximum=100, step=1, value=20),
],
outputs=[
gr.Image(label="Rendered 2D Visualization"),
gr.Model3D(label="Voxel Surface 3D Visualization",
zoom_speed=0.5, pan_speed=0.5,
clear_color=[0.0, 0.0, 0.0, 0.0],
camera_position=[-90, 80, None],
display_mode="solid"),
],
title="SceneDINO Demo",
examples="demo_utils/examples",
)
demo.launch()
|