rgndgn commited on
Commit
a22b6d9
·
verified ·
1 Parent(s): efccc85
Files changed (1) hide show
  1. gradio_app.py +43 -84
gradio_app.py CHANGED
@@ -8,11 +8,11 @@ from PIL import Image
8
  import gradio as gr
9
  import trimesh
10
  from transparent_background import Remover
11
-
12
  import subprocess
 
13
 
14
  def install_cuda_toolkit():
15
- # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run"
16
  CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda_12.2.0_535.54.03_linux.run"
17
  CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL)
18
  subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE])
@@ -25,24 +25,22 @@ def install_cuda_toolkit():
25
  os.environ["CUDA_HOME"],
26
  "" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"],
27
  )
28
- # Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range
29
  os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
30
 
31
  install_cuda_toolkit()
32
 
33
- # Import and setup SPAR3D
34
  os.system("USE_CUDA=1 pip install -vv --no-build-isolation ./texture_baker ./uv_unwrapper")
35
  import spar3d.utils as spar3d_utils
36
  from spar3d.system import SPAR3D
37
 
38
- # Constants
39
  COND_WIDTH = 512
40
  COND_HEIGHT = 512
41
  COND_DISTANCE = 2.2
42
  COND_FOVY = 0.591627
43
  BACKGROUND_COLOR = [0.5, 0.5, 0.5]
 
 
44
 
45
- # Initialize models
46
  device = spar3d_utils.get_device()
47
  bg_remover = Remover()
