noumanjavaid commited on
Commit
df60a80
·
verified ·
1 Parent(s): a64912b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +260 -169
app.py CHANGED
@@ -6,42 +6,84 @@ import torch
6
  from huggingface_hub import hf_hub_download
7
  import numpy as np
8
  import random
9
- # import argparse # Not strictly needed for weights_only=False, but good practice if dealing with argparse.Namespace
10
-
11
- os.system("git clone https://github.com/luost26/diffusion-point-cloud")
 
 
 
 
 
 
 
 
 
12
  sys.path.append("diffusion-point-cloud")
13
 
14
- #Codes reference : https://github.com/luost26/diffusion-point-cloud
15
-
16
- from models.vae_gaussian import *
17
- from models.vae_flow import *
18
-
19
- airplane_model_path = hf_hub_download("SerdarHelli/diffusion-point-cloud", filename="GEN_airplane.pt", revision="main")
20
- # IMPORTANT: GEN_chair.pt must be present in the root directory where this script is run.
21
- # This script does NOT download GEN_chair.pt. You need to manually place it there.
22
- # The original repository (https://github.com/luost26/diffusion-point-cloud)
23
- # mentions downloading checkpoints from Google Drive.
24
- chair_model_path = "./GEN_chair.pt"
25
-
26
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
27
-
28
- # --- Start of PyTorch 2.6+ loading considerations ---
29
- # Option 1: Set weights_only=False for each load (Simpler, if you trust the source)
30
- # This is the approach being applied here as per previous interactions.
31
-
32
- ckpt_airplane = torch.load(airplane_model_path, map_location=torch.device(device), weights_only=False)
33
- ckpt_chair = torch.load(chair_model_path, map_location=torch.device(device), weights_only=False) # <--- FIX APPLIED HERE
34
-
35
- # Option 2: For a more robust/secure approach with PyTorch 2.6+ (if you have many models)
36
- # You could do this at the top, after importing torch and argparse:
37
- # import argparse
38
- # torch.serialization.add_safe_globals([argparse.Namespace])
39
- # Then, the torch.load calls below would not need weights_only=False (they'd use the default weights_only=True)
40
- # ckpt_airplane = torch.load(airplane_model_path, map_location=torch.device(device))
41
- # ckpt_chair = torch.load(chair_model_path, map_location=torch.device(device))
42
- # --- End of PyTorch 2.6+ loading considerations ---
43
-
44
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  def seed_all(seed):
46
  torch.manual_seed(seed)
47
  np.random.seed(seed)
@@ -56,182 +98,231 @@ def normalize_point_clouds(pcs, mode):
56
  shift = pc.mean(dim=0).reshape(1, 3)
57
  scale = pc.flatten().std().reshape(1, 1)
58
  elif mode == 'shape_bbox':
59
- pc_max, _ = pc.max(dim=0, keepdim=True) # (1, 3)
60
- pc_min, _ = pc.min(dim=0, keepdim=True) # (1, 3)
61
  shift = ((pc_min + pc_max) / 2).view(1, 3)
62
  scale = (pc_max - pc_min).max().reshape(1, 1) / 2
63
- else: # Fallback if mode is not recognized, though your code doesn't use this branch with current inputs
64
- shift = 0
65
- scale = 1
66
 
67
- # Prevent division by zero or very small scale
68
- if scale < 1e-8:
69
- scale = torch.tensor(1.0).reshape(1,1)
70
-
71
- pc = (pc - shift) / scale
72
- pcs[i] = pc
73
  return pcs
74
 
 
 
 
 
 
 
 
75
 
76
- def predict(Seed, ckpt):
77
- if Seed is None:
78
- Seed = 777
79
- seed_all(int(Seed))
80
 
81
- # --- MODIFICATION START ---
82
  actual_args = None
 
83
  if 'args' in ckpt and hasattr(ckpt['args'], 'model'):
84
  actual_args = ckpt['args']
85
- print("Using 'args' found in checkpoint.")
 
 
 
 
 
86
  else:
