File size: 9,441 Bytes
0d0e451 7872317 ada67da 0d0e451 933cc55 0d0e451 ada67da 0d0e451 ada67da 0d0e451 ada67da 8a2c015 7872317 0d0e451 ada67da 0d0e451 ada67da 0d0e451 ada67da 0d0e451 ada67da 0d0e451 ada67da 0d0e451 ada67da 0d0e451 ada67da 0d0e451 ada67da 0d0e451 933cc55 ada67da 0d0e451 ada67da 933cc55 ada67da 61d2ddd ada67da 0d0e451 ada67da 0d0e451 ada67da 0d0e451 ada67da 0d0e451 ada67da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 |
import os
import gradio as gr
import plotly.graph_objects as go
import sys
import torch
from huggingface_hub import hf_hub_download
import numpy as np
import random
# import argparse # Not strictly needed for weights_only=False, but good practice if dealing with argparse.Namespace
os.system("git clone https://github.com/luost26/diffusion-point-cloud")
sys.path.append("diffusion-point-cloud")
#Codes reference : https://github.com/luost26/diffusion-point-cloud
from models.vae_gaussian import *
from models.vae_flow import *
airplane_model_path = hf_hub_download("SerdarHelli/diffusion-point-cloud", filename="GEN_airplane.pt", revision="main")
# IMPORTANT: GEN_chair.pt must be present in the root directory where this script is run.
# This script does NOT download GEN_chair.pt. You need to manually place it there.
# The original repository (https://github.com/luost26/diffusion-point-cloud)
# mentions downloading checkpoints from Google Drive.
chair_model_path = "./GEN_chair.pt"
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# --- Start of PyTorch 2.6+ loading considerations ---
# Option 1: Set weights_only=False for each load (Simpler, if you trust the source)
# This is the approach being applied here as per previous interactions.
ckpt_airplane = torch.load(airplane_model_path, map_location=torch.device(device), weights_only=False)
ckpt_chair = torch.load(chair_model_path, map_location=torch.device(device), weights_only=False) # <--- FIX APPLIED HERE
# Option 2: For a more robust/secure approach with PyTorch 2.6+ (if you have many models)
# You could do this at the top, after importing torch and argparse:
# import argparse
# torch.serialization.add_safe_globals([argparse.Namespace])
# Then, the torch.load calls below would not need weights_only=False (they'd use the default weights_only=True)
# ckpt_airplane = torch.load(airplane_model_path, map_location=torch.device(device))
# ckpt_chair = torch.load(chair_model_path, map_location=torch.device(device))
# --- End of PyTorch 2.6+ loading considerations ---
def seed_all(seed):
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
def normalize_point_clouds(pcs, mode):
if mode is None:
return pcs
for i in range(pcs.size(0)):
pc = pcs[i]
if mode == 'shape_unit':
shift = pc.mean(dim=0).reshape(1, 3)
scale = pc.flatten().std().reshape(1, 1)
elif mode == 'shape_bbox':
pc_max, _ = pc.max(dim=0, keepdim=True) # (1, 3)
pc_min, _ = pc.min(dim=0, keepdim=True) # (1, 3)
shift = ((pc_min + pc_max) / 2).view(1, 3)
scale = (pc_max - pc_min).max().reshape(1, 1) / 2
else: # Fallback if mode is not recognized, though your code doesn't use this branch with current inputs
shift = 0
scale = 1
# Prevent division by zero or very small scale
if scale < 1e-8:
scale = torch.tensor(1.0).reshape(1,1)
pc = (pc - shift) / scale
pcs[i] = pc
return pcs
def predict(Seed, ckpt):
if Seed is None:
Seed = 777
seed_all(int(Seed)) # Ensure Seed is an integer
# Ensure args is accessible, provide a default if it's missing or not a Namespace
# This is a defensive measure, as the error was about loading argparse.Namespace
if not hasattr(ckpt, 'args') or not hasattr(ckpt['args'], 'model'):
# This case should ideally not happen if the checkpoint is valid
# but if it does, we need a fallback or error.
# For now, let's assume 'args' and 'args.model' exist based on the error.
print("Warning: Checkpoint 'args' or 'args.model' not found. Assuming 'gaussian'.")
model_type = 'gaussian'
latent_dim = ckpt.get('latent_dim', 128) # A common default
flexibility = ckpt.get('flexibility', 0.0) # A common default
else:
model_type = ckpt['args'].model
latent_dim = ckpt['args'].latent_dim
flexibility = ckpt['args'].flexibility
if model_type == 'gaussian':
# Pass necessary args to the constructor
# We need to mock an args object if ckpt['args'] wasn't a full argparse.Namespace
# or if some attributes are missing.
mock_args = type('Args', (), {'latent_dim': latent_dim, 'hyper': getattr(ckpt.get('args', {}), 'hyper', None)})() # Add other required args
model = GaussianVAE(mock_args).to(device)
elif model_type == 'flow':
mock_args = type('Args', (), {
'latent_dim': latent_dim,
'flow_depth': getattr(ckpt.get('args', {}), 'flow_depth', 10), # Example default
'flow_hidden_dim': getattr(ckpt.get('args', {}), 'flow_hidden_dim', 256), # Example default
'hyper': getattr(ckpt.get('args', {}), 'hyper', None)
})()
model = FlowVAE(mock_args).to(device)
else:
raise ValueError(f"Unknown model type: {model_type}")
model.load_state_dict(ckpt['state_dict'])
model.eval() # Set model to evaluation mode
# Generate Point Clouds
gen_pcs = []
with torch.no_grad():
z = torch.randn([1, latent_dim]).to(device)
# The sample method might also depend on args from the checkpoint
num_points_to_generate = getattr(ckpt.get('args', {}), 'num_points', 2048) # Default to 2048 if not in args
x = model.sample(z, num_points_to_generate, flexibility=flexibility)
gen_pcs.append(x.detach().cpu())
gen_pcs_tensor = torch.cat(gen_pcs, dim=0)[:1] # Ensure we take only one point cloud
gen_pcs_normalized = normalize_point_clouds(gen_pcs_tensor.clone(), mode="shape_bbox") # Use .clone() if normalize_point_clouds modifies inplace
return gen_pcs_normalized[0]
def generate(seed, value):
if value == "Airplane":
ckpt = ckpt_airplane
elif value == "Chair":
ckpt = ckpt_chair
else:
# Default case or handle error
# For now, defaulting to airplane if 'value' is unexpected
print(f"Warning: Unknown model type '{value}'. Defaulting to Airplane.")
ckpt = ckpt_airplane
colors = (238, 75, 43) # RGB tuple for plotly
# Ensure seed is not None and is an int for the predict function
current_seed = seed
if current_seed is None:
current_seed = random.randint(0, 2**16 -1) # Generate a random seed if None
current_seed = int(current_seed)
points = predict(current_seed, ckpt)
# num_points = points.shape[0] # Not used directly in fig
fig = go.Figure(
data=[
go.Scatter3d(
x=points[:, 0], y=points[:, 1], z=points[:, 2],
mode='markers',
marker=dict(size=2, color=f'rgb({colors[0]},{colors[1]},{colors[2]})') # plotly expects rgb string
)
],
layout=dict(
scene=dict(
xaxis=dict(visible=True, title='X', backgroundcolor="rgb(230, 230,230)", gridcolor="white", zerolinecolor="white"),
yaxis=dict(visible=True, title='Y', backgroundcolor="rgb(230, 230,230)", gridcolor="white", zerolinecolor="white"),
zaxis=dict(visible=True, title='Z', backgroundcolor="rgb(230, 230,230)", gridcolor="white", zerolinecolor="white"),
aspectmode='data' # Ensures proportional axes
),
margin=dict(l=0, r=0, b=0, t=40), # Adjust margins
title=f"Generated {value} (Seed: {current_seed})"
)
)
return fig
markdown = f'''
# Diffusion Probabilistic Models for 3D Point Cloud Generation
[The space demo for the CVPR 2021 paper "Diffusion Probabilistic Models for 3D Point Cloud Generation".](https://arxiv.org/abs/2103.01458)
[For the official implementation.](https://github.com/luost26/diffusion-point-cloud)
### Future Work based on interest
- Adding new models for new type objects
- New Customization
It is running on **{device.upper()}**
---
**Note:** The `GEN_chair.pt` file must be manually placed in the root directory for the "Chair" model to work.
It is not downloaded automatically by this script.
Check the [original repository's instructions](https://github.com/luost26/diffusion-point-cloud#pretrained-models) for downloading checkpoints.
---
'''
with gr.Blocks(theme=gr.themes.Soft()) as demo:
with gr.Column():
with gr.Row():
gr.Markdown(markdown)
with gr.Row():
seed_slider = gr.Slider(minimum=0, maximum=2**16 - 1, step=1, label='Seed (0 for random)', value=777) # Set initial value
model_dropdown = gr.Dropdown(choices=["Airplane", "Chair"], label="Choose Model Type", value="Airplane") # Set initial value
btn = gr.Button(value="Generate Point Cloud")
point_cloud_plot = gr.Plot() # Changed variable name for clarity
# demo.load(generate, [seed_slider, model_dropdown], point_cloud_plot) # demo.load usually runs on page load
btn.click(generate, [seed_slider, model_dropdown], point_cloud_plot)
if __name__ == "__main__":
# Ensure GEN_chair.pt exists if Chair model might be selected
if not os.path.exists(chair_model_path):
print(f"WARNING: Chair model checkpoint '{chair_model_path}' not found.")
print(f"The 'Chair' option in the UI may not work unless this file is present.")
print(f"Please download it from the original project repository and place it at '{chair_model_path}'.")
demo.launch() |