|
import os |
|
import gradio as gr |
|
import plotly.graph_objects as go |
|
import sys |
|
import torch |
|
from huggingface_hub import hf_hub_download |
|
import numpy as np |
|
import random |
|
|
|
|
|
os.system("git clone https://github.com/luost26/diffusion-point-cloud") |
|
sys.path.append("diffusion-point-cloud") |
|
|
|
|
|
|
|
from models.vae_gaussian import * |
|
from models.vae_flow import * |
|
|
|
airplane_model_path = hf_hub_download("SerdarHelli/diffusion-point-cloud", filename="GEN_airplane.pt", revision="main") |
|
|
|
|
|
|
|
|
|
chair_model_path = "./GEN_chair.pt" |
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
|
|
|
|
|
|
ckpt_airplane = torch.load(airplane_model_path, map_location=torch.device(device), weights_only=False) |
|
ckpt_chair = torch.load(chair_model_path, map_location=torch.device(device), weights_only=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def seed_all(seed): |
|
torch.manual_seed(seed) |
|
np.random.seed(seed) |
|
random.seed(seed) |
|
|
|
def normalize_point_clouds(pcs, mode): |
|
if mode is None: |
|
return pcs |
|
for i in range(pcs.size(0)): |
|
pc = pcs[i] |
|
if mode == 'shape_unit': |
|
shift = pc.mean(dim=0).reshape(1, 3) |
|
scale = pc.flatten().std().reshape(1, 1) |
|
elif mode == 'shape_bbox': |
|
pc_max, _ = pc.max(dim=0, keepdim=True) |
|
pc_min, _ = pc.min(dim=0, keepdim=True) |
|
shift = ((pc_min + pc_max) / 2).view(1, 3) |
|
scale = (pc_max - pc_min).max().reshape(1, 1) / 2 |
|
else: |
|
shift = 0 |
|
scale = 1 |
|
|
|
|
|
if scale < 1e-8: |
|
scale = torch.tensor(1.0).reshape(1,1) |
|
|
|
pc = (pc - shift) / scale |
|
pcs[i] = pc |
|
return pcs |
|
|
|
|
|
def predict(Seed, ckpt): |
|
if Seed is None: |
|
Seed = 777 |
|
seed_all(int(Seed)) |
|
|
|
|
|
|
|
if not hasattr(ckpt, 'args') or not hasattr(ckpt['args'], 'model'): |
|
|
|
|
|
|
|
print("Warning: Checkpoint 'args' or 'args.model' not found. Assuming 'gaussian'.") |
|
model_type = 'gaussian' |
|
latent_dim = ckpt.get('latent_dim', 128) |
|
flexibility = ckpt.get('flexibility', 0.0) |
|
else: |
|
model_type = ckpt['args'].model |
|
latent_dim = ckpt['args'].latent_dim |
|
flexibility = ckpt['args'].flexibility |
|
|
|
|
|
if model_type == 'gaussian': |
|
|
|
|
|
|
|
mock_args = type('Args', (), {'latent_dim': latent_dim, 'hyper': getattr(ckpt.get('args', {}), 'hyper', None)})() |
|
model = GaussianVAE(mock_args).to(device) |
|
elif model_type == 'flow': |
|
mock_args = type('Args', (), { |
|
'latent_dim': latent_dim, |
|
'flow_depth': getattr(ckpt.get('args', {}), 'flow_depth', 10), |
|
'flow_hidden_dim': getattr(ckpt.get('args', {}), 'flow_hidden_dim', 256), |
|
'hyper': getattr(ckpt.get('args', {}), 'hyper', None) |
|
})() |
|
model = FlowVAE(mock_args).to(device) |
|
else: |
|
raise ValueError(f"Unknown model type: {model_type}") |
|
|
|
model.load_state_dict(ckpt['state_dict']) |
|
model.eval() |
|
|
|
|
|
gen_pcs = [] |
|
with torch.no_grad(): |
|
z = torch.randn([1, latent_dim]).to(device) |
|
|
|
num_points_to_generate = getattr(ckpt.get('args', {}), 'num_points', 2048) |
|
x = model.sample(z, num_points_to_generate, flexibility=flexibility) |
|
gen_pcs.append(x.detach().cpu()) |
|
|
|
gen_pcs_tensor = torch.cat(gen_pcs, dim=0)[:1] |
|
gen_pcs_normalized = normalize_point_clouds(gen_pcs_tensor.clone(), mode="shape_bbox") |
|
|
|
return gen_pcs_normalized[0] |
|
|
|
def generate(seed, value): |
|
if value == "Airplane": |
|
ckpt = ckpt_airplane |
|
elif value == "Chair": |
|
ckpt = ckpt_chair |
|
else: |
|
|
|
|
|
print(f"Warning: Unknown model type '{value}'. Defaulting to Airplane.") |
|
ckpt = ckpt_airplane |
|
|
|
colors = (238, 75, 43) |
|
|
|
|
|
current_seed = seed |
|
if current_seed is None: |
|
current_seed = random.randint(0, 2**16 -1) |
|
current_seed = int(current_seed) |
|
|
|
points = predict(current_seed, ckpt) |
|
|
|
|
|
fig = go.Figure( |
|
data=[ |
|
go.Scatter3d( |
|
x=points[:, 0], y=points[:, 1], z=points[:, 2], |
|
mode='markers', |
|
marker=dict(size=2, color=f'rgb({colors[0]},{colors[1]},{colors[2]})') |
|
) |
|
], |
|
layout=dict( |
|
scene=dict( |
|
xaxis=dict(visible=True, title='X', backgroundcolor="rgb(230, 230,230)", gridcolor="white", zerolinecolor="white"), |
|
yaxis=dict(visible=True, title='Y', backgroundcolor="rgb(230, 230,230)", gridcolor="white", zerolinecolor="white"), |
|
zaxis=dict(visible=True, title='Z', backgroundcolor="rgb(230, 230,230)", gridcolor="white", zerolinecolor="white"), |
|
aspectmode='data' |
|
), |
|
margin=dict(l=0, r=0, b=0, t=40), |
|
title=f"Generated {value} (Seed: {current_seed})" |
|
) |
|
) |
|
return fig |
|
|
|
markdown = f''' |
|
# Diffusion Probabilistic Models for 3D Point Cloud Generation |
|
|
|
[The space demo for the CVPR 2021 paper "Diffusion Probabilistic Models for 3D Point Cloud Generation".](https://arxiv.org/abs/2103.01458) |
|
|
|
[For the official implementation.](https://github.com/luost26/diffusion-point-cloud) |
|
### Future Work based on interest |
|
- Adding new models for new type objects |
|
- New Customization |
|
|
|
It is running on **{device.upper()}** |
|
|
|
--- |
|
**Note:** The `GEN_chair.pt` file must be manually placed in the root directory for the "Chair" model to work. |
|
It is not downloaded automatically by this script. |
|
Check the [original repository's instructions](https://github.com/luost26/diffusion-point-cloud#pretrained-models) for downloading checkpoints. |
|
--- |
|
''' |
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
with gr.Column(): |
|
with gr.Row(): |
|
gr.Markdown(markdown) |
|
with gr.Row(): |
|
seed_slider = gr.Slider(minimum=0, maximum=2**16 - 1, step=1, label='Seed (0 for random)', value=777) |
|
model_dropdown = gr.Dropdown(choices=["Airplane", "Chair"], label="Choose Model Type", value="Airplane") |
|
|
|
btn = gr.Button(value="Generate Point Cloud") |
|
point_cloud_plot = gr.Plot() |
|
|
|
|
|
btn.click(generate, [seed_slider, model_dropdown], point_cloud_plot) |
|
|
|
if __name__ == "__main__": |
|
|
|
if not os.path.exists(chair_model_path): |
|
print(f"WARNING: Chair model checkpoint '{chair_model_path}' not found.") |
|
print(f"The 'Chair' option in the UI may not work unless this file is present.") |
|
print(f"Please download it from the original project repository and place it at '{chair_model_path}'.") |
|
|
|
demo.launch() |