noumanjavaid commited on
Commit
ada67da
·
verified ·
1 Parent(s): 8a2c015

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +148 -72
app.py CHANGED
@@ -6,6 +6,7 @@ import torch
6
  from huggingface_hub import hf_hub_download
7
  import numpy as np
8
  import random
 
9
 
10
  os.system("git clone https://github.com/luost26/diffusion-point-cloud")
11
  sys.path.append("diffusion-point-cloud")
@@ -15,21 +16,38 @@ sys.path.append("diffusion-point-cloud")
15
  from models.vae_gaussian import *
16
  from models.vae_flow import *
17
 
18
- airplane=hf_hub_download("SerdarHelli/diffusion-point-cloud", filename="GEN_airplane.pt",revision="main")
19
- chair="./GEN_chair.pt"
 
 
 
 
20
 
21
- device='cuda' if torch.cuda.is_available() else 'cpu'
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- ckpt_airplane = torch.load(airplane, map_location=torch.device(device), weights_only=False)
25
- ckpt_chair = torch.load(chair,map_location=torch.device(device))
26
 
27
  def seed_all(seed):
28
  torch.manual_seed(seed)
29
  np.random.seed(seed)
30
  random.seed(seed)
31
 
32
- def normalize_point_clouds(pcs,mode):
33
  if mode is None:
34
  return pcs
35
  for i in range(pcs.size(0)):
@@ -38,100 +56,158 @@ def normalize_point_clouds(pcs,mode):
38
  shift = pc.mean(dim=0).reshape(1, 3)
39
  scale = pc.flatten().std().reshape(1, 1)
40
  elif mode == 'shape_bbox':
41
- pc_max, _ = pc.max(dim=0, keepdim=True) # (1, 3)
42
- pc_min, _ = pc.min(dim=0, keepdim=True) # (1, 3)
43
  shift = ((pc_min + pc_max) / 2).view(1, 3)
44
  scale = (pc_max - pc_min).max().reshape(1, 1) / 2
 
 
 
 
 
 
 
 
45
  pc = (pc - shift) / scale
46
  pcs[i] = pc
47
  return pcs
48
 
49
-
50
-
51
-
52
- def predict(Seed,ckpt):
53
- if Seed==None:
54
- Seed=777
55
- seed_all(Seed)
56
-
57
- if ckpt['args'].model == 'gaussian':
58
- model = GaussianVAE(ckpt['args']).to(device)
59
- elif ckpt['args'].model == 'flow':
60
- model = FlowVAE(ckpt['args']).to(device)
61
-
62
- model.load_state_dict(ckpt['state_dict'])
63
- # Generate Point Clouds
64
- gen_pcs = []
65
- with torch.no_grad():
66
- z = torch.randn([1, ckpt['args'].latent_dim]).to(device)
67
- x = model.sample(z, 2048, flexibility=ckpt['args'].flexibility)
68
- gen_pcs.append(x.detach().cpu())
69
- gen_pcs = torch.cat(gen_pcs, dim=0)[:1]
70
- gen_pcs = normalize_point_clouds(gen_pcs, mode="shape_bbox")
71
-
72
- return gen_pcs[0]
73
 