87
- # This fallback should ideally not be hit if 'args' is usually present
88
- print("Warning: 'args' not found or 'args.model' missing in checkpoint. Constructing mock_args.")
89
- # Define all necessary defaults if we have to construct from scratch
90
- default_model_type = 'gaussian'
91
- default_latent_dim = 128
92
- default_hyper = None
93
- default_residual = True
94
- default_flow_depth = 10
95
- default_flow_hidden_dim = 256
96
- default_num_points = 2048 # Default for sampling
97
- default_flexibility = 0.0 # Default for sampling
98
-
99
- actual_args = type('Args', (), {
100
- 'model': ckpt.get('model', default_model_type),
101
- 'latent_dim': ckpt.get('latent_dim', default_latent_dim),
102
- 'hyper': ckpt.get('hyper', default_hyper),
103
- 'residual': ckpt.get('residual', default_residual),
104
- 'flow_depth': ckpt.get('flow_depth', default_flow_depth),
105
- 'flow_hidden_dim': ckpt.get('flow_hidden_dim', default_flow_hidden_dim),
106
- 'num_points': ckpt.get('num_points', default_num_points), # Try to get from ckpt top-level too
107
- 'flexibility': ckpt.get('flexibility', default_flexibility) # Try to get from ckpt top-level too
108
- })()
109
-
110
- # Ensure essential attributes for sampling exist on actual_args, even if 'args' was found
111
- # These are parameters for the .sample() method, not necessarily model construction.
112
- # The original training args might not have included these if they were fixed in the sampling script.
113
-
114
- # Default values for sampling parameters if not present in actual_args
115
- default_num_points_sampling = 2048
116
- default_flexibility_sampling = 0.0
117
 
 
 
 
 
 
 
 
 
 
118
  if not hasattr(actual_args, 'num_points'):
119
- print(f"Attribute 'num_points' not found in actual_args. Setting default: {default_num_points_sampling}")
120
- setattr(actual_args, 'num_points', default_num_points_sampling)
121
 
122
- if not hasattr(actual_args, 'flexibility'):
123
- print(f"Attribute 'flexibility' not found in actual_args. Setting default: {default_flexibility_sampling}")
124
- setattr(actual_args, 'flexibility', default_flexibility_sampling)
125
-
126
- # Also ensure 'residual' is present if it's a Gaussian model, as it was an issue before
127
- # This is more for model construction, but good to double-check if the 'args' from ckpt might be incomplete
128
- if actual_args.model == 'gaussian' and not hasattr(actual_args, 'residual'):
129
- print(f"Attribute 'residual' not found in actual_args for Gaussian model. Setting default: True")
130
- setattr(actual_args, 'residual', True) # Default for GaussianVAE
131
 
132
- # --- MODIFICATION END ---
133
 
 
 
134
  if actual_args.model == 'gaussian':
135
- model = GaussianVAE(actual_args).to(device)
136
  elif actual_args.model == 'flow':
137
- model = FlowVAE(actual_args).to(device)
138
  else:
139
- raise ValueError(f"Unknown model type: {actual_args.model}")
140
 
141
  model.load_state_dict(ckpt['state_dict'])
142
  model.eval()
143
 
 
144
  gen_pcs = []
145
  with torch.no_grad():
146
- # Use the (potentially now augmented) actual_args for sampling
147
- z = torch.randn([1, actual_args.latent_dim]).to(device)
148
- x = model.sample(z, actual_args.num_points, flexibility=actual_args.flexibility)
149
  gen_pcs.append(x.detach().cpu())
150
-
151
  gen_pcs_tensor = torch.cat(gen_pcs, dim=0)[:1]
152
  gen_pcs_normalized = normalize_point_clouds(gen_pcs_tensor.clone(), mode="shape_bbox")
153
 
154
  return gen_pcs_normalized[0]
155
- def generate(seed, value):
156
- if value == "Airplane":
157
- ckpt = ckpt_airplane
158
- elif value == "Chair":
159
- ckpt = ckpt_chair
160
- else:
161
- # Default case or handle error
162
- # For now, defaulting to airplane if 'value' is unexpected
163
- print(f"Warning: Unknown model type '{value}'. Defaulting to Airplane.")
164
- ckpt = ckpt_airplane
165
 