48
  spar3d_model = SPAR3D.from_pretrained(
@@ -51,17 +49,14 @@ spar3d_model = SPAR3D.from_pretrained(
51
  weight_name="model.safetensors"
52
  ).eval().to(device)
53
 
54
- # Initialize camera parameters
55
  c2w_cond = spar3d_utils.default_cond_c2w(COND_DISTANCE)
56
  intrinsic, intrinsic_normed_cond = spar3d_utils.create_intrinsic_from_fov_rad(
57
  COND_FOVY, COND_HEIGHT, COND_WIDTH
58
  )
59
 
60
  def create_rgba_image(rgb_image: Image.Image, mask: np.ndarray = None) -> Image.Image:
61
- """Create an RGBA image from RGB image and optional mask."""
62
  rgba_image = rgb_image.convert('RGBA')
63
  if mask is not None:
64
- # Ensure mask is 2D before converting to alpha
65
  if len(mask.shape) > 2:
66
  mask = mask.squeeze()
67
  alpha = Image.fromarray((mask * 255).astype(np.uint8))
@@ -69,55 +64,37 @@ def create_rgba_image(rgb_image: Image.Image, mask: np.ndarray = None) -> Image.
69
  return rgba_image
70
 
71
  def create_batch(input_image: Image.Image) -> dict[str, Any]:
72
- """Prepare image batch for model input."""
73
- # Resize and convert input image to numpy array
74
  resized_image = input_image.resize((COND_WIDTH, COND_HEIGHT))
75
  img_array = np.array(resized_image).astype(np.float32) / 255.0
76
 
77
- # Extract RGB and alpha channels
78
- if img_array.shape[-1] == 4: # RGBA
79
  rgb = img_array[..., :3]
80
  mask = img_array[..., 3:4]
81
- else: # RGB
82
  rgb = img_array
83
  mask = np.ones((*img_array.shape[:2], 1), dtype=np.float32)
84
 
85
- # Convert to tensors while keeping channel-last format
86
- rgb = torch.from_numpy(rgb).float() # [H, W, 3]
87
- mask = torch.from_numpy(mask).float() # [H, W, 1]
88
-
89
- # Create background blend (match channel-last format)
90
- bg_tensor = torch.tensor(BACKGROUND_COLOR).view(1, 1, 3) # [1, 1, 3]
91
-
92
- # Blend RGB with background using mask (all in channel-last format)
93
- rgb_cond = torch.lerp(bg_tensor, rgb, mask) # [H, W, 3]
94
-
95
- # Move channels to correct dimension and add batch dimension
96
- # Important: For SPAR3D image tokenizer, we need [B, H, W, C] format
97
- rgb_cond = rgb_cond.unsqueeze(0) # [1, H, W, 3]
98
- mask = mask.unsqueeze(0) # [1, H, W, 1]
99
 
100
- # Create the batch dictionary
101
  batch = {
102
- "rgb_cond": rgb_cond, # [1, H, W, 3]
103
- "mask_cond": mask, # [1, H, W, 1]
104
- "c2w_cond": c2w_cond.unsqueeze(0), # [1, 4, 4]
105
- "intrinsic_cond": intrinsic.unsqueeze(0), # [1, 3, 3]
106
- "intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0), # [1, 3, 3]
107
  }
108
 
109
- for k, v in batch.items():
110
- print(f"[debug] {k} final shape:", v.shape)
111
-
112
  return batch
113
 
114
  def forward_model(batch, system, guidance_scale=3.0, seed=0, device="cuda"):
115
- """Process batch through model and generate point cloud."""
116
-
117
  batch_size = batch["rgb_cond"].shape[0]
118
  assert batch_size == 1, f"Expected batch size 1, got {batch_size}"
119
 
120
- # Generate point cloud tokens
121
  try:
122
  cond_tokens = system.forward_pdiff_cond(batch)
123
  except Exception as e:
@@ -129,7 +106,6 @@ def forward_model(batch, system, guidance_scale=3.0, seed=0, device="cuda"):
129
  print("rgb_cond requires_grad:", batch["rgb_cond"].requires_grad)
130
  raise
131
 
132
- # Sample points
133
  sample_iter = system.sampler.sample_batch_progressive(
134
  batch_size,
135
  cond_tokens,
@@ -137,38 +113,23 @@ def forward_model(batch, system, guidance_scale=3.0, seed=0, device="cuda"):
137
  device=device
138
  )
139
 
140
- # Get final samples
141
  for x in sample_iter:
142
  samples = x["xstart"]
143
 
144
  pc_cond = samples.permute(0, 2, 1).float()
145
-
146
- # Normalize point cloud
147
  pc_cond = spar3d_utils.normalize_pc_bbox(pc_cond)
148
-
149
- # Subsample to 512 points
150
  pc_cond = pc_cond[:, torch.randperm(pc_cond.shape[1])[:512]]
151
-
152
  return pc_cond
153
 
154
  @spaces.GPU
155
  @torch.inference_mode()
156
- def generate_and_process_3d(image: Image.Image) -> tuple[str | None, Image.Image | None]:
157
- """Generate image from prompt and convert to 3D model."""
158
-
159
- # Generate random seed
160
  seed = np.random.randint(0, np.iinfo(np.int32).max)
161
 
162
  try:
163
  rgb_image = image.convert('RGB')
164
-
165
- # bg_remover returns a PIL Image already, no need to convert
166
  no_bg_image = bg_remover.process(rgb_image)
167
- print(f"[debug] no_bg_image type: {type(no_bg_image)}, mode: {no_bg_image.mode}")
168
-
169
- # Convert to RGBA if not already
170
  rgba_image = no_bg_image.convert('RGBA')
171
- print(f"[debug] rgba_image mode: {rgba_image.mode}")
172
 
173
  processed_image = spar3d_utils.foreground_crop(
174
  rgba_image,
@@ -177,15 +138,8 @@ def generate_and_process_3d(image: Image.Image) -> tuple[str | None, Image.Image
177
  no_crop=False
178
  )
179
 
180
- # Show the processed image alpha channel for debugging
181
- alpha = np.array(processed_image)[:, :, 3]
182
- print(f"[debug] Alpha channel stats - min: {alpha.min()}, max: {alpha.max()}, unique: {np.unique(alpha)}")
183
-
184
- # Prepare batch for processing
185
  batch = create_batch(processed_image)
186
  batch = {k: v.to(device) for k, v in batch.items()}
187
-
188
- # Generate point cloud
189
  pc_cond = forward_model(
190
  batch,
191
  spar3d_model,
@@ -195,25 +149,24 @@ def generate_and_process_3d(image: Image.Image) -> tuple[str | None, Image.Image
195
  )
196
  batch["pc_cond"] = pc_cond
197
 
198
- # Generate mesh
199
  with torch.no_grad():
200
  with torch.autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu', dtype=torch.bfloat16):
201
  trimesh_mesh, _ = spar3d_model.generate_mesh(
202
  batch,
203
- 1024, # texture_resolution
204
  remesh="none",
205
  vertex_count=-1,
206
  estimate_illumination=True
207
  )
208
  trimesh_mesh = trimesh_mesh[0]
209
 
210
- # Export to GLB
211
- temp_dir = tempfile.mkdtemp()
212
- output_path = os.path.join(temp_dir, 'mesh.glb')
213
-
214
  trimesh_mesh.export(output_path, file_type="glb", include_normals=True)
 
215
 
216
- return output_path
217
 
218
  except Exception as e:
219
  print(f"Error during generation: {str(e)}")
@@ -221,28 +174,34 @@ def generate_and_process_3d(image: Image.Image) -> tuple[str | None, Image.Image
221
  traceback.print_exc()
222
  return None
223
 
224
- # Create Gradio app using Blocks
225
  with gr.Blocks() as demo:
226
- gr.Markdown("This space is based on [Stable Point-Aware 3D](https://huggingface.co/spaces/stabilityai/stable-point-aware-3d) by Stability AI, [Text to 3D](https://huggingface.co/spaces/jbilcke-hf/text-to-3d) by jbilcke-hf.")
227
-
228
  with gr.Row():
229
  input_img = gr.Image(
230
- type="pil", label="Input Image", sources="upload", image_mode="RGBA"
231
- )
232
-
233
- with gr.Row():
234
- model_output = gr.Model3D(
235
- label="Generated .GLB model",
236
- clear_color=[0.0, 0.0, 0.0, 0.0],
237
  )
238
 
239
- # Event handler
 
 
 
240
  input_img.upload(
241
  fn=generate_and_process_3d,
242
  inputs=[input_img],
243
- outputs=[model_output],
244
  api_name="generate"
245
  )
246
 
247
  if __name__ == "__main__":
248
- demo.queue().launch()
 
 
 
 
 
 
 
8
  import gradio as gr
9
  import trimesh
10
  from transparent_background import Remover
11
+ from pathlib import Path
12
  import subprocess
13
+ import uuid
14
 
15
  def install_cuda_toolkit():
 
16
  CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda_12.2.0_535.54.03_linux.run"
17
  CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL)
18
  subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE])
 
25
  os.environ["CUDA_HOME"],
26
  "" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"],
27
  )
 