74
- def generate(seed,value):
75
- if value=="Airplane":
76
- ckpt=ckpt_airplane
77
- elif value=="Chair":
78
- ckpt=ckpt_chair
79
- else :
80
- ckpt=ckpt_airplane
81
-
82
- colors=(238, 75, 43)
83
- points=predict(seed,ckpt)
84
- num_points=points.shape[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
 
 
86
 
87
  fig = go.Figure(
88
  data=[
89
  go.Scatter3d(
90
- x=points[:,0], y=points[:,1], z=points[:,2],
91
  mode='markers',
92
- marker=dict(size=1, color=colors)
93
  )
94
  ],
95
  layout=dict(
96
  scene=dict(
97
- xaxis=dict(visible=False),
98
- yaxis=dict(visible=False),
99
- zaxis=dict(visible=False)
100
- )
 
 
 
101
  )
102
  )
103
  return fig
104
-
105
- markdown=f'''
106
- # Diffusion Probabilistic Models for 3D Point Cloud Generation
107
 
108
-
109
- [The space demo for the CVPR 2021 paper "Diffusion Probabilistic Models for 3D Point Cloud Generation".](https://arxiv.org/abs/2103.01458)
110
-
111
- [For the official implementation.](https://github.com/luost26/diffusion-point-cloud)
112
 
113
- ### Future Work based on interest
114
- - Adding new models for new type objects
115
- - New Customization
116
-
117
-
118
 
119
- It is running on {device}
120
-
121
 
 
 
 
 
 
122
  '''
123
- with gr.Blocks() as demo:
124
  with gr.Column():
125
  with gr.Row():
126
  gr.Markdown(markdown)
127
  with gr.Row():
128
- seed = gr.Slider( minimum=0, maximum=2**16,label='Seed')
129
- value=gr.Dropdown(choices=["Airplane","Chair"],label="Choose Model Type")
130
- #truncate_std = gr.Slider( minimum=1, maximum=2,label='Truncate Std')
131
 
132
- btn = gr.Button(value="Generate")
133
- point_cloud = gr.Plot()
134
- demo.load(generate, [seed,value], point_cloud)
135
- btn.click(generate, [seed,value], point_cloud)
136
 
137
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
6
  from huggingface_hub import hf_hub_download
7
  import numpy as np
8
  import random
9
+ # import argparse # Not strictly needed for weights_only=False, but good practice if dealing with argparse.Namespace
10
 
11
  os.system("git clone https://github.com/luost26/diffusion-point-cloud")
12
  sys.path.append("diffusion-point-cloud")
 
16
  from models.vae_gaussian import *
17
  from models.vae_flow import *
18
 
19
+ airplane_model_path = hf_hub_download("SerdarHelli/diffusion-point-cloud", filename="GEN_airplane.pt", revision="main")
20
+ # IMPORTANT: GEN_chair.pt must be present in the root directory where this script is run.
21
+ # This script does NOT download GEN_chair.pt. You need to manually place it there.
22
+ # The original repository (https://github.com/luost26/diffusion-point-cloud)
23
+ # mentions downloading checkpoints from Google Drive.
24
+ chair_model_path = "./GEN_chair.pt"
25
 
26
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
27
 
28
+ # --- Start of PyTorch 2.6+ loading considerations ---
29
+ # Option 1: Set weights_only=False for each load (Simpler, if you trust the source)
30
+ # This is the approach being applied here as per previous interactions.
31
+
32
+ ckpt_airplane = torch.load(airplane_model_path, map_location=torch.device(device), weights_only=False)
33
+ ckpt_chair = torch.load(chair_model_path, map_location=torch.device(device), weights_only=False) # <--- FIX APPLIED HERE
34
+
35
+ # Option 2: For a more robust/secure approach with PyTorch 2.6+ (if you have many models)
36
+ # You could do this at the top, after importing torch and argparse:
37
+ # import argparse
38
+ # torch.serialization.add_safe_globals([argparse.Namespace])
39
+ # Then, the torch.load calls below would not need weights_only=False (they'd use the default weights_only=True)
40
+ # ckpt_airplane = torch.load(airplane_model_path, map_location=torch.device(device))
41
+ # ckpt_chair = torch.load(chair_model_path, map_location=torch.device(device))
42
+ # --- End of PyTorch 2.6+ loading considerations ---
43
 
 
 
44
 
45
  def seed_all(seed):
46
  torch.manual_seed(seed)
47
  np.random.seed(seed)
48
  random.seed(seed)
49
 
50
+ def normalize_point_clouds(pcs, mode):
51
  if mode is None:
52
  return pcs
53
  for i in range(pcs.size(0)):
 
56
  shift = pc.mean(dim=0).reshape(1, 3)
57
  scale = pc.flatten().std().reshape(1, 1)
58
  elif mode == 'shape_bbox':
59
+ pc_max, _ = pc.max(dim=0, keepdim=True) # (1, 3)
60
+ pc_min, _ = pc.min(dim=0, keepdim=True) # (1, 3)
61
  shift = ((pc_min + pc_max) / 2).view(1, 3)
62
  scale = (pc_max - pc_min).max().reshape(1, 1) / 2
63
+ else: # Fallback if mode is not recognized, though your code doesn't use this branch with current inputs
64
+ shift = 0
65
+ scale = 1
66
+
67
+ # Prevent division by zero or very small scale
68
+ if scale < 1e-8:
69
+ scale = torch.tensor(1.0).reshape(1,1)
70
+
71
  pc = (pc - shift) / scale
72
  pcs[i] = pc
73
  return pcs
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
+ def predict(Seed, ckpt):
77
+ if Seed is None:
78
+ Seed = 777
79
+ seed_all(int(Seed)) # Ensure Seed is an integer
80
+
81
+ # Ensure args is accessible, provide a default if it's missing or not a Namespace
82
+ # This is a defensive measure, as the error was about loading argparse.Namespace
83
+ if not hasattr(ckpt, 'args') or not hasattr(ckpt['args'], 'model'):
84
+ # This case should ideally not happen if the checkpoint is valid
85
+ # but if it does, we need a fallback or error.
86
+ # For now, let's assume 'args' and 'args.model' exist based on the error.
87
+ print("Warning: Checkpoint 'args' or 'args.model' not found. Assuming 'gaussian'.")
88
+ model_type = 'gaussian'
89
+ latent_dim = ckpt.get('latent_dim', 128) # A common default
90
+ flexibility = ckpt.get('flexibility', 0.0) # A common default
91
+ else:
92
+ model_type = ckpt['args'].model
93
+ latent_dim = ckpt['args'].latent_dim
94
+ flexibility = ckpt['args'].flexibility
95
+
96
+
97
+ if model_type == 'gaussian':
98
+ # Pass necessary args to the constructor
99
+ # We need to mock an args object if ckpt['args'] wasn't a full argparse.Namespace
100
+ # or if some attributes are missing.
101
+ mock_args = type('Args', (), {'latent_dim': latent_dim, 'hyper': getattr(ckpt.get('args', {}), 'hyper', None)})() # Add other required args
102
+ model = GaussianVAE(mock_args).to(device)
103
+ elif model_type == 'flow':
104
+ mock_args = type('Args', (), {
105
+ 'latent_dim': latent_dim,
106
+ 'flow_depth': getattr(ckpt.get('args', {}), 'flow_depth', 10), # Example default
107
+ 'flow_hidden_dim': getattr(ckpt.get('args', {}), 'flow_hidden_dim', 256), # Example default
108
+ 'hyper': getattr(ckpt.get('args', {}), 'hyper', None)
109
+ })()
110
+ model = FlowVAE(mock_args).to(device)
111
+ else:
112
+ raise ValueError(f"Unknown model type: {model_type}")
113
+
114
+ model.load_state_dict(ckpt['state_dict'])
115
+ model.eval() # Set model to evaluation mode
116
+
117
+ # Generate Point Clouds
118
+ gen_pcs = []
119
+ with torch.no_grad():
120
+ z = torch.randn([1, latent_dim]).to(device)
121
+ # The sample method might also depend on args from the checkpoint
122
+ num_points_to_generate = getattr(ckpt.get('args', {}), 'num_points', 2048) # Default to 2048 if not in args
123
+ x = model.sample(z, num_points_to_generate, flexibility=flexibility)
124
+ gen_pcs.append(x.detach().cpu())
125
+
126
+ gen_pcs_tensor = torch.cat(gen_pcs, dim=0)[:1] # Ensure we take only one point cloud
127
+ gen_pcs_normalized = normalize_point_clouds(gen_pcs_tensor.clone(), mode="shape_bbox") # Use .clone() if normalize_point_clouds modifies inplace
128
+
129
+ return gen_pcs_normalized[0]
130
+
131
+ def generate(seed, value):
132
+ if value == "Airplane":
133
+ ckpt = ckpt_airplane
134
+ elif value == "Chair":
135
+ ckpt = ckpt_chair
136
+ else:
137
+ # Default case or handle error
138
+ # For now, defaulting to airplane if 'value' is unexpected
139
+ print(f"Warning: Unknown model type '{value}'. Defaulting to Airplane.")
140
+ ckpt = ckpt_airplane
141
+
142
+ colors = (238, 75, 43) # RGB tuple for plotly
143
+
144
+ # Ensure seed is not None and is an int for the predict function
145
+ current_seed = seed
146
+ if current_seed is None:
147
+ current_seed = random.randint(0, 2**16 -1) # Generate a random seed if None
148
+ current_seed = int(current_seed)
149
 
150
+ points = predict(current_seed, ckpt)
151
+ # num_points = points.shape[0] # Not used directly in fig
152
 
153
  fig = go.Figure(
154
  data=[
155
  go.Scatter3d(
156
+ x=points[:, 0], y=points[:, 1], z=points[:, 2],
157
  mode='markers',
158
+ marker=dict(size=2, color=f'rgb({colors[0]},{colors[1]},{colors[2]})') # plotly expects rgb string
159
  )
160
  ],
161
  layout=dict(
162
  scene=dict(
163
+ xaxis=dict(visible=True, title='X', backgroundcolor="rgb(230, 230,230)", gridcolor="white", zerolinecolor="white"),
164
+ yaxis=dict(visible=True, title='Y', backgroundcolor="rgb(230, 230,230)", gridcolor="white", zerolinecolor="white"),
165
+ zaxis=dict(visible=True, title='Z', backgroundcolor="rgb(230, 230,230)", gridcolor="white", zerolinecolor="white"),
166
+ aspectmode='data' # Ensures proportional axes
167
+ ),
168
+ margin=dict(l=0, r=0, b=0, t=40), # Adjust margins
169
+ title=f"Generated {value} (Seed: {current_seed})"
170
  )
171
  )
172
  return fig
 
 
 
173
 
174
+ markdown = f'''
175
+ # Diffusion Probabilistic Models for 3D Point Cloud Generation
176
+
177
+ [The space demo for the CVPR 2021 paper "Diffusion Probabilistic Models for 3D Point Cloud Generation".](https://arxiv.org/abs/2103.01458)
178
 
179
+ [For the official implementation.](https://github.com/luost26/diffusion-point-cloud)
180
+ ### Future Work based on interest
181
+ - Adding new models for new type objects
182
+ - New Customization
 
183
 
184
+ It is running on **{device.upper()}**
 
185
 
186
+ ---
187
+ **Note:** The `GEN_chair.pt` file must be manually placed in the root directory for the "Chair" model to work.
188
+ It is not downloaded automatically by this script.
189
+ Check the [original repository's instructions](https://github.com/luost26/diffusion-point-cloud#pretrained-models) for downloading checkpoints.
190
+ ---
191
  '''
192
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
193
  with gr.Column():
194
  with gr.Row():
195
  gr.Markdown(markdown)
196
  with gr.Row():
197
+ seed_slider = gr.Slider(minimum=0, maximum=2**16 - 1, step=1, label='Seed (0 for random)', value=777) # Set initial value
198
+ model_dropdown = gr.Dropdown(choices=["Airplane", "Chair"], label="Choose Model Type", value="Airplane") # Set initial value
 
199
 
200
+ btn = gr.Button(value="Generate Point Cloud")
201
+ point_cloud_plot = gr.Plot() # Changed variable name for clarity
 
 
202
 
203
+ # demo.load(generate, [seed_slider, model_dropdown], point_cloud_plot) # demo.load usually runs on page load
204
+ btn.click(generate, [seed_slider, model_dropdown], point_cloud_plot)
205
+
206
+ if __name__ == "__main__":
207
+ # Ensure GEN_chair.pt exists if Chair model might be selected
208
+ if not os.path.exists(chair_model_path):
209
+ print(f"WARNING: Chair model checkpoint '{chair_model_path}' not found.")
210
+ print(f"The 'Chair' option in the UI may not work unless this file is present.")
211
+ print(f"Please download it from the original project repository and place it at '{chair_model_path}'.")
212
+
213
+ demo.launch()