166
- colors = (238, 75, 43) # RGB tuple for plotly
167
-
168
- # Ensure seed is not None and is an int for the predict function
169
- current_seed = seed
170
- if current_seed is None:
171
- current_seed = random.randint(0, 2**16 -1) # Generate a random seed if None
172
- current_seed = int(current_seed)
173
-
174
- points = predict(current_seed, ckpt)
175
- # num_points = points.shape[0] # Not used directly in fig
176
-
177
- fig = go.Figure(
178
- data=[
179
- go.Scatter3d(
180
- x=points[:, 0], y=points[:, 1], z=points[:, 2],
181
- mode='markers',
182
- marker=dict(size=2, color=f'rgb({colors[0]},{colors[1]},{colors[2]})') # plotly expects rgb string
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  )
184
- ],
185
- layout=dict(
186
- scene=dict(
187
- xaxis=dict(visible=True, title='X', backgroundcolor="rgb(230, 230,230)", gridcolor="white", zerolinecolor="white"),
188
- yaxis=dict(visible=True, title='Y', backgroundcolor="rgb(230, 230,230)", gridcolor="white", zerolinecolor="white"),
189
- zaxis=dict(visible=True, title='Z', backgroundcolor="rgb(230, 230,230)", gridcolor="white", zerolinecolor="white"),
190
- aspectmode='data' # Ensures proportional axes
191
- ),
192
- margin=dict(l=0, r=0, b=0, t=40), # Adjust margins
193
- title=f"Generated {value} (Seed: {current_seed})"
194
  )
195
- )
196
- return fig
197
 
198
- markdown = f'''
199
- # Diffusion Probabilistic Models for 3D Point Cloud Generation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
- [The space demo for the CVPR 2021 paper "Diffusion Probabilistic Models for 3D Point Cloud Generation".](https://arxiv.org/abs/2103.01458)
 
 
 
202
 
203
- [For the official implementation.](https://github.com/luost26/diffusion-point-cloud)
204
- ### Future Work based on interest
205
- - Adding new models for new type objects
206
- - New Customization
207
 
208
- It is running on **{device.upper()}**
209
 
210
- ---
211
- **Note:** The `GEN_chair.pt` file must be manually placed in the root directory for the "Chair" model to work.
212
- It is not downloaded automatically by this script.
213
- Check the [original repository's instructions](https://github.com/luost26/diffusion-point-cloud#pretrained-models) for downloading checkpoints.
214
- ---
 
 
215
  '''
216
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
217
- with gr.Column():
218
- with gr.Row():
219
- gr.Markdown(markdown)
220
- with gr.Row():
221
- seed_slider = gr.Slider(minimum=0, maximum=2**16 - 1, step=1, label='Seed (0 for random)', value=777) # Set initial value
222
- model_dropdown = gr.Dropdown(choices=["Airplane", "Chair"], label="Choose Model Type", value="Airplane") # Set initial value
223
 
224
- btn = gr.Button(value="Generate Point Cloud")
225
- point_cloud_plot = gr.Plot() # Changed variable name for clarity
226
 
227
- # demo.load(generate, [seed_slider, model_dropdown], point_cloud_plot) # demo.load usually runs on page load
228
- btn.click(generate, [seed_slider, model_dropdown], point_cloud_plot)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
 
230
  if __name__ == "__main__":
231
- # Ensure GEN_chair.pt exists if Chair model might be selected
232
- if not os.path.exists(chair_model_path):
233
- print(f"WARNING: Chair model checkpoint '{chair_model_path}' not found.")
234
- print(f"The 'Chair' option in the UI may not work unless this file is present.")
235
- print(f"Please download it from the original project repository and place it at '{chair_model_path}'.")
236
 
237
- demo.launch()
 
 
6
  from huggingface_hub import hf_hub_download
7
  import numpy as np
8
  import random