28
  os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
29
 
30
  install_cuda_toolkit()
31
 
 
32
  os.system("USE_CUDA=1 pip install -vv --no-build-isolation ./texture_baker ./uv_unwrapper")
33
  import spar3d.utils as spar3d_utils
34
  from spar3d.system import SPAR3D
35
 
 
36
  COND_WIDTH = 512
37
  COND_HEIGHT = 512
38
  COND_DISTANCE = 2.2
39
  COND_FOVY = 0.591627
40
  BACKGROUND_COLOR = [0.5, 0.5, 0.5]
41
+ OUTPUT_DIR = "output"
42
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
43
 
 
44
  device = spar3d_utils.get_device()
45
  bg_remover = Remover()
46
  spar3d_model = SPAR3D.from_pretrained(
 
49
  weight_name="model.safetensors"
50
  ).eval().to(device)
51
 
 
52
  c2w_cond = spar3d_utils.default_cond_c2w(COND_DISTANCE)
53
  intrinsic, intrinsic_normed_cond = spar3d_utils.create_intrinsic_from_fov_rad(
54
  COND_FOVY, COND_HEIGHT, COND_WIDTH
55
  )
56
 
57
  def create_rgba_image(rgb_image: Image.Image, mask: np.ndarray = None) -> Image.Image:
 
58
  rgba_image = rgb_image.convert('RGBA')
59
  if mask is not None:
 
60
  if len(mask.shape) > 2:
61
  mask = mask.squeeze()
62
  alpha = Image.fromarray((mask * 255).astype(np.uint8))
 
64
  return rgba_image
65
 
66
  def create_batch(input_image: Image.Image) -> dict[str, Any]:
 
 
67
  resized_image = input_image.resize((COND_WIDTH, COND_HEIGHT))
68
  img_array = np.array(resized_image).astype(np.float32) / 255.0
69
 
70
+ if img_array.shape[-1] == 4:
 
71
  rgb = img_array[..., :3]
72
  mask = img_array[..., 3:4]
73
+ else:
74
  rgb = img_array
75
  mask = np.ones((*img_array.shape[:2], 1), dtype=np.float32)
76
 
