Update app.py
Browse files
app.py
CHANGED
@@ -6,42 +6,84 @@ import torch
|
|
6 |
from huggingface_hub import hf_hub_download
|
7 |
import numpy as np
|
8 |
import random
|
9 |
-
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
sys.path.append("diffusion-point-cloud")
|
13 |
|
14 |
-
#
|
15 |
-
|
16 |
-
from models.vae_gaussian import
|
17 |
-
from models.vae_flow import
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
#
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
#
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
def seed_all(seed):
|
46 |
torch.manual_seed(seed)
|
47 |
np.random.seed(seed)
|
@@ -56,182 +98,231 @@ def normalize_point_clouds(pcs, mode):
|
|
56 |
shift = pc.mean(dim=0).reshape(1, 3)
|
57 |
scale = pc.flatten().std().reshape(1, 1)
|
58 |
elif mode == 'shape_bbox':
|
59 |
-
pc_max, _ = pc.max(dim=0, keepdim=True)
|
60 |
-
pc_min, _ = pc.min(dim=0, keepdim=True)
|
61 |
shift = ((pc_min + pc_max) / 2).view(1, 3)
|
62 |
scale = (pc_max - pc_min).max().reshape(1, 1) / 2
|
63 |
-
else: # Fallback
|
64 |
-
shift = 0
|
65 |
-
scale = 1
|
66 |
|
67 |
-
# Prevent division by zero or very small scale
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
pc = (pc - shift) / scale
|
72 |
-
pcs[i] = pc
|
73 |
return pcs
|
74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
-
|
77 |
-
if Seed is None:
|
78 |
-
Seed = 777
|
79 |
-
seed_all(int(Seed))
|
80 |
|
81 |
-
# ---
|
82 |
actual_args = None
|
|
|
83 |
if 'args' in ckpt and hasattr(ckpt['args'], 'model'):
|
84 |
actual_args = ckpt['args']
|
85 |
-
print("Using 'args' found in checkpoint.")
|
|
|
|
|
|
|
|
|
|
|
86 |
else:
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
'model': ckpt.get('model', default_model_type),
|
101 |
-
'latent_dim': ckpt.get('latent_dim', default_latent_dim),
|
102 |
-
'hyper': ckpt.get('hyper', default_hyper),
|
103 |
-
'residual': ckpt.get('residual', default_residual),
|
104 |
-
'flow_depth': ckpt.get('flow_depth', default_flow_depth),
|
105 |
-
'flow_hidden_dim': ckpt.get('flow_hidden_dim', default_flow_hidden_dim),
|
106 |
-
'num_points': ckpt.get('num_points', default_num_points), # Try to get from ckpt top-level too
|
107 |
-
'flexibility': ckpt.get('flexibility', default_flexibility) # Try to get from ckpt top-level too
|
108 |
-
})()
|
109 |
-
|
110 |
-
# Ensure essential attributes for sampling exist on actual_args, even if 'args' was found
|
111 |
-
# These are parameters for the .sample() method, not necessarily model construction.
|
112 |
-
# The original training args might not have included these if they were fixed in the sampling script.
|
113 |
-
|
114 |
-
# Default values for sampling parameters if not present in actual_args
|
115 |
-
default_num_points_sampling = 2048
|
116 |
-
default_flexibility_sampling = 0.0
|
117 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
if not hasattr(actual_args, 'num_points'):
|
119 |
-
print(
|
120 |
-
setattr(actual_args, 'num_points',
|
121 |
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
# Also ensure 'residual' is present if it's a Gaussian model, as it was an issue before
|
127 |
-
# This is more for model construction, but good to double-check if the 'args' from ckpt might be incomplete
|
128 |
-
if actual_args.model == 'gaussian' and not hasattr(actual_args, 'residual'):
|
129 |
-
print(f"Attribute 'residual' not found in actual_args for Gaussian model. Setting default: True")
|
130 |
-
setattr(actual_args, 'residual', True) # Default for GaussianVAE
|
131 |
|
132 |
-
# --- MODIFICATION END ---
|
133 |
|
|
|
|
|
134 |
if actual_args.model == 'gaussian':
|
135 |
-
model = GaussianVAE(actual_args).to(
|
136 |
elif actual_args.model == 'flow':
|
137 |
-
model = FlowVAE(actual_args).to(
|
138 |
else:
|
139 |
-
raise ValueError(f"Unknown model type: {actual_args.model}")
|
140 |
|
141 |
model.load_state_dict(ckpt['state_dict'])
|
142 |
model.eval()
|
143 |
|
|
|
144 |
gen_pcs = []
|
145 |
with torch.no_grad():
|
146 |
-
|
147 |
-
|
148 |
-
x = model.sample(z, actual_args.num_points, flexibility=actual_args.flexibility)
|
149 |
gen_pcs.append(x.detach().cpu())
|
150 |
-
|
151 |
gen_pcs_tensor = torch.cat(gen_pcs, dim=0)[:1]
|
152 |
gen_pcs_normalized = normalize_point_clouds(gen_pcs_tensor.clone(), mode="shape_bbox")
|
153 |
|
154 |
return gen_pcs_normalized[0]
|
155 |
-
def generate(seed, value):
|
156 |
-
if value == "Airplane":
|
157 |
-
ckpt = ckpt_airplane
|
158 |
-
elif value == "Chair":
|
159 |
-
ckpt = ckpt_chair
|
160 |
-
else:
|
161 |
-
# Default case or handle error
|
162 |
-
# For now, defaulting to airplane if 'value' is unexpected
|
163 |
-
print(f"Warning: Unknown model type '{value}'. Defaulting to Airplane.")
|
164 |
-
ckpt = ckpt_airplane
|
165 |
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
)
|
184 |
-
],
|
185 |
-
layout=dict(
|
186 |
-
scene=dict(
|
187 |
-
xaxis=dict(visible=True, title='X', backgroundcolor="rgb(230, 230,230)", gridcolor="white", zerolinecolor="white"),
|
188 |
-
yaxis=dict(visible=True, title='Y', backgroundcolor="rgb(230, 230,230)", gridcolor="white", zerolinecolor="white"),
|
189 |
-
zaxis=dict(visible=True, title='Z', backgroundcolor="rgb(230, 230,230)", gridcolor="white", zerolinecolor="white"),
|
190 |
-
aspectmode='data' # Ensures proportional axes
|
191 |
-
),
|
192 |
-
margin=dict(l=0, r=0, b=0, t=40), # Adjust margins
|
193 |
-
title=f"Generated {value} (Seed: {current_seed})"
|
194 |
)
|
195 |
-
)
|
196 |
-
return fig
|
197 |
|
198 |
-
|
199 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
200 |
|
201 |
-
|
|
|
|
|
|
|
202 |
|
203 |
-
|
204 |
-
|
205 |
-
- Adding new models for new type objects
|
206 |
-
- New Customization
|
207 |
|
208 |
-
|
209 |
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
|
|
|
|
215 |
'''
|
216 |
-
|
217 |
-
|
218 |
-
with gr.Row():
|
219 |
-
gr.Markdown(markdown)
|
220 |
-
with gr.Row():
|
221 |
-
seed_slider = gr.Slider(minimum=0, maximum=2**16 - 1, step=1, label='Seed (0 for random)', value=777) # Set initial value
|
222 |
-
model_dropdown = gr.Dropdown(choices=["Airplane", "Chair"], label="Choose Model Type", value="Airplane") # Set initial value
|
223 |
|
224 |
-
btn = gr.Button(value="Generate Point Cloud")
|
225 |
-
point_cloud_plot = gr.Plot() # Changed variable name for clarity
|
226 |
|
227 |
-
|
228 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
|
|
|
230 |
if __name__ == "__main__":
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
print(f"Please download it from the original project repository and place it at '{chair_model_path}'.")
|
236 |
|
237 |
-
demo
|
|
|
|
6 |
from huggingface_hub import hf_hub_download
|
7 |
import numpy as np
|
8 |
import random
|
9 |
+
import tempfile # For creating temporary files for download
|
10 |
+
import traceback # For detailed error logging
|
11 |
+
|
12 |
+
# --- Environment Setup ---
|
13 |
+
# Suppress TensorFlow oneDNN optimization messages if TensorFlow is inadvertently imported by a dependency
|
14 |
+
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
|
15 |
+
# Clone the repository only if the directory doesn't exist
|
16 |
+
if not os.path.exists("diffusion-point-cloud"):
|
17 |
+
print("Cloning diffusion-point-cloud repository...")
|
18 |
+
os.system("git clone https://github.com/luost26/diffusion-point-cloud")
|
19 |
+
else:
|
20 |
+
print("diffusion-point-cloud repository already exists.")
|
21 |
sys.path.append("diffusion-point-cloud")
|
22 |
|
23 |
+
# --- Model Imports ---
|
24 |
+
try:
|
25 |
+
from models.vae_gaussian import GaussianVAE
|
26 |
+
from models.vae_flow import FlowVAE
|
27 |
+
except ImportError as e:
|
28 |
+
print(f"CRITICAL Error importing models: {e}")
|
29 |
+
print("Please ensure 'diffusion-point-cloud' directory is in sys.path and contains the model definitions.")
|
30 |
+
sys.exit(1)
|
31 |
+
|
32 |
+
# --- Model Checkpoint Paths and Loading ---
|
33 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
34 |
+
print(f"Using device: {DEVICE.upper()}")
|
35 |
+
|
36 |
+
MODEL_CONFIGS = {
|
37 |
+
"Airplane": {
|
38 |
+
"path_function": lambda: hf_hub_download("SerdarHelli/diffusion-point-cloud", filename="GEN_airplane.pt", revision="main"),
|
39 |
+
"expected_model_type": "gaussian",
|
40 |
+
"default_args": {
|
41 |
+
'model': "gaussian", # Should match expected_model_type
|
42 |
+
'latent_dim': 128,
|
43 |
+
'hyper': None,
|
44 |
+
'residual': True,
|
45 |
+
'num_points': 2048, # For sampling
|
46 |
+
# 'flexibility' will be taken from UI
|
47 |
+
}
|
48 |
+
},
|
49 |
+
"Chair": {
|
50 |
+
"path_function": lambda: "./GEN_chair.pt",
|
51 |
+
"expected_model_type": "gaussian", # Assuming Gaussian for chair as well
|
52 |
+
"default_args": {
|
53 |
+
'model': "gaussian",
|
54 |
+
'latent_dim': 128,
|
55 |
+
'hyper': None,
|
56 |
+
'residual': True,
|
57 |
+
'num_points': 2048,
|
58 |
+
}
|
59 |
+
}
|
60 |
+
# To add more models:
|
61 |
+
# "YourModelName": {
|
62 |
+
# "path_function": lambda: "path/to/your/model.pt",
|
63 |
+
# "expected_model_type": "gaussian", # or "flow"
|
64 |
+
# "default_args": { ... } # Model-specific defaults
|
65 |
+
# }
|
66 |
+
}
|
67 |
+
|
68 |
+
|
69 |
+
# Load checkpoints
|
70 |
+
LOADED_CHECKPOINTS = {}
|
71 |
+
for model_name, config in MODEL_CONFIGS.items():
|
72 |
+
model_path = "" # Initialize for error message
|
73 |
+
try:
|
74 |
+
model_path = config["path_function"]()
|
75 |
+
if model_name == "Chair" and not os.path.exists(model_path): # Specific check for local file
|
76 |
+
print(f"WARNING: Checkpoint for {model_name} not found at '{model_path}'. This model will not be available.")
|
77 |
+
LOADED_CHECKPOINTS[model_name] = None
|
78 |
+
continue
|
79 |
+
print(f"Loading checkpoint for {model_name} from '{model_path}'...")
|
80 |
+
LOADED_CHECKPOINTS[model_name] = torch.load(model_path, map_location=torch.device(DEVICE), weights_only=False)
|
81 |
+
print(f"Successfully loaded {model_name}.")
|
82 |
+
except Exception as e:
|
83 |
+
print(f"ERROR loading checkpoint for {model_name} from '{model_path}': {e}")
|
84 |
+
LOADED_CHECKPOINTS[model_name] = None
|
85 |
+
|
86 |
+
# --- Helper Functions ---
|
87 |
def seed_all(seed):
|
88 |
torch.manual_seed(seed)
|
89 |
np.random.seed(seed)
|
|
|
98 |
shift = pc.mean(dim=0).reshape(1, 3)
|
99 |
scale = pc.flatten().std().reshape(1, 1)
|
100 |
elif mode == 'shape_bbox':
|
101 |
+
pc_max, _ = pc.max(dim=0, keepdim=True)
|
102 |
+
pc_min, _ = pc.min(dim=0, keepdim=True)
|
103 |
shift = ((pc_min + pc_max) / 2).view(1, 3)
|
104 |
scale = (pc_max - pc_min).max().reshape(1, 1) / 2
|
105 |
+
else: # Fallback
|
106 |
+
shift = torch.zeros_like(pc.mean(dim=0).reshape(1, 3))
|
107 |
+
scale = torch.ones_like(pc.flatten().std().reshape(1, 1))
|
108 |
|
109 |
+
if scale.abs().item() < 1e-8: # Prevent division by zero or very small scale
|
110 |
+
scale = torch.tensor(1.0, device=pc.device, dtype=pc.dtype).reshape(1, 1)
|
111 |
+
|
112 |
+
pcs[i] = (pc - shift) / scale
|
|
|
|
|
113 |
return pcs
|
114 |
|
115 |
+
# --- Core Prediction Logic ---
|
116 |
+
def predict(seed_val, selected_model_name, flexibility_val):
|
117 |
+
seed_all(int(seed_val))
|
118 |
+
|
119 |
+
ckpt = LOADED_CHECKPOINTS.get(selected_model_name)
|
120 |
+
if ckpt is None:
|
121 |
+
raise ValueError(f"Checkpoint for model '{selected_model_name}' not loaded or unavailable.")
|
122 |
|
123 |
+
model_specific_defaults = MODEL_CONFIGS[selected_model_name].get("default_args", {})
|
|
|
|
|
|
|
124 |
|
125 |
+
# --- Argument Handling for Model Instantiation and Sampling ---
|
126 |
actual_args = None
|
127 |
+
# Prioritize args from checkpoint if available and seems valid
|
128 |
if 'args' in ckpt and hasattr(ckpt['args'], 'model'):
|
129 |
actual_args = ckpt['args']
|
130 |
+
print(f"Using 'args' found in checkpoint for {selected_model_name}.")
|
131 |
+
# Augment with model-specific defaults if attributes are missing from ckpt['args']
|
132 |
+
for key, default_value in model_specific_defaults.items():
|
133 |
+
if not hasattr(actual_args, key):
|
134 |
+
print(f"Checkpoint 'args' missing '{key}'. Setting default: {default_value}")
|
135 |
+
setattr(actual_args, key, default_value)
|
136 |
else:
|
137 |
+
print(f"Warning: 'args' not found or 'args.model' missing in checkpoint for {selected_model_name}. Constructing mock_args from defaults.")
|
138 |
+
# Fallback: construct args using model_specific_defaults, trying to get values from top-level of ckpt
|
139 |
+
actual_args_dict = {}
|
140 |
+
for key, default_value in model_specific_defaults.items():
|
141 |
+
# Try to get from ckpt top-level first, then use the model-specific default
|
142 |
+
actual_args_dict[key] = ckpt.get(key, default_value)
|
143 |
+
actual_args = type('Args', (), actual_args_dict)()
|
144 |
+
|
145 |
+
# Ensure essential attributes for model construction and sampling are present on actual_args
|
146 |
+
# These might have been set by defaults above, but good to double check or enforce
|
147 |
+
if not hasattr(actual_args, 'model'): # Critical
|
148 |
+
raise ValueError("Resolved 'actual_args' is missing the 'model' attribute.")
|
149 |
+
if not hasattr(actual_args, 'latent_dim'): setattr(actual_args, 'latent_dim', 128) # A common default
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
|
151 |
+
if actual_args.model == 'gaussian':
|
152 |
+
if not hasattr(actual_args, 'residual'):
|
153 |
+
print("Setting default 'residual=True' for GaussianVAE.")
|
154 |
+
setattr(actual_args, 'residual', True)
|
155 |
+
elif actual_args.model == 'flow': # Parameters for FlowVAE
|
156 |
+
if not hasattr(actual_args, 'flow_depth'): setattr(actual_args, 'flow_depth', 10)
|
157 |
+
if not hasattr(actual_args, 'flow_hidden_dim'): setattr(actual_args, 'flow_hidden_dim', 256)
|
158 |
+
|
159 |
+
# Sampling parameters
|
160 |
if not hasattr(actual_args, 'num_points'):
|
161 |
+
print("Setting default 'num_points=2048' for sampling.")
|
162 |
+
setattr(actual_args, 'num_points', 2048)
|
163 |
|
164 |
+
# Use flexibility from UI slider, this overrides any 'flexibility' in args
|
165 |
+
setattr(actual_args, 'flexibility', flexibility_val)
|
166 |
+
print(f"Using flexibility: {actual_args.flexibility} for sampling.")
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
|
|
|
168 |
|
169 |
+
# --- Model Instantiation ---
|
170 |
+
model = None
|
171 |
if actual_args.model == 'gaussian':
|
172 |
+
model = GaussianVAE(actual_args).to(DEVICE)
|
173 |
elif actual_args.model == 'flow':
|
174 |
+
model = FlowVAE(actual_args).to(DEVICE)
|
175 |
else:
|
176 |
+
raise ValueError(f"Unknown model type in args: '{actual_args.model}'. Expected 'gaussian' or 'flow'.")
|
177 |
|
178 |
model.load_state_dict(ckpt['state_dict'])
|
179 |
model.eval()
|
180 |
|
181 |
+
# --- Point Cloud Generation ---
|
182 |
gen_pcs = []
|
183 |
with torch.no_grad():
|
184 |
+
z = torch.randn([1, actual_args.latent_dim], device=DEVICE)
|
185 |
+
x = model.sample(z, int(actual_args.num_points), flexibility=actual_args.flexibility)
|
|
|
186 |
gen_pcs.append(x.detach().cpu())
|
187 |
+
|
188 |
gen_pcs_tensor = torch.cat(gen_pcs, dim=0)[:1]
|
189 |
gen_pcs_normalized = normalize_point_clouds(gen_pcs_tensor.clone(), mode="shape_bbox")
|
190 |
|
191 |
return gen_pcs_normalized[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
|
193 |
+
|
194 |
+
# --- Gradio Interface Function ---
|
195 |
+
def generate_gradio(seed, model_choice, flexibility, point_color_hex, marker_size):
|
196 |
+
error_message = ""
|
197 |
+
figure_plot = None
|
198 |
+
download_file_path = None
|
199 |
+
|
200 |
+
try:
|
201 |
+
if seed is None:
|
202 |
+
seed = random.randint(0, 2**16 - 1)
|
203 |
+
seed = int(seed)
|
204 |
+
|
205 |
+
if not model_choice:
|
206 |
+
error_message = "Please choose a model type."
|
207 |
+
# Return empty plot and no file if model not chosen
|
208 |
+
return go.Figure(), None, error_message
|
209 |
+
|
210 |
+
print(f"Generating {model_choice} with Seed: {seed}, Flex: {flexibility}, Color: {point_color_hex}, Size: {marker_size}")
|
211 |
+
|
212 |
+
points = predict(seed, model_choice, flexibility)
|
213 |
+
|
214 |
+
# Create Plotly figure
|
215 |
+
figure_plot = go.Figure(
|
216 |
+
data=[
|
217 |
+
go.Scatter3d(
|
218 |
+
x=points[:, 0], y=points[:, 1], z=points[:, 2],
|
219 |
+
mode='markers',
|
220 |
+
marker=dict(size=marker_size, color=point_color_hex) # Use hex color directly
|
221 |
+
)
|
222 |
+
],
|
223 |
+
layout=dict(
|
224 |
+
title=f"Generated {model_choice} (Seed: {seed}, Flex: {flexibility:.2f})",
|
225 |
+
scene=dict(
|
226 |
+
xaxis=dict(visible=True, title='X', backgroundcolor="rgb(230,230,230)", gridcolor="white", zerolinecolor="white"),
|
227 |
+
yaxis=dict(visible=True, title='Y', backgroundcolor="rgb(230,230,230)", gridcolor="white", zerolinecolor="white"),
|
228 |
+
zaxis=dict(visible=True, title='Z', backgroundcolor="rgb(230,230,230)", gridcolor="white", zerolinecolor="white"),
|
229 |
+
aspectmode='data'
|
230 |
+
),
|
231 |
+
margin=dict(l=0, r=0, b=0, t=40)
|
232 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
)
|
|
|
|
|
234 |
|
235 |
+
# Prepare file for download
|
236 |
+
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".xyz", encoding='utf-8') as tmp_file:
|
237 |
+
for point in points:
|
238 |
+
tmp_file.write(f"{point[0]:.6f} {point[1]:.6f} {point[2]:.6f}\n")
|
239 |
+
download_file_path = tmp_file.name
|
240 |
+
print(f"Point cloud saved for download at: {download_file_path}")
|
241 |
+
|
242 |
+
except ValueError as ve:
|
243 |
+
error_message = f"Configuration Error: {str(ve)}"
|
244 |
+
print(error_message)
|
245 |
+
except AttributeError as ae:
|
246 |
+
error_message = f"Model Configuration Issue: {str(ae)}. The checkpoint might be missing expected parameters or they are incompatible."
|
247 |
+
print(error_message)
|
248 |
+
except Exception as e:
|
249 |
+
error_message = f"An unexpected error occurred: {str(e)}"
|
250 |
+
print(f"{error_message}\nFull Traceback:\n{traceback.format_exc()}")
|
251 |
+
|
252 |
+
# Ensure we always return three values, even on error
|
253 |
+
if figure_plot is None: figure_plot = go.Figure() # Empty plot on error
|
254 |
+
return figure_plot, download_file_path, error_message
|
255 |
|
256 |
+
# --- Gradio UI Definition ---
|
257 |
+
available_models = [name for name, ckpt in LOADED_CHECKPOINTS.items() if ckpt is not None]
|
258 |
+
if not available_models:
|
259 |
+
print("CRITICAL: No models were loaded successfully. The application may not function as expected.")
|
260 |
|
261 |
+
markdown_description = f'''
|
262 |
+
# Diffusion Probabilistic Models for 3D Point Cloud Generation
|
|
|
|
|
263 |
|
264 |
+
[CVPR 2021 Paper: "Diffusion Probabilistic Models for 3D Point Cloud Generation"](https://arxiv.org/abs/2103.01458) | [Official GitHub](https://github.com/luost26/diffusion-point-cloud)
|
265 |
|
266 |
+
This demo allows you to generate 3D point clouds using pre-trained models.
|
267 |
+
- Adjust the **Seed** for different random initializations.
|
268 |
+
- Choose a **Model Type** (e.g., Airplane, Chair).
|
269 |
+
- Control **Sampling Flexibility**: Lower values tend towards the mean shape, higher values increase diversity.
|
270 |
+
- Customize **Point Color** and **Marker Size**.
|
271 |
+
|
272 |
+
Running on: **{DEVICE.upper()}**
|
273 |
'''
|
274 |
+
if "Chair" in MODEL_CONFIGS and "Chair" not in available_models: # Check if Chair was intended but failed to load
|
275 |
+
markdown_description += "\n\n**Warning:** The 'Chair' model checkpoint (`GEN_chair.pt`) was not found or failed to load. Please ensure it's in the root directory if you intend to use it."
|
|
|
|
|
|
|
|
|
|
|
276 |
|
|
|
|
|
277 |
|
278 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
279 |
+
gr.Markdown(markdown_description)
|
280 |
+
|
281 |
+
with gr.Row():
|
282 |
+
with gr.Column(scale=1): # Controls Column
|
283 |
+
model_dropdown = gr.Dropdown(choices=available_models, label="Choose Model Type", value=available_models[0] if available_models else None)
|
284 |
+
seed_slider = gr.Slider(minimum=0, maximum=2**16 - 1, step=1, label='Seed', value=777, randomize=True)
|
285 |
+
flexibility_slider = gr.Slider(minimum=0.0, maximum=2.0, step=0.05, label='Sampling Flexibility', value=0.0)
|
286 |
+
|
287 |
+
with gr.Row():
|
288 |
+
color_picker = gr.ColorPicker(label="Point Color", value="#EE4B2B") # Default orange
|
289 |
+
marker_size_slider = gr.Slider(minimum=1, maximum=10, step=1, label="Marker Size", value=2)
|
290 |
+
|
291 |
+
generate_btn = gr.Button(value="Generate Point Cloud", variant="primary")
|
292 |
+
|
293 |
+
with gr.Column(scale=2): # Output Column
|
294 |
+
plot_output = gr.Plot(label="Generated Point Cloud")
|
295 |
+
file_download_output = gr.File(label="Download Point Cloud (.xyz)")
|
296 |
+
error_display = gr.Markdown("") # For displaying error messages
|
297 |
+
|
298 |
+
generate_btn.click(
|
299 |
+
fn=generate_gradio,
|
300 |
+
inputs=[seed_slider, model_dropdown, flexibility_slider, color_picker, marker_size_slider],
|
301 |
+
outputs=[plot_output, file_download_output, error_display]
|
302 |
+
)
|
303 |
+
|
304 |
+
if available_models:
|
305 |
+
example_list = [
|
306 |
+
[777, available_models[0], 0.0, "#EE4B2B", 2],
|
307 |
+
[1234, available_models[0], 0.5, "#1E90FF", 3], # DodgerBlue
|
308 |
+
]
|
309 |
+
if len(available_models) > 1: # If Chair (or another model) is available
|
310 |
+
example_list.append([100, available_models[1], 0.2, "#32CD32", 2.5]) # LimeGreen
|
311 |
+
|
312 |
+
gr.Examples(
|
313 |
+
examples=example_list,
|
314 |
+
inputs=[seed_slider, model_dropdown, flexibility_slider, color_picker, marker_size_slider],
|
315 |
+
outputs=[plot_output, file_download_output, error_display],
|
316 |
+
fn=generate_gradio,
|
317 |
+
cache_examples=False, # Generation is fast enough, no need to cache potentially large plots
|
318 |
+
)
|
319 |
|
320 |
+
# --- Application Launch ---
|
321 |
if __name__ == "__main__":
|
322 |
+
if not available_models:
|
323 |
+
print("No models available to run the Gradio demo. You might want to check checkpoint paths and errors above.")
|
324 |
+
# Optionally, you could still launch a limited UI that just shows an error.
|
325 |
+
# For now, we'll just print and let it potentially launch an empty UI if Gradio is set up.
|
|
|
326 |
|
327 |
+
print("Launching Gradio demo...")
|
328 |
+
demo.launch() # Add share=True if you want a public link when running locally
|