9
+ import tempfile # For creating temporary files for download
10
+ import traceback # For detailed error logging
11
+
12
+ # --- Environment Setup ---
13
+ # Suppress TensorFlow oneDNN optimization messages if TensorFlow is inadvertently imported by a dependency
14
+ os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
15
+ # Clone the repository only if the directory doesn't exist
16
+ if not os.path.exists("diffusion-point-cloud"):
17
+ print("Cloning diffusion-point-cloud repository...")
18
+ os.system("git clone https://github.com/luost26/diffusion-point-cloud")
19
+ else:
20
+ print("diffusion-point-cloud repository already exists.")
21
  sys.path.append("diffusion-point-cloud")
22
 
23
+ # --- Model Imports ---
24
+ try:
25
+ from models.vae_gaussian import GaussianVAE
26
+ from models.vae_flow import FlowVAE
27
+ except ImportError as e:
28
+ print(f"CRITICAL Error importing models: {e}")
29
+ print("Please ensure 'diffusion-point-cloud' directory is in sys.path and contains the model definitions.")
30
+ sys.exit(1)
31
+
32
+ # --- Model Checkpoint Paths and Loading ---
33
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
34
+ print(f"Using device: {DEVICE.upper()}")
35
+
36
+ MODEL_CONFIGS = {
37
+ "Airplane": {
38
+ "path_function": lambda: hf_hub_download("SerdarHelli/diffusion-point-cloud", filename="GEN_airplane.pt", revision="main"),
39
+ "expected_model_type": "gaussian",
40
+ "default_args": {
41
+ 'model': "gaussian", # Should match expected_model_type
42
+ 'latent_dim': 128,
43
+ 'hyper': None,
44
+ 'residual': True,
45
+ 'num_points': 2048, # For sampling
46
+ # 'flexibility' will be taken from UI
47
+ }
48
+ },
49
+ "Chair": {
50
+ "path_function": lambda: "./GEN_chair.pt",
51
+ "expected_model_type": "gaussian", # Assuming Gaussian for chair as well
52
+ "default_args": {
53
+ 'model': "gaussian",
54
+ 'latent_dim': 128,
55
+ 'hyper': None,
56
+ 'residual': True,
57
+ 'num_points': 2048,
58
+ }
59
+ }
60
+ # To add more models:
61
+ # "YourModelName": {
62
+ # "path_function": lambda: "path/to/your/model.pt",
63
+ # "expected_model_type": "gaussian", # or "flow"
64
+ # "default_args": { ... } # Model-specific defaults
65
+ # }
66
+ }
67
+
68
+
69
+ # Load checkpoints
70
+ LOADED_CHECKPOINTS = {}
71
+ for model_name, config in MODEL_CONFIGS.items():
72
+ model_path = "" # Initialize for error message
73
+ try:
74
+ model_path = config["path_function"]()
75
+ if model_name == "Chair" and not os.path.exists(model_path): # Specific check for local file
76
+ print(f"WARNING: Checkpoint for {model_name} not found at '{model_path}'. This model will not be available.")
77
+ LOADED_CHECKPOINTS[model_name] = None
78
+ continue
79
+ print(f"Loading checkpoint for {model_name} from '{model_path}'...")
80
+ LOADED_CHECKPOINTS[model_name] = torch.load(model_path, map_location=torch.device(DEVICE), weights_only=False)
81
+ print(f"Successfully loaded {model_name}.")
82
+ except Exception as e:
83
+ print(f"ERROR loading checkpoint for {model_name} from '{model_path}': {e}")
84
+ LOADED_CHECKPOINTS[model_name] = None
85
+
86
+ # --- Helper Functions ---
87
  def seed_all(seed):
88
  torch.manual_seed(seed)
89
  np.random.seed(seed)
 
98
  shift = pc.mean(dim=0).reshape(1, 3)
99
  scale = pc.flatten().std().reshape(1, 1)
100
  elif mode == 'shape_bbox':
101
+ pc_max, _ = pc.max(dim=0, keepdim=True)
102
+ pc_min, _ = pc.min(dim=0, keepdim=True)
103
  shift = ((pc_min + pc_max) / 2).view(1, 3)
104
  scale = (pc_max - pc_min).max().reshape(1, 1) / 2
105
+ else: # Fallback
106
+ shift = torch.zeros_like(pc.mean(dim=0).reshape(1, 3))
107
+ scale = torch.ones_like(pc.flatten().std().reshape(1, 1))
108
 
