import os import gradio as gr import plotly.graph_objects as go import sys import torch from huggingface_hub import hf_hub_download import numpy as np import random import tempfile # For creating temporary files for download import traceback # For detailed error logging # --- Environment Setup --- # Suppress TensorFlow oneDNN optimization messages if TensorFlow is inadvertently imported by a dependency os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' # Clone the repository only if the directory doesn't exist if not os.path.exists("diffusion-point-cloud"): print("Cloning diffusion-point-cloud repository...") os.system("git clone https://github.com/luost26/diffusion-point-cloud") else: print("diffusion-point-cloud repository already exists.") sys.path.append("diffusion-point-cloud") # --- Model Imports --- try: from models.vae_gaussian import GaussianVAE from models.vae_flow import FlowVAE except ImportError as e: print(f"CRITICAL Error importing models: {e}") print("Please ensure 'diffusion-point-cloud' directory is in sys.path and contains the model definitions.") sys.exit(1) # --- Model Checkpoint Paths and Loading --- DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' print(f"Using device: {DEVICE.upper()}") MODEL_CONFIGS = { "Airplane": { "path_function": lambda: hf_hub_download("SerdarHelli/diffusion-point-cloud", filename="GEN_airplane.pt", revision="main"), "expected_model_type": "gaussian", "default_args": { 'model': "gaussian", # Should match expected_model_type 'latent_dim': 128, 'hyper': None, 'residual': True, 'num_points': 2048, # For sampling # 'flexibility' will be taken from UI } }, "Chair": { "path_function": lambda: "./GEN_chair.pt", "expected_model_type": "gaussian", # Assuming Gaussian for chair as well "default_args": { 'model': "gaussian", 'latent_dim': 128, 'hyper': None, 'residual': True, 'num_points': 2048, } } # To add more models: # "YourModelName": { # "path_function": lambda: "path/to/your/model.pt", # "expected_model_type": "gaussian", # or "flow" # "default_args": { ... } # Model-specific defaults # } } # Load checkpoints LOADED_CHECKPOINTS = {} for model_name, config in MODEL_CONFIGS.items(): model_path = "" # Initialize for error message try: model_path = config["path_function"]() if model_name == "Chair" and not os.path.exists(model_path): # Specific check for local file print(f"WARNING: Checkpoint for {model_name} not found at '{model_path}'. This model will not be available.") LOADED_CHECKPOINTS[model_name] = None continue print(f"Loading checkpoint for {model_name} from '{model_path}'...") LOADED_CHECKPOINTS[model_name] = torch.load(model_path, map_location=torch.device(DEVICE), weights_only=False) print(f"Successfully loaded {model_name}.") except Exception as e: print(f"ERROR loading checkpoint for {model_name} from '{model_path}': {e}") LOADED_CHECKPOINTS[model_name] = None # --- Helper Functions --- def seed_all(seed): torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) def normalize_point_clouds(pcs, mode): if mode is None: return pcs for i in range(pcs.size(0)): pc = pcs[i] if mode == 'shape_unit': shift = pc.mean(dim=0).reshape(1, 3) scale = pc.flatten().std().reshape(1, 1) elif mode == 'shape_bbox': pc_max, _ = pc.max(dim=0, keepdim=True) pc_min, _ = pc.min(dim=0, keepdim=True) shift = ((pc_min + pc_max) / 2).view(1, 3) scale = (pc_max - pc_min).max().reshape(1, 1) / 2 else: # Fallback shift = torch.zeros_like(pc.mean(dim=0).reshape(1, 3)) scale = torch.ones_like(pc.flatten().std().reshape(1, 1)) if scale.abs().item() < 1e-8: # Prevent division by zero or very small scale scale = torch.tensor(1.0, device=pc.device, dtype=pc.dtype).reshape(1, 1) pcs[i] = (pc - shift) / scale return pcs # --- Core Prediction Logic --- def predict(seed_val, selected_model_name, flexibility_val): seed_all(int(seed_val)) ckpt = LOADED_CHECKPOINTS.get(selected_model_name) if ckpt is None: raise ValueError(f"Checkpoint for model '{selected_model_name}' not loaded or unavailable.") model_specific_defaults = MODEL_CONFIGS[selected_model_name].get("default_args", {}) # --- Argument Handling for Model Instantiation and Sampling --- actual_args = None # Prioritize args from checkpoint if available and seems valid if 'args' in ckpt and hasattr(ckpt['args'], 'model'): actual_args = ckpt['args'] print(f"Using 'args' found in checkpoint for {selected_model_name}.") # Augment with model-specific defaults if attributes are missing from ckpt['args'] for key, default_value in model_specific_defaults.items(): if not hasattr(actual_args, key): print(f"Checkpoint 'args' missing '{key}'. Setting default: {default_value}") setattr(actual_args, key, default_value) else: print(f"Warning: 'args' not found or 'args.model' missing in checkpoint for {selected_model_name}. Constructing mock_args from defaults.") # Fallback: construct args using model_specific_defaults, trying to get values from top-level of ckpt actual_args_dict = {} for key, default_value in model_specific_defaults.items(): # Try to get from ckpt top-level first, then use the model-specific default actual_args_dict[key] = ckpt.get(key, default_value) actual_args = type('Args', (), actual_args_dict)() # Ensure essential attributes for model construction and sampling are present on actual_args # These might have been set by defaults above, but good to double check or enforce if not hasattr(actual_args, 'model'): # Critical raise ValueError("Resolved 'actual_args' is missing the 'model' attribute.") if not hasattr(actual_args, 'latent_dim'): setattr(actual_args, 'latent_dim', 128) # A common default if actual_args.model == 'gaussian': if not hasattr(actual_args, 'residual'): print("Setting default 'residual=True' for GaussianVAE.") setattr(actual_args, 'residual', True) elif actual_args.model == 'flow': # Parameters for FlowVAE if not hasattr(actual_args, 'flow_depth'): setattr(actual_args, 'flow_depth', 10) if not hasattr(actual_args, 'flow_hidden_dim'): setattr(actual_args, 'flow_hidden_dim', 256) # Sampling parameters if not hasattr(actual_args, 'num_points'): print("Setting default 'num_points=2048' for sampling.") setattr(actual_args, 'num_points', 2048) # Use flexibility from UI slider, this overrides any 'flexibility' in args setattr(actual_args, 'flexibility', flexibility_val) print(f"Using flexibility: {actual_args.flexibility} for sampling.") # --- Model Instantiation --- model = None if actual_args.model == 'gaussian': model = GaussianVAE(actual_args).to(DEVICE) elif actual_args.model == 'flow': model = FlowVAE(actual_args).to(DEVICE) else: raise ValueError(f"Unknown model type in args: '{actual_args.model}'. Expected 'gaussian' or 'flow'.") model.load_state_dict(ckpt['state_dict']) model.eval() # --- Point Cloud Generation --- gen_pcs = [] with torch.no_grad(): z = torch.randn([1, actual_args.latent_dim], device=DEVICE) x = model.sample(z, int(actual_args.num_points), flexibility=actual_args.flexibility) gen_pcs.append(x.detach().cpu()) gen_pcs_tensor = torch.cat(gen_pcs, dim=0)[:1] gen_pcs_normalized = normalize_point_clouds(gen_pcs_tensor.clone(), mode="shape_bbox") return gen_pcs_normalized[0] # --- Gradio Interface Function --- def generate_gradio(seed, model_choice, flexibility, point_color_hex, marker_size): error_message = "" figure_plot = None download_file_path = None try: if seed is None: seed = random.randint(0, 2**16 - 1) seed = int(seed) if not model_choice: error_message = "Please choose a model type." # Return empty plot and no file if model not chosen return go.Figure(), None, error_message print(f"Generating {model_choice} with Seed: {seed}, Flex: {flexibility}, Color: {point_color_hex}, Size: {marker_size}") points = predict(seed, model_choice, flexibility) # Create Plotly figure figure_plot = go.Figure( data=[ go.Scatter3d( x=points[:, 0], y=points[:, 1], z=points[:, 2], mode='markers', marker=dict(size=marker_size, color=point_color_hex) # Use hex color directly ) ], layout=dict( title=f"Generated {model_choice} (Seed: {seed}, Flex: {flexibility:.2f})", scene=dict( xaxis=dict(visible=True, title='X', backgroundcolor="rgb(230,230,230)", gridcolor="white", zerolinecolor="white"), yaxis=dict(visible=True, title='Y', backgroundcolor="rgb(230,230,230)", gridcolor="white", zerolinecolor="white"), zaxis=dict(visible=True, title='Z', backgroundcolor="rgb(230,230,230)", gridcolor="white", zerolinecolor="white"), aspectmode='data' ), margin=dict(l=0, r=0, b=0, t=40) ) ) # Prepare file for download with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".xyz", encoding='utf-8') as tmp_file: for point in points: tmp_file.write(f"{point[0]:.6f} {point[1]:.6f} {point[2]:.6f}\n") download_file_path = tmp_file.name print(f"Point cloud saved for download at: {download_file_path}") except ValueError as ve: error_message = f"Configuration Error: {str(ve)}" print(error_message) except AttributeError as ae: error_message = f"Model Configuration Issue: {str(ae)}. The checkpoint might be missing expected parameters or they are incompatible." print(error_message) except Exception as e: error_message = f"An unexpected error occurred: {str(e)}" print(f"{error_message}\nFull Traceback:\n{traceback.format_exc()}") # Ensure we always return three values, even on error if figure_plot is None: figure_plot = go.Figure() # Empty plot on error return figure_plot, download_file_path, error_message # --- Gradio UI Definition --- available_models = [name for name, ckpt in LOADED_CHECKPOINTS.items() if ckpt is not None] if not available_models: print("CRITICAL: No models were loaded successfully. The application may not function as expected.") markdown_description = f''' # Diffusion Probabilistic Models for 3D Point Cloud Generation [CVPR 2021 Paper: "Diffusion Probabilistic Models for 3D Point Cloud Generation"](https://arxiv.org/abs/2103.01458) | [Official GitHub](https://github.com/luost26/diffusion-point-cloud) This demo allows you to generate 3D point clouds using pre-trained models. - Adjust the **Seed** for different random initializations. - Choose a **Model Type** (e.g., Airplane, Chair). - Control **Sampling Flexibility**: Lower values tend towards the mean shape, higher values increase diversity. - Customize **Point Color** and **Marker Size**. Running on: **{DEVICE.upper()}** ''' if "Chair" in MODEL_CONFIGS and "Chair" not in available_models: # Check if Chair was intended but failed to load markdown_description += "\n\n**Warning:** The 'Chair' model checkpoint (`GEN_chair.pt`) was not found or failed to load. Please ensure it's in the root directory if you intend to use it." with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown(markdown_description) with gr.Row(): with gr.Column(scale=1): # Controls Column model_dropdown = gr.Dropdown(choices=available_models, label="Choose Model Type", value=available_models[0] if available_models else None) seed_slider = gr.Slider(minimum=0, maximum=2**16 - 1, step=1, label='Seed', value=777, randomize=True) flexibility_slider = gr.Slider(minimum=0.0, maximum=2.0, step=0.05, label='Sampling Flexibility', value=0.0) with gr.Row(): color_picker = gr.ColorPicker(label="Point Color", value="#EE4B2B") # Default orange marker_size_slider = gr.Slider(minimum=1, maximum=10, step=1, label="Marker Size", value=2) generate_btn = gr.Button(value="Generate Point Cloud", variant="primary") with gr.Column(scale=2): # Output Column plot_output = gr.Plot(label="Generated Point Cloud") file_download_output = gr.File(label="Download Point Cloud (.xyz)") error_display = gr.Markdown("") # For displaying error messages generate_btn.click( fn=generate_gradio, inputs=[seed_slider, model_dropdown, flexibility_slider, color_picker, marker_size_slider], outputs=[plot_output, file_download_output, error_display] ) if available_models: example_list = [ [777, available_models[0], 0.0, "#EE4B2B", 2], [1234, available_models[0], 0.5, "#1E90FF", 3], # DodgerBlue ] if len(available_models) > 1: # If Chair (or another model) is available example_list.append([100, available_models[1], 0.2, "#32CD32", 2.5]) # LimeGreen gr.Examples( examples=example_list, inputs=[seed_slider, model_dropdown, flexibility_slider, color_picker, marker_size_slider], outputs=[plot_output, file_download_output, error_display], fn=generate_gradio, cache_examples=False, # Generation is fast enough, no need to cache potentially large plots ) # --- Application Launch --- if __name__ == "__main__": if not available_models: print("No models available to run the Gradio demo. You might want to check checkpoint paths and errors above.") # Optionally, you could still launch a limited UI that just shows an error. # For now, we'll just print and let it potentially launch an empty UI if Gradio is set up. print("Launching Gradio demo...") demo.launch() # Add share=True if you want a public link when running locally