import os import gradio as gr import plotly.graph_objects as go import sys import torch from huggingface_hub import hf_hub_download import numpy as np import random # import argparse # Not strictly needed for weights_only=False, but good practice if dealing with argparse.Namespace os.system("git clone https://github.com/luost26/diffusion-point-cloud") sys.path.append("diffusion-point-cloud") #Codes reference : https://github.com/luost26/diffusion-point-cloud from models.vae_gaussian import * from models.vae_flow import * airplane_model_path = hf_hub_download("SerdarHelli/diffusion-point-cloud", filename="GEN_airplane.pt", revision="main") # IMPORTANT: GEN_chair.pt must be present in the root directory where this script is run. # This script does NOT download GEN_chair.pt. You need to manually place it there. # The original repository (https://github.com/luost26/diffusion-point-cloud) # mentions downloading checkpoints from Google Drive. chair_model_path = "./GEN_chair.pt" device = 'cuda' if torch.cuda.is_available() else 'cpu' # --- Start of PyTorch 2.6+ loading considerations --- # Option 1: Set weights_only=False for each load (Simpler, if you trust the source) # This is the approach being applied here as per previous interactions. ckpt_airplane = torch.load(airplane_model_path, map_location=torch.device(device), weights_only=False) ckpt_chair = torch.load(chair_model_path, map_location=torch.device(device), weights_only=False) # <--- FIX APPLIED HERE # Option 2: For a more robust/secure approach with PyTorch 2.6+ (if you have many models) # You could do this at the top, after importing torch and argparse: # import argparse # torch.serialization.add_safe_globals([argparse.Namespace]) # Then, the torch.load calls below would not need weights_only=False (they'd use the default weights_only=True) # ckpt_airplane = torch.load(airplane_model_path, map_location=torch.device(device)) # ckpt_chair = torch.load(chair_model_path, map_location=torch.device(device)) # --- End of PyTorch 2.6+ loading considerations --- def seed_all(seed): torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) def normalize_point_clouds(pcs, mode): if mode is None: return pcs for i in range(pcs.size(0)): pc = pcs[i] if mode == 'shape_unit': shift = pc.mean(dim=0).reshape(1, 3) scale = pc.flatten().std().reshape(1, 1) elif mode == 'shape_bbox': pc_max, _ = pc.max(dim=0, keepdim=True) # (1, 3) pc_min, _ = pc.min(dim=0, keepdim=True) # (1, 3) shift = ((pc_min + pc_max) / 2).view(1, 3) scale = (pc_max - pc_min).max().reshape(1, 1) / 2 else: # Fallback if mode is not recognized, though your code doesn't use this branch with current inputs shift = 0 scale = 1 # Prevent division by zero or very small scale if scale < 1e-8: scale = torch.tensor(1.0).reshape(1,1) pc = (pc - shift) / scale pcs[i] = pc return pcs def predict(Seed, ckpt): if Seed is None: Seed = 777 seed_all(int(Seed)) # Ensure Seed is an integer # Ensure args is accessible, provide a default if it's missing or not a Namespace # This is a defensive measure, as the error was about loading argparse.Namespace if not hasattr(ckpt, 'args') or not hasattr(ckpt['args'], 'model'): # This case should ideally not happen if the checkpoint is valid # but if it does, we need a fallback or error. # For now, let's assume 'args' and 'args.model' exist based on the error. print("Warning: Checkpoint 'args' or 'args.model' not found. Assuming 'gaussian'.") model_type = 'gaussian' latent_dim = ckpt.get('latent_dim', 128) # A common default flexibility = ckpt.get('flexibility', 0.0) # A common default else: model_type = ckpt['args'].model latent_dim = ckpt['args'].latent_dim flexibility = ckpt['args'].flexibility if model_type == 'gaussian': # Pass necessary args to the constructor # We need to mock an args object if ckpt['args'] wasn't a full argparse.Namespace # or if some attributes are missing. mock_args = type('Args', (), {'latent_dim': latent_dim, 'hyper': getattr(ckpt.get('args', {}), 'hyper', None)})() # Add other required args model = GaussianVAE(mock_args).to(device) elif model_type == 'flow': mock_args = type('Args', (), { 'latent_dim': latent_dim, 'flow_depth': getattr(ckpt.get('args', {}), 'flow_depth', 10), # Example default 'flow_hidden_dim': getattr(ckpt.get('args', {}), 'flow_hidden_dim', 256), # Example default 'hyper': getattr(ckpt.get('args', {}), 'hyper', None) })() model = FlowVAE(mock_args).to(device) else: raise ValueError(f"Unknown model type: {model_type}") model.load_state_dict(ckpt['state_dict']) model.eval() # Set model to evaluation mode # Generate Point Clouds gen_pcs = [] with torch.no_grad(): z = torch.randn([1, latent_dim]).to(device) # The sample method might also depend on args from the checkpoint num_points_to_generate = getattr(ckpt.get('args', {}), 'num_points', 2048) # Default to 2048 if not in args x = model.sample(z, num_points_to_generate, flexibility=flexibility) gen_pcs.append(x.detach().cpu()) gen_pcs_tensor = torch.cat(gen_pcs, dim=0)[:1] # Ensure we take only one point cloud gen_pcs_normalized = normalize_point_clouds(gen_pcs_tensor.clone(), mode="shape_bbox") # Use .clone() if normalize_point_clouds modifies inplace return gen_pcs_normalized[0] def generate(seed, value): if value == "Airplane": ckpt = ckpt_airplane elif value == "Chair": ckpt = ckpt_chair else: # Default case or handle error # For now, defaulting to airplane if 'value' is unexpected print(f"Warning: Unknown model type '{value}'. Defaulting to Airplane.") ckpt = ckpt_airplane colors = (238, 75, 43) # RGB tuple for plotly # Ensure seed is not None and is an int for the predict function current_seed = seed if current_seed is None: current_seed = random.randint(0, 2**16 -1) # Generate a random seed if None current_seed = int(current_seed) points = predict(current_seed, ckpt) # num_points = points.shape[0] # Not used directly in fig fig = go.Figure( data=[ go.Scatter3d( x=points[:, 0], y=points[:, 1], z=points[:, 2], mode='markers', marker=dict(size=2, color=f'rgb({colors[0]},{colors[1]},{colors[2]})') # plotly expects rgb string ) ], layout=dict( scene=dict( xaxis=dict(visible=True, title='X', backgroundcolor="rgb(230, 230,230)", gridcolor="white", zerolinecolor="white"), yaxis=dict(visible=True, title='Y', backgroundcolor="rgb(230, 230,230)", gridcolor="white", zerolinecolor="white"), zaxis=dict(visible=True, title='Z', backgroundcolor="rgb(230, 230,230)", gridcolor="white", zerolinecolor="white"), aspectmode='data' # Ensures proportional axes ), margin=dict(l=0, r=0, b=0, t=40), # Adjust margins title=f"Generated {value} (Seed: {current_seed})" ) ) return fig markdown = f''' # Diffusion Probabilistic Models for 3D Point Cloud Generation [The space demo for the CVPR 2021 paper "Diffusion Probabilistic Models for 3D Point Cloud Generation".](https://arxiv.org/abs/2103.01458) [For the official implementation.](https://github.com/luost26/diffusion-point-cloud) ### Future Work based on interest - Adding new models for new type objects - New Customization It is running on **{device.upper()}** --- **Note:** The `GEN_chair.pt` file must be manually placed in the root directory for the "Chair" model to work. It is not downloaded automatically by this script. Check the [original repository's instructions](https://github.com/luost26/diffusion-point-cloud#pretrained-models) for downloading checkpoints. --- ''' with gr.Blocks(theme=gr.themes.Soft()) as demo: with gr.Column(): with gr.Row(): gr.Markdown(markdown) with gr.Row(): seed_slider = gr.Slider(minimum=0, maximum=2**16 - 1, step=1, label='Seed (0 for random)', value=777) # Set initial value model_dropdown = gr.Dropdown(choices=["Airplane", "Chair"], label="Choose Model Type", value="Airplane") # Set initial value btn = gr.Button(value="Generate Point Cloud") point_cloud_plot = gr.Plot() # Changed variable name for clarity # demo.load(generate, [seed_slider, model_dropdown], point_cloud_plot) # demo.load usually runs on page load btn.click(generate, [seed_slider, model_dropdown], point_cloud_plot) if __name__ == "__main__": # Ensure GEN_chair.pt exists if Chair model might be selected if not os.path.exists(chair_model_path): print(f"WARNING: Chair model checkpoint '{chair_model_path}' not found.") print(f"The 'Chair' option in the UI may not work unless this file is present.") print(f"Please download it from the original project repository and place it at '{chair_model_path}'.") demo.launch()