109
+ if scale.abs().item() < 1e-8: # Prevent division by zero or very small scale
110
+ scale = torch.tensor(1.0, device=pc.device, dtype=pc.dtype).reshape(1, 1)
111
+
112
+ pcs[i] = (pc - shift) / scale
 
 
113
  return pcs
114
 
115
+ # --- Core Prediction Logic ---
116
+ def predict(seed_val, selected_model_name, flexibility_val):
117
+ seed_all(int(seed_val))
118
+
119
+ ckpt = LOADED_CHECKPOINTS.get(selected_model_name)
120
+ if ckpt is None:
121
+ raise ValueError(f"Checkpoint for model '{selected_model_name}' not loaded or unavailable.")
122
 
123
+ model_specific_defaults = MODEL_CONFIGS[selected_model_name].get("default_args", {})
 
 
 
124
 
125
+ # --- Argument Handling for Model Instantiation and Sampling ---
126
  actual_args = None
127
+ # Prioritize args from checkpoint if available and seems valid
128
  if 'args' in ckpt and hasattr(ckpt['args'], 'model'):
129
  actual_args = ckpt['args']
130
+ print(f"Using 'args' found in checkpoint for {selected_model_name}.")
131
+ # Augment with model-specific defaults if attributes are missing from ckpt['args']
132
+ for key, default_value in model_specific_defaults.items():
133
+ if not hasattr(actual_args, key):
134
+ print(f"Checkpoint 'args' missing '{key}'. Setting default: {default_value}")
135
+ setattr(actual_args, key, default_value)
136
  else:
137
+ print(f"Warning: 'args' not found or 'args.model' missing in checkpoint for {selected_model_name}. Constructing mock_args from defaults.")
138
+ # Fallback: construct args using model_specific_defaults, trying to get values from top-level of ckpt
139
+ actual_args_dict = {}
140
+ for key, default_value in model_specific_defaults.items():
141
+ # Try to get from ckpt top-level first, then use the model-specific default
142
+ actual_args_dict[key] = ckpt.get(key, default_value)
143
+ actual_args = type('Args', (), actual_args_dict)()
144
+
145
+ # Ensure essential attributes for model construction and sampling are present on actual_args
146
+ # These might have been set by defaults above, but good to double check or enforce
147
+ if not hasattr(actual_args, 'model'): # Critical
148
+ raise ValueError("Resolved 'actual_args' is missing the 'model' attribute.")
149
+ if not hasattr(actual_args, 'latent_dim'): setattr(actual_args, 'latent_dim', 128) # A common default
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
+ if actual_args.model == 'gaussian':
152
+ if not hasattr(actual_args, 'residual'):
153
+ print("Setting default 'residual=True' for GaussianVAE.")
154
+ setattr(actual_args, 'residual', True)
155
+ elif actual_args.model == 'flow': # Parameters for FlowVAE
156
+ if not hasattr(actual_args, 'flow_depth'): setattr(actual_args, 'flow_depth', 10)
157
+ if not hasattr(actual_args, 'flow_hidden_dim'): setattr(actual_args, 'flow_hidden_dim', 256)
158
+
159
+ # Sampling parameters
160
  if not hasattr(actual_args, 'num_points'):
161
+ print("Setting default 'num_points=2048' for sampling.")
162
+ setattr(actual_args, 'num_points', 2048)
163
 
164
+ # Use flexibility from UI slider, this overrides any 'flexibility' in args
165
+ setattr(actual_args, 'flexibility', flexibility_val)
166
+ print(f"Using flexibility: {actual_args.flexibility} for sampling.")
 
 
 
 
 
 
167
 
 
168
 
169
+ # --- Model Instantiation ---
170
+ model = None
171
  if actual_args.model == 'gaussian':
172
+ model = GaussianVAE(actual_args).to(DEVICE)
173
  elif actual_args.model == 'flow':
174
+ model = FlowVAE(actual_args).to(DEVICE)
175
  else:
176
+ raise ValueError(f"Unknown model type in args: '{actual_args.model}'. Expected 'gaussian' or 'flow'.")
177
 
