noumanjavaid's picture
Update app.py
a64912b verified
raw
history blame
10.5 kB
import os
import gradio as gr
import plotly.graph_objects as go
import sys
import torch
from huggingface_hub import hf_hub_download
import numpy as np
import random
# import argparse # Not strictly needed for weights_only=False, but good practice if dealing with argparse.Namespace
os.system("git clone https://github.com/luost26/diffusion-point-cloud")
sys.path.append("diffusion-point-cloud")
#Codes reference : https://github.com/luost26/diffusion-point-cloud
from models.vae_gaussian import *
from models.vae_flow import *
airplane_model_path = hf_hub_download("SerdarHelli/diffusion-point-cloud", filename="GEN_airplane.pt", revision="main")
# IMPORTANT: GEN_chair.pt must be present in the root directory where this script is run.
# This script does NOT download GEN_chair.pt. You need to manually place it there.
# The original repository (https://github.com/luost26/diffusion-point-cloud)
# mentions downloading checkpoints from Google Drive.
chair_model_path = "./GEN_chair.pt"
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# --- Start of PyTorch 2.6+ loading considerations ---
# Option 1: Set weights_only=False for each load (Simpler, if you trust the source)
# This is the approach being applied here as per previous interactions.
ckpt_airplane = torch.load(airplane_model_path, map_location=torch.device(device), weights_only=False)
ckpt_chair = torch.load(chair_model_path, map_location=torch.device(device), weights_only=False) # <--- FIX APPLIED HERE
# Option 2: For a more robust/secure approach with PyTorch 2.6+ (if you have many models)
# You could do this at the top, after importing torch and argparse:
# import argparse
# torch.serialization.add_safe_globals([argparse.Namespace])
# Then, the torch.load calls below would not need weights_only=False (they'd use the default weights_only=True)
# ckpt_airplane = torch.load(airplane_model_path, map_location=torch.device(device))
# ckpt_chair = torch.load(chair_model_path, map_location=torch.device(device))
# --- End of PyTorch 2.6+ loading considerations ---
def seed_all(seed):
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
def normalize_point_clouds(pcs, mode):
if mode is None:
return pcs
for i in range(pcs.size(0)):
pc = pcs[i]
if mode == 'shape_unit':
shift = pc.mean(dim=0).reshape(1, 3)
scale = pc.flatten().std().reshape(1, 1)
elif mode == 'shape_bbox':
pc_max, _ = pc.max(dim=0, keepdim=True) # (1, 3)
pc_min, _ = pc.min(dim=0, keepdim=True) # (1, 3)
shift = ((pc_min + pc_max) / 2).view(1, 3)
scale = (pc_max - pc_min).max().reshape(1, 1) / 2
else: # Fallback if mode is not recognized, though your code doesn't use this branch with current inputs
shift = 0
scale = 1
# Prevent division by zero or very small scale
if scale < 1e-8:
scale = torch.tensor(1.0).reshape(1,1)
pc = (pc - shift) / scale
pcs[i] = pc
return pcs
def predict(Seed, ckpt):
if Seed is None:
Seed = 777
seed_all(int(Seed))
# --- MODIFICATION START ---
actual_args = None
if 'args' in ckpt and hasattr(ckpt['args'], 'model'):
actual_args = ckpt['args']
print("Using 'args' found in checkpoint.")
else:
# This fallback should ideally not be hit if 'args' is usually present
print("Warning: 'args' not found or 'args.model' missing in checkpoint. Constructing mock_args.")
# Define all necessary defaults if we have to construct from scratch
default_model_type = 'gaussian'
default_latent_dim = 128
default_hyper = None
default_residual = True
default_flow_depth = 10
default_flow_hidden_dim = 256
default_num_points = 2048 # Default for sampling
default_flexibility = 0.0 # Default for sampling
actual_args = type('Args', (), {
'model': ckpt.get('model', default_model_type),
'latent_dim': ckpt.get('latent_dim', default_latent_dim),
'hyper': ckpt.get('hyper', default_hyper),
'residual': ckpt.get('residual', default_residual),
'flow_depth': ckpt.get('flow_depth', default_flow_depth),
'flow_hidden_dim': ckpt.get('flow_hidden_dim', default_flow_hidden_dim),
'num_points': ckpt.get('num_points', default_num_points), # Try to get from ckpt top-level too
'flexibility': ckpt.get('flexibility', default_flexibility) # Try to get from ckpt top-level too
})()
# Ensure essential attributes for sampling exist on actual_args, even if 'args' was found
# These are parameters for the .sample() method, not necessarily model construction.
# The original training args might not have included these if they were fixed in the sampling script.
# Default values for sampling parameters if not present in actual_args
default_num_points_sampling = 2048
default_flexibility_sampling = 0.0
if not hasattr(actual_args, 'num_points'):
print(f"Attribute 'num_points' not found in actual_args. Setting default: {default_num_points_sampling}")
setattr(actual_args, 'num_points', default_num_points_sampling)
if not hasattr(actual_args, 'flexibility'):
print(f"Attribute 'flexibility' not found in actual_args. Setting default: {default_flexibility_sampling}")
setattr(actual_args, 'flexibility', default_flexibility_sampling)
# Also ensure 'residual' is present if it's a Gaussian model, as it was an issue before
# This is more for model construction, but good to double-check if the 'args' from ckpt might be incomplete
if actual_args.model == 'gaussian' and not hasattr(actual_args, 'residual'):
print(f"Attribute 'residual' not found in actual_args for Gaussian model. Setting default: True")
setattr(actual_args, 'residual', True) # Default for GaussianVAE
# --- MODIFICATION END ---
if actual_args.model == 'gaussian':
model = GaussianVAE(actual_args).to(device)
elif actual_args.model == 'flow':
model = FlowVAE(actual_args).to(device)
else:
raise ValueError(f"Unknown model type: {actual_args.model}")
model.load_state_dict(ckpt['state_dict'])
model.eval()
gen_pcs = []
with torch.no_grad():
# Use the (potentially now augmented) actual_args for sampling
z = torch.randn([1, actual_args.latent_dim]).to(device)
x = model.sample(z, actual_args.num_points, flexibility=actual_args.flexibility)
gen_pcs.append(x.detach().cpu())
gen_pcs_tensor = torch.cat(gen_pcs, dim=0)[:1]
gen_pcs_normalized = normalize_point_clouds(gen_pcs_tensor.clone(), mode="shape_bbox")
return gen_pcs_normalized[0]
def generate(seed, value):
if value == "Airplane":
ckpt = ckpt_airplane
elif value == "Chair":
ckpt = ckpt_chair
else:
# Default case or handle error
# For now, defaulting to airplane if 'value' is unexpected
print(f"Warning: Unknown model type '{value}'. Defaulting to Airplane.")
ckpt = ckpt_airplane
colors = (238, 75, 43) # RGB tuple for plotly
# Ensure seed is not None and is an int for the predict function
current_seed = seed
if current_seed is None:
current_seed = random.randint(0, 2**16 -1) # Generate a random seed if None
current_seed = int(current_seed)
points = predict(current_seed, ckpt)
# num_points = points.shape[0] # Not used directly in fig
fig = go.Figure(
data=[
go.Scatter3d(
x=points[:, 0], y=points[:, 1], z=points[:, 2],
mode='markers',
marker=dict(size=2, color=f'rgb({colors[0]},{colors[1]},{colors[2]})') # plotly expects rgb string
)
],
layout=dict(
scene=dict(
xaxis=dict(visible=True, title='X', backgroundcolor="rgb(230, 230,230)", gridcolor="white", zerolinecolor="white"),
yaxis=dict(visible=True, title='Y', backgroundcolor="rgb(230, 230,230)", gridcolor="white", zerolinecolor="white"),
zaxis=dict(visible=True, title='Z', backgroundcolor="rgb(230, 230,230)", gridcolor="white", zerolinecolor="white"),
aspectmode='data' # Ensures proportional axes
),
margin=dict(l=0, r=0, b=0, t=40), # Adjust margins
title=f"Generated {value} (Seed: {current_seed})"
)
)
return fig
markdown = f'''
# Diffusion Probabilistic Models for 3D Point Cloud Generation
[The space demo for the CVPR 2021 paper "Diffusion Probabilistic Models for 3D Point Cloud Generation".](https://arxiv.org/abs/2103.01458)
[For the official implementation.](https://github.com/luost26/diffusion-point-cloud)
### Future Work based on interest
- Adding new models for new type objects
- New Customization
It is running on **{device.upper()}**
---
**Note:** The `GEN_chair.pt` file must be manually placed in the root directory for the "Chair" model to work.
It is not downloaded automatically by this script.
Check the [original repository's instructions](https://github.com/luost26/diffusion-point-cloud#pretrained-models) for downloading checkpoints.
---
'''
with gr.Blocks(theme=gr.themes.Soft()) as demo:
with gr.Column():
with gr.Row():
gr.Markdown(markdown)
with gr.Row():
seed_slider = gr.Slider(minimum=0, maximum=2**16 - 1, step=1, label='Seed (0 for random)', value=777) # Set initial value
model_dropdown = gr.Dropdown(choices=["Airplane", "Chair"], label="Choose Model Type", value="Airplane") # Set initial value
btn = gr.Button(value="Generate Point Cloud")
point_cloud_plot = gr.Plot() # Changed variable name for clarity
# demo.load(generate, [seed_slider, model_dropdown], point_cloud_plot) # demo.load usually runs on page load
btn.click(generate, [seed_slider, model_dropdown], point_cloud_plot)
if __name__ == "__main__":
# Ensure GEN_chair.pt exists if Chair model might be selected
if not os.path.exists(chair_model_path):
print(f"WARNING: Chair model checkpoint '{chair_model_path}' not found.")
print(f"The 'Chair' option in the UI may not work unless this file is present.")
print(f"Please download it from the original project repository and place it at '{chair_model_path}'.")
demo.launch()