File size: 8,004 Bytes
9e15541
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
from demo_utils.utils import (load_modules, 
                              load_sample_from_path, 
                              load_sample_from_dataset,
                              get_fov_mask,
                              inference_3d, 
                              inference_rendered_2d)

import tempfile
import os
import sys
import yaml

sys.path.append("./sscbench")
from sscbench.gen_voxelgrid_npy import save_as_voxel_ply, classes_to_colors
from download_checkpoint_hf import download_scenedino_checkpoint

import torch

import numpy as np
import gradio as gr
import open3d as o3d
import spaces


# Load checkpoints from Hugging Face
download_scenedino_checkpoint("ssc-kitti-360-dino")
download_scenedino_checkpoint("ssc-kitti-360-dinov2")

# Load model, ray sampler, datasets
ckpt_path = "out/scenedino-pretrained/seg-best-dino/"
ckpt_name = "checkpoint.pt"
net_v1, renderer_v1, ray_sampler_v1, test_dataset = load_modules(ckpt_path, ckpt_name)
renderer_v1.eval()

ckpt_path = "out/scenedino-pretrained/seg-best-dinov2/"
ckpt_name = "checkpoint.pt"
net_v2, renderer_v2, ray_sampler_v2, _ = load_modules(ckpt_path, ckpt_name)
renderer_v2.eval()


def convert_voxels(arr, map_dict):
    f = np.vectorize(map_dict.__getitem__)
    return f(arr)

with open("sscbench/label_maps.yaml", "r") as f:
    label_maps = yaml.safe_load(f)