178
  model.load_state_dict(ckpt['state_dict'])
179
  model.eval()
180
 
181
+ # --- Point Cloud Generation ---
182
  gen_pcs = []
183
  with torch.no_grad():
184
+ z = torch.randn([1, actual_args.latent_dim], device=DEVICE)
185
+ x = model.sample(z, int(actual_args.num_points), flexibility=actual_args.flexibility)
 
186
  gen_pcs.append(x.detach().cpu())
187
+
188
  gen_pcs_tensor = torch.cat(gen_pcs, dim=0)[:1]
189
  gen_pcs_normalized = normalize_point_clouds(gen_pcs_tensor.clone(), mode="shape_bbox")
190
 
191
  return gen_pcs_normalized[0]
 
 
 
 
 
 
 
 
 
 
192
 
193
+
194
+ # --- Gradio Interface Function ---
195
+ def generate_gradio(seed, model_choice, flexibility, point_color_hex, marker_size):
196
+ error_message = ""
197
+ figure_plot = None
198
+ download_file_path = None
199
+
200
+ try:
201
+ if seed is None:
202
+ seed = random.randint(0, 2**16 - 1)
203
+ seed = int(seed)
204
+
205
+ if not model_choice:
206
+ error_message = "Please choose a model type."
207
+ # Return empty plot and no file if model not chosen
208
+ return go.Figure(), None, error_message
209
+
210
+ print(f"Generating {model_choice} with Seed: {seed}, Flex: {flexibility}, Color: {point_color_hex}, Size: {marker_size}")
211
+
212
+ points = predict(seed, model_choice, flexibility)
213
+
214
+ # Create Plotly figure
215
+ figure_plot = go.Figure(
216
+ data=[
217
+ go.Scatter3d(
218
+ x=points[:, 0], y=points[:, 1], z=points[:, 2],
219
+ mode='markers',
220
+ marker=dict(size=marker_size, color=point_color_hex) # Use hex color directly
221
+ )
222
+ ],
223
+ layout=dict(
224
+ title=f"Generated {model_choice} (Seed: {seed}, Flex: {flexibility:.2f})",
225
+ scene=dict(
226
+ xaxis=dict(visible=True, title='X', backgroundcolor="rgb(230,230,230)", gridcolor="white", zerolinecolor="white"),
227
+ yaxis=dict(visible=True, title='Y', backgroundcolor="rgb(230,230,230)", gridcolor="white", zerolinecolor="white"),
228
+ zaxis=dict(visible=True, title='Z', backgroundcolor="rgb(230,230,230)", gridcolor="white", zerolinecolor="white"),
229
+ aspectmode='data'
230
+ ),
231
+ margin=dict(l=0, r=0, b=0, t=40)
232
  )
 
 
 
 
 
 
 
 
 
 
233
  )
 
 
234
 
235
+ # Prepare file for download
236
+ with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".xyz", encoding='utf-8') as tmp_file:
237
+ for point in points:
238
+ tmp_file.write(f"{point[0]:.6f} {point[1]:.6f} {point[2]:.6f}\n")
239
+ download_file_path = tmp_file.name
240
+ print(f"Point cloud saved for download at: {download_file_path}")
241
+
242
+ except ValueError as ve:
243
+ error_message = f"Configuration Error: {str(ve)}"
244
+ print(error_message)
245
+ except AttributeError as ae:
246
+ error_message = f"Model Configuration Issue: {str(ae)}. The checkpoint might be missing expected parameters or they are incompatible."
247
+ print(error_message)
248
+ except Exception as e:
249
+ error_message = f"An unexpected error occurred: {str(e)}"
250
+ print(f"{error_message}\nFull Traceback:\n{traceback.format_exc()}")
251
+
252
+ # Ensure we always return three values, even on error
253
+ if figure_plot is None: figure_plot = go.Figure() # Empty plot on error
254
+ return figure_plot, download_file_path, error_message
255
 
256
+ # --- Gradio UI Definition ---
257
+ available_models = [name for name, ckpt in LOADED_CHECKPOINTS.items() if ckpt is not None]
258
+ if not available_models:
259
+ print("CRITICAL: No models were loaded successfully. The application may not function as expected.")
260
 