77
+ rgb = torch.from_numpy(rgb).float()
78
+ mask = torch.from_numpy(mask).float()
79
+ bg_tensor = torch.tensor(BACKGROUND_COLOR).view(1, 1, 3)
80
+ rgb_cond = torch.lerp(bg_tensor, rgb, mask)
81
+ rgb_cond = rgb_cond.unsqueeze(0)
82
+ mask = mask.unsqueeze(0)
 
 
 
 
 
 
 
 
83
 
 
84
  batch = {
85
+ "rgb_cond": rgb_cond,
86
+ "mask_cond": mask,
87
+ "c2w_cond": c2w_cond.unsqueeze(0),
88
+ "intrinsic_cond": intrinsic.unsqueeze(0),
89
+ "intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0),
90
  }
91
 
 
 
 
92
  return batch
93
 
94
  def forward_model(batch, system, guidance_scale=3.0, seed=0, device="cuda"):
 
 
95
  batch_size = batch["rgb_cond"].shape[0]
96
  assert batch_size == 1, f"Expected batch size 1, got {batch_size}"
97
 
 
98
  try:
99
  cond_tokens = system.forward_pdiff_cond(batch)
100
  except Exception as e:
 
106
  print("rgb_cond requires_grad:", batch["rgb_cond"].requires_grad)
107
  raise
108
 
 
109
  sample_iter = system.sampler.sample_batch_progressive(
110
  batch_size,
111
  cond_tokens,
 
113
  device=device
114
  )
115
 
 
116
  for x in sample_iter:
117
  samples = x["xstart"]
118
 
119
  pc_cond = samples.permute(0, 2, 1).float()
 
 
120
  pc_cond = spar3d_utils.normalize_pc_bbox(pc_cond)
 
 
121
  pc_cond = pc_cond[:, torch.randperm(pc_cond.shape[1])[:512]]
 
122
  return pc_cond
123
 
124
  @spaces.GPU
125
  @torch.inference_mode()
126
+ def generate_and_process_3d(image: Image.Image) -> str:
 
 
 
127
  seed = np.random.randint(0, np.iinfo(np.int32).max)
128
 
129
  try:
130
  rgb_image = image.convert('RGB')
 
 
131
  no_bg_image = bg_remover.process(rgb_image)
 
 
 
132
  rgba_image = no_bg_image.convert('RGBA')
 
133
 
134
  processed_image = spar3d_utils.foreground_crop(
135
  rgba_image,
 
138
  no_crop=False
139
  )
140
 
 
 
 
 
 
141
  batch = create_batch(processed_image)
142
  batch = {k: v.to(device) for k, v in batch.items()}
 
 
143
  pc_cond = forward_model(
144
  batch,
145
  spar3d_model,
 
149
  )
150
  batch["pc_cond"] = pc_cond
151
 
 
152
  with torch.no_grad():
153
  with torch.autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu', dtype=torch.bfloat16):
154
  trimesh_mesh, _ = spar3d_model.generate_mesh(
155
  batch,
156
+ 1024,
157
  remesh="none",
158
  vertex_count=-1,
159
  estimate_illumination=True
160
  )
161
  trimesh_mesh = trimesh_mesh[0]
162
 
163
+ unique_id = str(uuid.uuid4())
164
+ filename = f'model_{unique_id}.glb'
165
+ output_path = os.path.join(OUTPUT_DIR, filename)
 
166
  trimesh_mesh.export(output_path, file_type="glb", include_normals=True)
167
+ public_url = f"https://john6666-image-to-3d-test.hf.space/file={output_path}"
168
 
169
+ return public_url
170
 
171
  except Exception as e:
172
  print(f"Error during generation: {str(e)}")
 
174
  traceback.print_exc()
175
  return None
176
 
177
+ # Create Gradio interface
178
  with gr.Blocks() as demo:
 
 
179
  with gr.Row():
180
  input_img = gr.Image(
181
+ type="pil",
182
+ label=None, # Remove the label
183
+ show_label=False, # Further remove label
184
+ sources="upload",
185
+ image_mode="RGBA",
186
+ elem_id="hidden-upload" # Add an ID for CSS targeting
 
187
  )
188
 
189
+ # Make all output components invisible
190
+ with gr.Row(visible=False):
191
+ model_url = gr.Textbox(label="Model URL")
192
+
193
  input_img.upload(
194
  fn=generate_and_process_3d,
195
  inputs=[input_img],
196
+ outputs=[model_url],
197
  api_name="generate"
198
  )
199
 
200
  if __name__ == "__main__":
201
+ demo.queue().launch(
202
+ server_name="0.0.0.0",
203
+ server_port=7860,
204
+ share=True,
205
+ ssr_mode=False,
206
+ allowed_paths=[Path(OUTPUT_DIR).resolve()]
207
+ )