@spaces.GPU(duration=60)
def demo_run(image: str, 
             backbone: str,
             mode: str,
             sigma_threshold: float, 
             resolution: float, 
             x_range: int, 
             y_range: int, 
             z_range: int):

    if backbone == "DINO (ViT-B)":
        net, renderer, ray_sampler = net_v1, renderer_v1, ray_sampler_v1
    elif backbone == "DINOv2 (ViT-B)":
        net, renderer, ray_sampler = net_v2, renderer_v2, ray_sampler_v2

    prediction_mode = "stego_kmeans"
    if mode == "Feature PCA 1-3":
        segmentation = False
        rgb_from_pca_dim = 0
    elif mode == "Feature PCA 4-6":
        segmentation = False
        rgb_from_pca_dim = 3
    elif mode == "Feature PCA 7-9":
        segmentation = False
        rgb_from_pca_dim = 6
    elif mode == "SSC (unsup.)":
        segmentation = True
    elif mode == "SSC (linear)":
        segmentation = True
        prediction_mode = "direct_linear"

    # Necessary when reading from examples? cast from str
    sigma_threshold, resolution = float(sigma_threshold), float(resolution)
    x_range, y_range, z_range = int(x_range), int(y_range), int(z_range)

    # Too many voxels
    max_voxel_count = 5000000
    voxel_count = (x_range//resolution + 1) * (y_range//resolution + 1) * (z_range//resolution + 1)
    if voxel_count > max_voxel_count:
        raise gr.Error(f"Too many voxels ({int(voxel_count) / 1_000_000:.1f}M > {max_voxel_count / 1_000_000:.1f}M).\n" +
                        "Reduce voxel resolution or range.", duration=5)

    with torch.no_grad():
        images, poses, projs = load_sample_from_path(image, intrinsic=None)

        net.encode(images, projs, poses, ids_encoder=[0])
        net.set_scale(0)

        # 2D Features output
        dino_full_2d, depth_2d, seg_2d = inference_rendered_2d(net, poses, projs, ray_sampler, renderer, prediction_mode)
        net.encoder.fit_visualization(dino_full_2d.flatten(0, -2))

        if segmentation:
            output_2d = convert_voxels(seg_2d.detach().cpu(), label_maps["cityscapes_to_label"])
            output_2d = classes_to_colors[output_2d].cpu().detach().numpy()
        else:
            output_2d = net.encoder.transform_visualization(dino_full_2d, from_dim=rgb_from_pca_dim)
            output_2d -= output_2d.min()
            output_2d /= output_2d.max()
            output_2d = output_2d.cpu().detach().numpy()

        # Chunking
        max_chunk_size = 100000
        z_layers_per_chunk = max_chunk_size // ((x_range//resolution + 1) * (y_range//resolution + 1))

        # 3D Features output
        x_range = (-x_range/2, x_range)
        y_range = (-y_range/2, y_range)
        z_range = (0, z_range)

        is_occupied, output_3d, fov_mask = [], [], []
        current_z = 0

        while current_z <= z_range[1]:
            z_range_chunk = (current_z, min(current_z + z_layers_per_chunk*resolution, z_range[1]))
            current_z += (z_layers_per_chunk+1) * resolution

            xyz_chunk, dino_full_3d_chunk, sigma_3d_chunk, seg_3d_chunk = inference_3d(net, x_range, y_range, z_range_chunk, resolution, prediction_mode)
            fov_mask_chunk = get_fov_mask(projs[0, 0], xyz_chunk)

            is_occupied_chunk = sigma_3d_chunk > sigma_threshold

            if segmentation:
                output_3d_chunk = seg_3d_chunk
            else:
                output_3d_chunk = net.encoder.transform_visualization(dino_full_3d_chunk, from_dim=rgb_from_pca_dim)
                output_3d_chunk -= output_3d_chunk.min()
                output_3d_chunk /= output_3d_chunk.max()

                output_3d_chunk = torch.clamp(output_3d_chunk*1.2 - 0.1, 0.0, 1.0)
                output_3d_chunk = (255*output_3d_chunk).int()

            fov_mask_chunk = fov_mask_chunk.reshape(is_occupied_chunk.shape)

            is_occupied.append(is_occupied_chunk)
            output_3d.append(output_3d_chunk)
            fov_mask.append(fov_mask_chunk)

        is_occupied = torch.cat(is_occupied, dim=2)
        output_3d = torch.cat(output_3d, dim=2)
        fov_mask = torch.cat(fov_mask, dim=2)

    temp_dir = tempfile.gettempdir()
    ply_path = os.path.join(temp_dir, "output.ply")

    if segmentation:
        # mapped to "unlabeled"
        is_occupied[output_3d == 10] = 0
        is_occupied[output_3d == 12] = 0

        save_as_voxel_ply(ply_path, 
                          is_occupied.detach().cpu(), 
                          voxel_size=resolution, 
                          size=is_occupied.size(), 
                          classes=torch.Tensor(
                              convert_voxels(
                                  output_3d.detach().cpu(), 
                                  label_maps["cityscapes_to_label"])),
                          fov_mask=fov_mask)
    else:
        save_as_voxel_ply(ply_path, 
                          is_occupied.detach().cpu(), 
                          voxel_size=resolution, 
                          size=is_occupied.size(), 
                          colors=output_3d.detach().cpu(),
                          fov_mask=fov_mask)

    mesh = o3d.io.read_triangle_mesh(ply_path)
    glb_path = os.path.join(temp_dir, "output.glb")
    o3d.io.write_triangle_mesh(glb_path, mesh, write_ascii=True)

    del dino_full_2d, depth_2d, seg_2d
    del dino_full_3d_chunk, sigma_3d_chunk, seg_3d_chunk, is_occupied_chunk
    del is_occupied, output_3d, fov_mask

    torch.cuda.empty_cache()

    return output_2d, glb_path


demo = gr.Interface(
    demo_run,
    inputs=[
        gr.Image(label="Input image", type="filepath"),
        gr.Radio(label="Backbone", choices=["DINO (ViT-B)", "DINOv2 (ViT-B)"]),
        gr.Radio(label="Mode", choices=["Feature PCA 1-3", "Feature PCA 4-6", "Feature PCA 7-9", "SSC (unsup.)", "SSC (linear)"]),
        gr.Slider(label="Density threshold", minimum=0, maximum=1, step=0.05, value=0.2),
        gr.Slider(label="Resolution [m]", minimum=0.05, maximum=0.5, step=0.1, value=0.2),
        gr.Slider(label="X Range [m]", minimum=1, maximum=50, step=1, value=10),
        gr.Slider(label="Y Range [m]", minimum=1, maximum=50, step=1, value=10),
        gr.Slider(label="Z Range [m]", minimum=1, maximum=100, step=1, value=20),
    ], 
    outputs=[
        gr.Image(label="Rendered 2D Visualization"),
        gr.Model3D(label="Voxel Surface 3D Visualization",
                   zoom_speed=0.5, pan_speed=0.5, 
                   clear_color=[0.0, 0.0, 0.0, 0.0], 
                   camera_position=[-90, 80, None], 
                   display_mode="solid"),
    ],
    title="SceneDINO Demo",
    examples="demo_utils/examples",
)

demo.launch()