261
+ markdown_description = f'''
262
+ # Diffusion Probabilistic Models for 3D Point Cloud Generation
 
 
263
 
264
+ [CVPR 2021 Paper: "Diffusion Probabilistic Models for 3D Point Cloud Generation"](https://arxiv.org/abs/2103.01458) | [Official GitHub](https://github.com/luost26/diffusion-point-cloud)
265
 
266
+ This demo allows you to generate 3D point clouds using pre-trained models.
267
+ - Adjust the **Seed** for different random initializations.
268
+ - Choose a **Model Type** (e.g., Airplane, Chair).
269
+ - Control **Sampling Flexibility**: Lower values tend towards the mean shape, higher values increase diversity.
270
+ - Customize **Point Color** and **Marker Size**.
271
+
272
+ Running on: **{DEVICE.upper()}**
273
  '''
274
+ if "Chair" in MODEL_CONFIGS and "Chair" not in available_models: # Check if Chair was intended but failed to load
275
+ markdown_description += "\n\n**Warning:** The 'Chair' model checkpoint (`GEN_chair.pt`) was not found or failed to load. Please ensure it's in the root directory if you intend to use it."
 
 
 
 
 
276
 
 
 
277
 
278
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
279
+ gr.Markdown(markdown_description)
280
+
281
+ with gr.Row():
282
+ with gr.Column(scale=1): # Controls Column
283
+ model_dropdown = gr.Dropdown(choices=available_models, label="Choose Model Type", value=available_models[0] if available_models else None)
284
+ seed_slider = gr.Slider(minimum=0, maximum=2**16 - 1, step=1, label='Seed', value=777, randomize=True)
285
+ flexibility_slider = gr.Slider(minimum=0.0, maximum=2.0, step=0.05, label='Sampling Flexibility', value=0.0)
286
+
287
+ with gr.Row():
288
+ color_picker = gr.ColorPicker(label="Point Color", value="#EE4B2B") # Default orange
289
+ marker_size_slider = gr.Slider(minimum=1, maximum=10, step=1, label="Marker Size", value=2)
290
+
291
+ generate_btn = gr.Button(value="Generate Point Cloud", variant="primary")
292
+
293
+ with gr.Column(scale=2): # Output Column
294
+ plot_output = gr.Plot(label="Generated Point Cloud")
295
+ file_download_output = gr.File(label="Download Point Cloud (.xyz)")
296
+ error_display = gr.Markdown("") # For displaying error messages
297
+
298
+ generate_btn.click(
299
+ fn=generate_gradio,
300
+ inputs=[seed_slider, model_dropdown, flexibility_slider, color_picker, marker_size_slider],
301
+ outputs=[plot_output, file_download_output, error_display]
302
+ )
303
+
304
+ if available_models:
305
+ example_list = [
306
+ [777, available_models[0], 0.0, "#EE4B2B", 2],
307
+ [1234, available_models[0], 0.5, "#1E90FF", 3], # DodgerBlue
308
+ ]
309
+ if len(available_models) > 1: # If Chair (or another model) is available
310
+ example_list.append([100, available_models[1], 0.2, "#32CD32", 2.5]) # LimeGreen
311
+
312
+ gr.Examples(
313
+ examples=example_list,
314
+ inputs=[seed_slider, model_dropdown, flexibility_slider, color_picker, marker_size_slider],
315
+ outputs=[plot_output, file_download_output, error_display],
316
+ fn=generate_gradio,
317
+ cache_examples=False, # Generation is fast enough, no need to cache potentially large plots
318
+ )
319
 
320
+ # --- Application Launch ---
321
  if __name__ == "__main__":
322
+ if not available_models:
323
+ print("No models available to run the Gradio demo. You might want to check checkpoint paths and errors above.")
324
+ # Optionally, you could still launch a limited UI that just shows an error.
325
+ # For now, we'll just print and let it potentially launch an empty UI if Gradio is set up.
 
326
 
327
+ print("Launching Gradio demo...")
328
+ demo.launch() # Add share=True if you want a public link when running locally