anvilinteractiv commited on
Commit
c484979
·
verified ·
1 Parent(s): 15f252d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +300 -168
app.py CHANGED
@@ -3,10 +3,9 @@ import os
3
  import gradio as gr
4
  import numpy as np
5
  import torch
6
- from torch.cuda.amp import autocast
7
  import trimesh
8
  import random
9
- from PIL import Image
10
  from transformers import AutoModelForImageSegmentation
11
  from torchvision import transforms
12
  from huggingface_hub import hf_hub_download, snapshot_download
@@ -14,9 +13,6 @@ import subprocess
14
  import shutil
15
  import base64
16
  import logging
17
- import time
18
- import traceback
19
- import requests
20
 
21
  # Set up logging
22
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@@ -26,7 +22,7 @@ logger = logging.getLogger(__name__)
26
  try:
27
  subprocess.run("pip install spandrel==0.4.1 --no-deps", shell=True, check=True)
28
  except Exception as e:
29
- logger.error(f"Failed to install spandrel: {str(e)}\n{traceback.format_exc()}")
30
  raise
31
 
32
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
@@ -34,7 +30,7 @@ DTYPE = torch.float16
34
 
35
  logger.info(f"Using device: {DEVICE}")
36
 
37
- DEFAULT_FACE_NUMBER = 20000 # Reduced for memory efficiency
38
  MAX_SEED = np.iinfo(np.int32).max
39
  TRIPOSG_REPO_URL = "https://github.com/VAST-AI-Research/TripoSG.git"
40
  MV_ADAPTER_REPO_URL = "https://github.com/huanngzh/MV-Adapter.git"
@@ -62,20 +58,22 @@ sys.path.append(MV_ADAPTER_CODE_DIR)
62
  sys.path.append(os.path.join(MV_ADAPTER_CODE_DIR, "scripts"))
63
 
64
  try:
 
65
  from image_process import prepare_image
66
  from briarmbg import BriaRMBG
67
  snapshot_download("briaai/RMBG-1.4", local_dir=RMBG_PRETRAINED_MODEL)
68
- rmbg_net = BriaRMBG.from_pretrained(RMBG_PRETRAINED_MODEL).to(DEVICE, dtype=DTYPE)
69
  rmbg_net.eval()
70
  from triposg.pipelines.pipeline_triposg import TripoSGPipeline
71
  snapshot_download("VAST-AI/TripoSG", local_dir=TRIPOSG_PRETRAINED_MODEL)
72
- triposg_pipe = TripoSGPipeline.from_pretrained(TRIPOSG_PRETRAINED_MODEL).to(DEVICE, dtype=DTYPE)
73
  except Exception as e:
74
- logger.error(f"Failed to load TripoSG models: {str(e)}\n{traceback.format_exc()}")
75
  raise
76
 
77
  try:
78
- NUM_VIEWS = 4 # Reduced for memory efficiency
 
79
  from inference_ig2mv_sdxl import prepare_pipeline, preprocess_image, remove_bg
80
  from mvadapter.utils import get_orthogonal_camera, tensor_to_image, make_image_grid
81
  from mvadapter.utils.render import NVDiffRastContextWrapper, load_mesh, render
@@ -92,17 +90,17 @@ try:
92
  )
93
  birefnet = AutoModelForImageSegmentation.from_pretrained(
94
  "ZhengPeng7/BiRefNet", trust_remote_code=True
95
- ).to(DEVICE, dtype=DTYPE)
96
  transform_image = transforms.Compose(
97
  [
98
- transforms.Resize((512, 512)), # Reduced resolution
99
  transforms.ToTensor(),
100
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
101
  ]
102
  )
103
  remove_bg_fn = lambda x: remove_bg(x, birefnet, transform_image, DEVICE)
104
  except Exception as e:
105
- logger.error(f"Failed to load MV-Adapter models: {str(e)}\n{traceback.format_exc()}")
106
  raise
107
 
108
  try:
@@ -111,201 +109,139 @@ try:
111
  if not os.path.exists("checkpoints/big-lama.pt"):
112
  subprocess.run("wget -P checkpoints/ https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt", shell=True, check=True)
113
  except Exception as e:
114
- logger.error(f"Failed to download checkpoints: {str(e)}\n{traceback.format_exc()}")
115
  raise
116
 
117
- def log_gpu_memory():
118
- if torch.cuda.is_available():
119
- allocated = torch.cuda.memory_allocated() / 1024**3
120
- reserved = torch.cuda.memory_reserved() / 1024**3
121
- logger.info(f"GPU Memory: Allocated {allocated:.2f} GB, Reserved {reserved:.2f} GB")
122
-
123
  def get_random_hex():
124
  random_bytes = os.urandom(8)
125
  random_hex = random_bytes.hex()
126
  return random_hex
127
 
128
- def retry_on_failure(func, max_attempts=3, delay=1):
129
- for attempt in range(max_attempts):
130
- try:
131
- return func()
132
- except RuntimeError as e:
133
- logger.warning(f"Attempt {attempt + 1} failed: {str(e)}\n{traceback.format_exc()}")
134
- if attempt == max_attempts - 1:
135
- raise
136
- time.sleep(delay)
137
-
138
- @spaces.GPU(duration=2)
139
- @torch.no_grad()
140
- def run_segmentation(image):
141
  try:
142
- log_gpu_memory()
143
- if isinstance(image, dict):
144
- image_path = image.get("path") or image.get("url")
145
- if not image_path:
146
- raise ValueError("Invalid image input: no path or URL provided")
147
- if image_path.startswith("http"):
148
- temp_image_path = os.path.join(TMP_DIR, f"input_{get_random_hex()}.png")
149
- image_path = download_image(image_path, temp_image_path)
150
- elif isinstance(image, str) and image.startswith("http"):
151
- temp_image_path = os.path.join(TMP_DIR, f"input_{get_random_hex()}.png")
152
- image_path = download_image(image, temp_image_path)
153
- else:
154
- image_path = image
155
- if not isinstance(image, (str, bytes)) or (isinstance(image, str) and not os.path.exists(image)):
156
- raise ValueError(f"Expected str (path/URL), bytes, or FileData dict, got {type(image)}")
157
-
158
- with autocast():
159
- image_seg = prepare_image(image_path, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net)
160
- rmbg_net.to("cpu")
161
- torch.cuda.empty_cache()
162
- log_gpu_memory()
163
- return image_seg
164
- except Exception as e:
165
- logger.error(f"Error in run_segmentation: {str(e)}\n{traceback.format_exc()}")
166
- raise
167
-
168
- @spaces.GPU(duration=3)
169
- @torch.no_grad()
170
- def image_to_3d(image, seed, num_inference_steps=30, guidance_scale=7.0, simplify=True, target_face_num=DEFAULT_FACE_NUMBER, req=None):
171
- try:
172
- log_gpu_memory()
173
- triposg_pipe.to(DEVICE, dtype=DTYPE)
174
- with autocast():
175
- outputs = triposg_pipe(
176
- image=image,
177
- generator=torch.Generator(device=triposg_pipe.device).manual_seed(seed),
178
- num_inference_steps=num_inference_steps,
179
- guidance_scale=guidance_scale
180
- ).samples[0]
181
  mesh = trimesh.Trimesh(outputs[0].astype(np.float32), np.ascontiguousarray(outputs[1]))
 
182
  if simplify:
 
183
  from utils import simplify_mesh
184
  mesh = simplify_mesh(mesh, target_face_num)
185
- save_dir = os.path.join(TMP_DIR, str(req.session_hash) if req else "examples")
 
186
  os.makedirs(save_dir, exist_ok=True)
187
  mesh_path = os.path.join(save_dir, f"polygenixai_{get_random_hex()}.glb")
188
  mesh.export(mesh_path)
189
- triposg_pipe.to("cpu")
 
190
  torch.cuda.empty_cache()
191
- log_gpu_memory()
192
- return mesh_path
193
- except Exception as e:
194
- logger.error(f"Error in image_to_3d: {str(e)}\n{traceback.format_exc()}")
195
- raise
196
 
197
- @spaces.GPU(duration=3)
198
- @torch.no_grad()
199
- def run_texture(image, mesh_path, seed, req=None):
200
- try:
201
- log_gpu_memory()
202
- height, width = 512, 512
203
  cameras = get_orthogonal_camera(
204
- elevation_deg=[0, 0, 0, 89.99],
205
  distance=[1.8] * NUM_VIEWS,
206
  left=-0.55,
207
  right=0.55,
208
  bottom=-0.55,
209
  top=0.55,
210
- azimuth_deg=[x - 90 for x in [0, 90, 180, 180]],
211
  device=DEVICE,
212
  )
213
  ctx = NVDiffRastContextWrapper(device=DEVICE, context_type="cuda")
 
214
  mesh = load_mesh(mesh_path, rescale=True, device=DEVICE)
215
- with autocast():
216
- render_out = render(
217
- ctx,
218
- mesh,
219
- cameras,
220
- height=height,
221
- width=width,
222
- render_attr=False,
223
- normal_background=0.0,
224
- )
225
  control_images = (
226
  torch.cat(
227
- [(render_out.pos + 0.5).clamp(0, 1), (render_out.normal / 2 + 0.5).clamp(0, 1)],
 
 
 
228
  dim=-1,
229
  )
230
  .permute(0, 3, 1, 2)
231
  .to(DEVICE)
232
  )
233
- del render_out
234
  image = Image.open(image)
235
- birefnet.to(DEVICE, dtype=DTYPE)
236
- with autocast():
237
- image = remove_bg_fn(image)
238
- birefnet.to("cpu")
239
  image = preprocess_image(image, height, width)
240
- pipe_kwargs = {"generator": torch.Generator(device=DEVICE).manual_seed(seed)} if seed != -1 else {}
241
- mv_adapter_pipe.to(DEVICE, dtype=DTYPE)
242
- with autocast():
243
- images = mv_adapter_pipe(
244
- "high quality",
245
- height=height,
246
- width=width,
247
- num_inference_steps=10,
248
- guidance_scale=3.0,
249
- num_images_per_prompt=NUM_VIEWS,
250
- control_image=control_images,
251
- control_conditioning_scale=1.0,
252
- reference_image=image,
253
- reference_conditioning_scale=1.0,
254
- negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast",
255
- cross_attention_kwargs={"scale": 1.0},
256
- **pipe_kwargs,
257
- ).images
258
- mv_adapter_pipe.to("cpu")
259
- del control_images
260
- save_dir = os.path.join(TMP_DIR, str(req.session_hash) if req else "examples")
 
261
  os.makedirs(save_dir, exist_ok=True)
262
  mv_image_path = os.path.join(save_dir, f"mv_adapter_{get_random_hex()}.png")
263
  make_image_grid(images, rows=1).save(mv_image_path)
 
264
  from texture import TexturePipeline, ModProcessConfig
265
  texture_pipe = TexturePipeline(
266
  upscaler_ckpt_path="checkpoints/RealESRGAN_x2plus.pth",
267
  inpaint_ckpt_path="checkpoints/big-lama.pt",
268
  device=DEVICE,
269
  )
 
270
  textured_glb_path = texture_pipe(
271
  mesh_path=mesh_path,
272
  save_dir=save_dir,
273
  save_name=f"polygenixai_texture_mesh_{get_random_hex()}.glb",
274
  uv_unwarp=True,
275
- uv_size=2048,
276
  rgb_path=mv_image_path,
277
  rgb_process_config=ModProcessConfig(view_upscale=True, inpaint_mode="view"),
278
- camera_azimuth_deg=[x - 90 for x in [0, 90, 180, 180]],
279
  )
280
- torch.cuda.empty_cache()
281
- log_gpu_memory()
282
- return textured_glb_path
283
- except Exception as e:
284
- logger.error(f"Error in run_texture: {str(e)}\n{traceback.format_exc()}")
285
- raise
286
 
287
- @spaces.GPU(duration=3)
288
- @torch.no_grad()
289
- def run_full(image, seed=0, num_inference_steps=30, guidance_scale=7.0, simplify=True, target_face_num=DEFAULT_FACE_NUMBER, req=None):
290
- try:
291
- log_gpu_memory()
292
- image_seg = run_segmentation(image)
293
- mesh_path = image_to_3d(image_seg, seed, num_inference_steps, guidance_scale, simplify, target_face_num, req)
294
- textured_glb_path = run_texture(image, mesh_path, seed, req)
295
  return image_seg, mesh_path, textured_glb_path
296
  except Exception as e:
297
- logger.error(f"Error in run_full: {str(e)}\n{traceback.format_exc()}")
298
  raise
299
 
300
- def gradio_generate(image, seed=0, num_inference_steps=30, guidance_scale=7.0, simplify=True, target_face_num=DEFAULT_FACE_NUMBER):
301
  try:
302
  logger.info("Starting gradio_generate")
 
303
  api_key = os.getenv("POLYGENIX_API_KEY", "your-secret-api-key")
304
  request = gr.Request()
305
  if not request.headers.get("x-api-key") == api_key:
306
  logger.error("Invalid API key")
307
  raise ValueError("Invalid API key")
308
 
 
309
  if image.startswith("data:image"):
310
  logger.info("Processing base64 image")
311
  base64_string = image.split(",")[1]
@@ -319,12 +255,12 @@ def gradio_generate(image, seed=0, num_inference_steps=30, guidance_scale=7.0, s
319
  logger.error(f"Image file not found: {temp_image_path}")
320
  raise ValueError("Invalid or missing image file")
321
 
322
- image_seg, mesh_path, textured_glb_path = run_full(temp_image_path, seed, num_inference_steps, guidance_scale, simplify, target_face_num, request)
323
  session_hash = os.path.basename(os.path.dirname(textured_glb_path))
324
  logger.info(f"Generated model at /files/{session_hash}/{os.path.basename(textured_glb_path)}")
325
  return {"file_url": f"/files/{session_hash}/{os.path.basename(textured_glb_path)}"}
326
  except Exception as e:
327
- logger.error(f"Error in gradio_generate: {str(e)}\n{traceback.format_exc()}")
328
  raise
329
 
330
  def start_session(req: gr.Request):
@@ -333,7 +269,7 @@ def start_session(req: gr.Request):
333
  os.makedirs(save_dir, exist_ok=True)
334
  logger.info(f"Started session, created directory: {save_dir}")
335
  except Exception as e:
336
- logger.error(f"Error in start_session: {str(e)}\n{traceback.format_exc()}")
337
  raise
338
 
339
  def end_session(req: gr.Request):
@@ -342,7 +278,7 @@ def end_session(req: gr.Request):
342
  shutil.rmtree(save_dir)
343
  logger.info(f"Ended session, removed directory: {save_dir}")
344
  except Exception as e:
345
- logger.error(f"Error in end_session: {str(e)}\n{traceback.format_exc()}")
346
  raise
347
 
348
  def get_random_seed(randomize_seed, seed):
@@ -352,10 +288,12 @@ def get_random_seed(randomize_seed, seed):
352
  logger.info(f"Generated seed: {seed}")
353
  return seed
354
  except Exception as e:
355
- logger.error(f"Error in get_random_seed: {str(e)}\n{traceback.format_exc()}")
356
  raise
357
 
 
358
  def download_image(url: str, save_path: str) -> str:
 
359
  try:
360
  logger.info(f"Downloading image from {url}")
361
  response = requests.get(url, stream=True)
@@ -366,21 +304,216 @@ def download_image(url: str, save_path: str) -> str:
366
  logger.info(f"Saved image to {save_path}")
367
  return save_path
368
  except Exception as e:
369
- logger.error(f"Failed to download image from {url}: {str(e)}\n{traceback.format_exc()}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
  raise
371
 
372
- @spaces.GPU(duration=3)
373
  @torch.no_grad()
374
- def run_full_api(image, seed=0, num_inference_steps=30, guidance_scale=7.0, simplify=True, target_face_num=DEFAULT_FACE_NUMBER, req=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375
  try:
376
  logger.info("Running run_full_api")
377
- def execute():
378
- image_seg, mesh_path, textured_glb_path = run_full(image, seed, num_inference_steps, guidance_scale, simplify, target_face_num, req)
379
- session_hash = os.path.basename(os.path.dirname(textured_glb_path))
380
- return {"file_url": f"/files/{session_hash}/{os.path.basename(textured_glb_path)}"}
381
- return retry_on_failure(execute)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
  except Exception as e:
383
- logger.error(f"Error in run_full_api: {str(e)}\n{traceback.format_exc()}")
384
  raise
385
 
386
  # Define Gradio API endpoint
@@ -391,8 +524,8 @@ try:
391
  inputs=[
392
  gr.Image(type="filepath", label="Image"),
393
  gr.Number(label="Seed", value=0, precision=0),
394
- gr.Number(label="Inference Steps", value=30, precision=0),
395
- gr.Number(label="Guidance Scale", value=7.0),
396
  gr.Checkbox(label="Simplify Mesh", value=True),
397
  gr.Number(label="Target Face Number", value=DEFAULT_FACE_NUMBER, precision=0)
398
  ],
@@ -401,7 +534,7 @@ try:
401
  )
402
  logger.info("Gradio API interface initialized successfully")
403
  except Exception as e:
404
- logger.error(f"Failed to initialize Gradio API interface: {str(e)}\n{traceback.format_exc()}")
405
  raise
406
 
407
  HEADER = """
@@ -487,6 +620,7 @@ HEADER = """
487
  </style>
488
  """
489
 
 
490
  try:
491
  logger.info("Initializing Gradio Blocks interface")
492
  with gr.Blocks(title="PolyGenixAI", css="body { background-color: #1A1A1A; } .gr-panel { background-color: #2D2D2D; }") as demo:
@@ -519,7 +653,7 @@ try:
519
  minimum=8,
520
  maximum=50,
521
  step=1,
522
- value=30,
523
  info="Higher steps enhance detail but increase processing time",
524
  elem_classes="gr-slider"
525
  )
@@ -534,7 +668,7 @@ try:
534
  )
535
  reduce_face = gr.Checkbox(label="Simplify Mesh", value=True)
536
  target_face_num = gr.Slider(
537
- maximum=100000,
538
  minimum=10000,
539
  value=DEFAULT_FACE_NUMBER,
540
  label="Target Face Number",
@@ -554,7 +688,7 @@ try:
554
  f"{TRIPOSG_CODE_DIR}/assets/example_data/{image}"
555
  for image in os.listdir(f"{TRIPOSG_CODE_DIR}/assets/example_data")
556
  ],
557
- fn=run_full_api,
558
  inputs=[image_prompts],
559
  outputs=[seg_image, model_output, textured_model_output],
560
  cache_examples=True,
@@ -579,9 +713,7 @@ try:
579
  target_face_num
580
  ],
581
  outputs=[model_output]
582
- ).then(
583
- lambda: gr.Button(interactive=True), outputs=[gen_texture_button]
584
- )
585
  gen_texture_button.click(
586
  run_texture,
587
  inputs=[image_prompts, model_output, seed],
@@ -591,7 +723,7 @@ try:
591
  demo.unload(end_session)
592
  logger.info("Gradio Blocks interface initialized successfully")
593
  except Exception as e:
594
- logger.error(f"Failed to initialize Gradio Blocks interface: {str(e)}\n{traceback.format_exc()}")
595
  raise
596
 
597
  if __name__ == "__main__":
@@ -600,5 +732,5 @@ if __name__ == "__main__":
600
  demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
601
  logger.info("Gradio application launched successfully")
602
  except Exception as e:
603
- logger.error(f"Failed to launch Gradio application: {str(e)}\n{traceback.format_exc()}")
604
  raise
 
3
  import gradio as gr
4
  import numpy as np
5
  import torch
6
+ from PIL import Image
7
  import trimesh
8
  import random
 
9
  from transformers import AutoModelForImageSegmentation
10
  from torchvision import transforms
11
  from huggingface_hub import hf_hub_download, snapshot_download
 
13
  import shutil
14
  import base64
15
  import logging
 
 
 
16
 
17
  # Set up logging
18
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
22
  try:
23
  subprocess.run("pip install spandrel==0.4.1 --no-deps", shell=True, check=True)
24
  except Exception as e:
25
+ logger.error(f"Failed to install spandrel: {str(e)}")
26
  raise
27
 
28
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
30
 
31
  logger.info(f"Using device: {DEVICE}")
32
 
33
+ DEFAULT_FACE_NUMBER = 100000
34
  MAX_SEED = np.iinfo(np.int32).max
35
  TRIPOSG_REPO_URL = "https://github.com/VAST-AI-Research/TripoSG.git"
36
  MV_ADAPTER_REPO_URL = "https://github.com/huanngzh/MV-Adapter.git"
 
58
  sys.path.append(os.path.join(MV_ADAPTER_CODE_DIR, "scripts"))
59
 
60
  try:
61
+ # triposg
62
  from image_process import prepare_image
63
  from briarmbg import BriaRMBG
64
  snapshot_download("briaai/RMBG-1.4", local_dir=RMBG_PRETRAINED_MODEL)
65
+ rmbg_net = BriaRMBG.from_pretrained(RMBG_PRETRAINED_MODEL).to(DEVICE)
66
  rmbg_net.eval()
67
  from triposg.pipelines.pipeline_triposg import TripoSGPipeline
68
  snapshot_download("VAST-AI/TripoSG", local_dir=TRIPOSG_PRETRAINED_MODEL)
69
+ triposg_pipe = TripoSGPipeline.from_pretrained(TRIPOSG_PRETRAINED_MODEL).to(DEVICE, DTYPE)
70
  except Exception as e:
71
+ logger.error(f"Failed to load TripoSG models: {str(e)}")
72
  raise
73
 
74
  try:
75
+ # mv adapter
76
+ NUM_VIEWS = 6
77
  from inference_ig2mv_sdxl import prepare_pipeline, preprocess_image, remove_bg
78
  from mvadapter.utils import get_orthogonal_camera, tensor_to_image, make_image_grid
79
  from mvadapter.utils.render import NVDiffRastContextWrapper, load_mesh, render
 
90
  )
91
  birefnet = AutoModelForImageSegmentation.from_pretrained(
92
  "ZhengPeng7/BiRefNet", trust_remote_code=True
93
+ ).to(DEVICE)
94
  transform_image = transforms.Compose(
95
  [
96
+ transforms.Resize((1024, 1024)),
97
  transforms.ToTensor(),
98
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
99
  ]
100
  )
101
  remove_bg_fn = lambda x: remove_bg(x, birefnet, transform_image, DEVICE)
102
  except Exception as e:
103
+ logger.error(f"Failed to load MV-Adapter models: {str(e)}")
104
  raise
105
 
106
  try:
 
109
  if not os.path.exists("checkpoints/big-lama.pt"):
110
  subprocess.run("wget -P checkpoints/ https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt", shell=True, check=True)
111
  except Exception as e:
112
+ logger.error(f"Failed to download checkpoints: {str(e)}")
113
  raise
114
 
 
 
 
 
 
 
115
  def get_random_hex():
116
  random_bytes = os.urandom(8)
117
  random_hex = random_bytes.hex()
118
  return random_hex
119
 
120
+ @spaces.GPU(duration=5)
121
+ def run_full(image: str, seed: int = 0, num_inference_steps: int = 50, guidance_scale: float = 7.5, simplify: bool = True, target_face_num: int = DEFAULT_FACE_NUMBER, req=None):
 
 
 
 
 
 
 
 
 
 
 
122
  try:
123
+ image_seg = prepare_image(image, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net)
124
+
125
+ outputs = triposg_pipe(
126
+ image=image_seg,
127
+ generator=torch.Generator(device=triposg_pipe.device).manual_seed(seed),
128
+ num_inference_steps=num_inference_steps,
129
+ guidance_scale=guidance_scale
130
+ ).samples[0]
131
+ logger.info("Mesh extraction done")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  mesh = trimesh.Trimesh(outputs[0].astype(np.float32), np.ascontiguousarray(outputs[1]))
133
+
134
  if simplify:
135
+ logger.info("Starting mesh simplification")
136
  from utils import simplify_mesh
137
  mesh = simplify_mesh(mesh, target_face_num)
138
+
139
+ save_dir = os.path.join(TMP_DIR, "examples")
140
  os.makedirs(save_dir, exist_ok=True)
141
  mesh_path = os.path.join(save_dir, f"polygenixai_{get_random_hex()}.glb")
142
  mesh.export(mesh_path)
143
+ logger.info(f"Saved mesh to {mesh_path}")
144
+
145
  torch.cuda.empty_cache()
 
 
 
 
 
146
 
147
+ height, width = 768, 768
 
 
 
 
 
148
  cameras = get_orthogonal_camera(
149
+ elevation_deg=[0, 0, 0, 0, 89.99, -89.99],
150
  distance=[1.8] * NUM_VIEWS,
151
  left=-0.55,
152
  right=0.55,
153
  bottom=-0.55,
154
  top=0.55,
155
+ azimuth_deg=[x - 90 for x in [0, 90, 180, 270, 180, 180]],
156
  device=DEVICE,
157
  )
158
  ctx = NVDiffRastContextWrapper(device=DEVICE, context_type="cuda")
159
+
160
  mesh = load_mesh(mesh_path, rescale=True, device=DEVICE)
161
+ render_out = render(
162
+ ctx,
163
+ mesh,
164
+ cameras,
165
+ height=height,
166
+ width=width,
167
+ render_attr=False,
168
+ normal_background=0.0,
169
+ )
 
170
  control_images = (
171
  torch.cat(
172
+ [
173
+ (render_out.pos + 0.5).clamp(0, 1),
174
+ (render_out.normal / 2 + 0.5).clamp(0, 1),
175
+ ],
176
  dim=-1,
177
  )
178
  .permute(0, 3, 1, 2)
179
  .to(DEVICE)
180
  )
181
+
182
  image = Image.open(image)
183
+ image = remove_bg_fn(image)
 
 
 
184
  image = preprocess_image(image, height, width)
185
+
186
+ pipe_kwargs = {}
187
+ if seed != -1 and isinstance(seed, int):
188
+ pipe_kwargs["generator"] = torch.Generator(device=DEVICE).manual_seed(seed)
189
+
190
+ images = mv_adapter_pipe(
191
+ "high quality",
192
+ height=height,
193
+ width=width,
194
+ num_inference_steps=15,
195
+ guidance_scale=3.0,
196
+ num_images_per_prompt=NUM_VIEWS,
197
+ control_image=control_images,
198
+ control_conditioning_scale=1.0,
199
+ reference_image=image,
200
+ reference_conditioning_scale=1.0,
201
+ negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast",
202
+ cross_attention_kwargs={"scale": 1.0},
203
+ **pipe_kwargs,
204
+ ).images
205
+
206
+ torch.cuda.empty_cache()
207
  os.makedirs(save_dir, exist_ok=True)
208
  mv_image_path = os.path.join(save_dir, f"mv_adapter_{get_random_hex()}.png")
209
  make_image_grid(images, rows=1).save(mv_image_path)
210
+
211
  from texture import TexturePipeline, ModProcessConfig
212
  texture_pipe = TexturePipeline(
213
  upscaler_ckpt_path="checkpoints/RealESRGAN_x2plus.pth",
214
  inpaint_ckpt_path="checkpoints/big-lama.pt",
215
  device=DEVICE,
216
  )
217
+
218
  textured_glb_path = texture_pipe(
219
  mesh_path=mesh_path,
220
  save_dir=save_dir,
221
  save_name=f"polygenixai_texture_mesh_{get_random_hex()}.glb",
222
  uv_unwarp=True,
223
+ uv_size=4096,
224
  rgb_path=mv_image_path,
225
  rgb_process_config=ModProcessConfig(view_upscale=True, inpaint_mode="view"),
226
+ camera_azimuth_deg=[x - 90 for x in [0, 90, 180, 270, 180, 180]],
227
  )
 
 
 
 
 
 
228
 
 
 
 
 
 
 
 
 
229
  return image_seg, mesh_path, textured_glb_path
230
  except Exception as e:
231
+ logger.error(f"Error in run_full: {str(e)}")
232
  raise
233
 
234
+ def gradio_generate(image: str, seed: int = 0, num_inference_steps: int = 50, guidance_scale: float = 7.5, simplify: bool = True, target_face_num: int = DEFAULT_FACE_NUMBER):
235
  try:
236
  logger.info("Starting gradio_generate")
237
+ # Verify API key
238
  api_key = os.getenv("POLYGENIX_API_KEY", "your-secret-api-key")
239
  request = gr.Request()
240
  if not request.headers.get("x-api-key") == api_key:
241
  logger.error("Invalid API key")
242
  raise ValueError("Invalid API key")
243
 
244
+ # Handle base64 image or file path
245
  if image.startswith("data:image"):
246
  logger.info("Processing base64 image")
247
  base64_string = image.split(",")[1]
 
255
  logger.error(f"Image file not found: {temp_image_path}")
256
  raise ValueError("Invalid or missing image file")
257
 
258
+ image_seg, mesh_path, textured_glb_path = run_full(temp_image_path, seed, num_inference_steps, guidance_scale, simplify, target_face_num, req=None)
259
  session_hash = os.path.basename(os.path.dirname(textured_glb_path))
260
  logger.info(f"Generated model at /files/{session_hash}/{os.path.basename(textured_glb_path)}")
261
  return {"file_url": f"/files/{session_hash}/{os.path.basename(textured_glb_path)}"}
262
  except Exception as e:
263
+ logger.error(f"Error in gradio_generate: {str(e)}")
264
  raise
265
 
266
  def start_session(req: gr.Request):
 
269
  os.makedirs(save_dir, exist_ok=True)
270
  logger.info(f"Started session, created directory: {save_dir}")
271
  except Exception as e:
272
+ logger.error(f"Error in start_session: {str(e)}")
273
  raise
274
 
275
  def end_session(req: gr.Request):
 
278
  shutil.rmtree(save_dir)
279
  logger.info(f"Ended session, removed directory: {save_dir}")
280
  except Exception as e:
281
+ logger.error(f"Error in end_session: {str(e)}")
282
  raise
283
 
284
  def get_random_seed(randomize_seed, seed):
 
288
  logger.info(f"Generated seed: {seed}")
289
  return seed
290
  except Exception as e:
291
+ logger.error(f"Error in get_random_seed: {str(e)}")
292
  raise
293
 
294
+
295
  def download_image(url: str, save_path: str) -> str:
296
+ """Download an image from a URL and save it locally."""
297
  try:
298
  logger.info(f"Downloading image from {url}")
299
  response = requests.get(url, stream=True)
 
304
  logger.info(f"Saved image to {save_path}")
305
  return save_path
306
  except Exception as e:
307
+ logger.error(f"Failed to download image from {url}: {str(e)}")
308
+ raise
309
+
310
+ @spaces.GPU()
311
+ @torch.no_grad()
312
+ def run_segmentation(image):
313
+ try:
314
+ logger.info("Running segmentation")
315
+ # Handle FileData dict or URL
316
+ if isinstance(image, dict):
317
+ image_path = image.get("path") or image.get("url")
318
+ if not image_path:
319
+ logger.error("Invalid image input: no path or URL provided")
320
+ raise ValueError("Invalid image input: no path or URL provided")
321
+ if image_path.startswith("http"):
322
+ temp_image_path = os.path.join(TMP_DIR, f"input_{get_random_hex()}.png")
323
+ image_path = download_image(image_path, temp_image_path)
324
+ elif isinstance(image, str) and image.startswith("http"):
325
+ temp_image_path = os.path.join(TMP_DIR, f"input_{get_random_hex()}.png")
326
+ image_path = download_image(image, temp_image_path)
327
+ else:
328
+ image_path = image
329
+ if not isinstance(image, (str, bytes)) or (isinstance(image, str) and not os.path.exists(image)):
330
+ logger.error(f"Invalid image type or path: {type(image)}")
331
+ raise ValueError(f"Expected str (path/URL), bytes, or FileData dict, got {type(image)}")
332
+
333
+ image = prepare_image(image_path, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net)
334
+ logger.info("Segmentation complete")
335
+ return image
336
+ except Exception as e:
337
+ logger.error(f"Error in run_segmentation: {str(e)}")
338
+ raise
339
+
340
+ @spaces.GPU(duration=5)
341
+ @torch.no_grad()
342
+ def image_to_3d(
343
+ image, # Changed to accept FileData dict or PIL Image
344
+ seed: int,
345
+ num_inference_steps: int,
346
+ guidance_scale: float,
347
+ simplify: bool,
348
+ target_face_num: int,
349
+ req: gr.Request
350
+ ):
351
+ try:
352
+ logger.info("Running image_to_3d")
353
+ # Handle FileData dict from gradio_client
354
+ if isinstance(image, dict):
355
+ image_path = image.get("path") or image.get("url")
356
+ if not image_path:
357
+ logger.error("Invalid image input: no path or URL provided")
358
+ raise ValueError("Invalid image input: no path or URL provided")
359
+ image = Image.open(image_path)
360
+ elif not isinstance(image, Image.Image):
361
+ logger.error(f"Invalid image type: {type(image)}")
362
+ raise ValueError(f"Expected PIL Image or FileData dict, got {type(image)}")
363
+
364
+ outputs = triposg_pipe(
365
+ image=image,
366
+ generator=torch.Generator(device=triposg_pipe.device).manual_seed(seed),
367
+ num_inference_steps=num_inference_steps,
368
+ guidance_scale=guidance_scale
369
+ ).samples[0]
370
+ logger.info("Mesh extraction done")
371
+ mesh = trimesh.Trimesh(outputs[0].astype(np.float32), np.ascontiguousarray(outputs[1]))
372
+
373
+ if simplify:
374
+ logger.info("Starting mesh simplification")
375
+ try:
376
+ from utils import simplify_mesh
377
+ mesh = simplify_mesh(mesh, target_face_num)
378
+ except ImportError as e:
379
+ logger.error(f"Failed to import simplify_mesh: {str(e)}")
380
+ raise
381
+
382
+ save_dir = os.path.join(TMP_DIR, str(req.session_hash))
383
+ os.makedirs(save_dir, exist_ok=True)
384
+ mesh_path = os.path.join(save_dir, f"polygenixai_{get_random_hex()}.glb")
385
+ mesh.export(mesh_path)
386
+ logger.info(f"Saved mesh to {mesh_path}")
387
+
388
+ torch.cuda.empty_cache()
389
+ return mesh_path
390
+ except Exception as e:
391
+ logger.error(f"Error in image_to_3d: {str(e)}")
392
  raise
393
 
394
+ @spaces.GPU(duration=5)
395
  @torch.no_grad()
396
+ def run_texture(image: Image, mesh_path: str, seed: int, req: gr.Request):
397
+ try:
398
+ logger.info("Running texture generation")
399
+ height, width = 768, 768
400
+ cameras = get_orthogonal_camera(
401
+ elevation_deg=[0, 0, 0, 0, 89.99, -89.99],
402
+ distance=[1.8] * NUM_VIEWS,
403
+ left=-0.55,
404
+ right=0.55,
405
+ bottom=-0.55,
406
+ top=0.55,
407
+ azimuth_deg=[x - 90 for x in [0, 90, 180, 270, 180, 180]],
408
+ device=DEVICE,
409
+ )
410
+ ctx = NVDiffRastContextWrapper(device=DEVICE, context_type="cuda")
411
+
412
+ mesh = load_mesh(mesh_path, rescale=True, device=DEVICE)
413
+ render_out = render(
414
+ ctx,
415
+ mesh,
416
+ cameras,
417
+ height=height,
418
+ width=width,
419
+ render_attr=False,
420
+ normal_background=0.0,
421
+ )
422
+ control_images = (
423
+ torch.cat(
424
+ [
425
+ (render_out.pos + 0.5).clamp(0, 1),
426
+ (render_out.normal / 2 + 0.5).clamp(0, 1),
427
+ ],
428
+ dim=-1,
429
+ )
430
+ .permute(0, 3, 1, 2)
431
+ .to(DEVICE)
432
+ )
433
+
434
+ image = Image.open(image)
435
+ image = remove_bg_fn(image)
436
+ image = preprocess_image(image, height, width)
437
+
438
+ pipe_kwargs = {}
439
+ if seed != -1 and isinstance(seed, int):
440
+ pipe_kwargs["generator"] = torch.Generator(device=DEVICE).manual_seed(seed)
441
+
442
+ images = mv_adapter_pipe(
443
+ "high quality",
444
+ height=height,
445
+ width=width,
446
+ num_inference_steps=15,
447
+ guidance_scale=3.0,
448
+ num_images_per_prompt=NUM_VIEWS,
449
+ control_image=control_images,
450
+ control_conditioning_scale=1.0,
451
+ reference_image=image,
452
+ reference_conditioning_scale=1.0,
453
+ negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast",
454
+ cross_attention_kwargs={"scale": 1.0},
455
+ **pipe_kwargs,
456
+ ).images
457
+
458
+ torch.cuda.empty_cache()
459
+ save_dir = os.path.join(TMP_DIR, str(req.session_hash))
460
+ os.makedirs(save_dir, exist_ok=True)
461
+ mv_image_path = os.path.join(save_dir, f"mv_adapter_{get_random_hex()}.png")
462
+ make_image_grid(images, rows=1).save(mv_image_path)
463
+
464
+ from texture import TexturePipeline, ModProcessConfig
465
+ texture_pipe = TexturePipeline(
466
+ upscaler_ckpt_path="checkpoints/RealESRGAN_x2plus.pth",
467
+ inpaint_ckpt_path="checkpoints/big-lama.pt",
468
+ device=DEVICE,
469
+ )
470
+
471
+ textured_glb_path = texture_pipe(
472
+ mesh_path=mesh_path,
473
+ save_dir=save_dir,
474
+ save_name=f"polygenixai_texture_mesh_{get_random_hex()}.glb",
475
+ uv_unwarp=True,
476
+ uv_size=4096,
477
+ rgb_path=mv_image_path,
478
+ rgb_process_config=ModProcessConfig(view_upscale=True, inpaint_mode="view"),
479
+ camera_azimuth_deg=[x - 90 for x in [0, 90, 180, 270, 180, 180]],
480
+ )
481
+
482
+ logger.info(f"Textured model saved to {textured_glb_path}")
483
+ return textured_glb_path
484
+ except Exception as e:
485
+ logger.error(f"Error in run_texture: {str(e)}")
486
+ raise
487
+
488
+ @spaces.GPU(duration=5)
489
+ @torch.no_grad()
490
+ def run_full_api(image, seed: int = 0, num_inference_steps: int = 50, guidance_scale: float = 7.5, simplify: bool = True, target_face_num: int = DEFAULT_FACE_NUMBER, req: gr.Request = None):
491
  try:
492
  logger.info("Running run_full_api")
493
+ # Handle FileData dict or URL
494
+ if isinstance(image, dict):
495
+ image_path = image.get("path") or image.get("url")
496
+ if not image_path:
497
+ logger.error("Invalid image input: no path or URL provided")
498
+ raise ValueError("Invalid image input: no path or URL provided")
499
+ if image_path.startswith("http"):
500
+ temp_image_path = os.path.join(TMP_DIR, f"input_{get_random_hex()}.png")
501
+ image_path = download_image(image_path, temp_image_path)
502
+ elif isinstance(image, str) and image.startswith("http"):
503
+ temp_image_path = os.path.join(TMP_DIR, f"input_{get_random_hex()}.png")
504
+ image_path = download_image(image, temp_image_path)
505
+ else:
506
+ image_path = image
507
+ if not isinstance(image, str) or not os.path.exists(image_path):
508
+ logger.error(f"Invalid image path: {image_path}")
509
+ raise ValueError(f"Invalid image path: {image_path}")
510
+
511
+ image_seg, mesh_path, textured_glb_path = run_full(image_path, seed, num_inference_steps, guidance_scale, simplify, target_face_num, req)
512
+ session_hash = os.path.basename(os.path.dirname(textured_glb_path))
513
+ logger.info(f"Generated textured model at /files/{session_hash}/{os.path.basename(textured_glb_path)}")
514
+ return {"file_url": f"/files/{session_hash}/{os.path.basename(textured_glb_path)}"}
515
  except Exception as e:
516
+ logger.error(f"Error in run_full_api: {str(e)}")
517
  raise
518
 
519
  # Define Gradio API endpoint
 
524
  inputs=[
525
  gr.Image(type="filepath", label="Image"),
526
  gr.Number(label="Seed", value=0, precision=0),
527
+ gr.Number(label="Inference Steps", value=50, precision=0),
528
+ gr.Number(label="Guidance Scale", value=7.5),
529
  gr.Checkbox(label="Simplify Mesh", value=True),
530
  gr.Number(label="Target Face Number", value=DEFAULT_FACE_NUMBER, precision=0)
531
  ],
 
534
  )
535
  logger.info("Gradio API interface initialized successfully")
536
  except Exception as e:
537
+ logger.error(f"Failed to initialize Gradio API interface: {str(e)}")
538
  raise
539
 
540
  HEADER = """
 
620
  </style>
621
  """
622
 
623
+ # Gradio web interface
624
  try:
625
  logger.info("Initializing Gradio Blocks interface")
626
  with gr.Blocks(title="PolyGenixAI", css="body { background-color: #1A1A1A; } .gr-panel { background-color: #2D2D2D; }") as demo:
 
653
  minimum=8,
654
  maximum=50,
655
  step=1,
656
+ value=50,
657
  info="Higher steps enhance detail but increase processing time",
658
  elem_classes="gr-slider"
659
  )
 
668
  )
669
  reduce_face = gr.Checkbox(label="Simplify Mesh", value=True)
670
  target_face_num = gr.Slider(
671
+ maximum=1000000,
672
  minimum=10000,
673
  value=DEFAULT_FACE_NUMBER,
674
  label="Target Face Number",
 
688
  f"{TRIPOSG_CODE_DIR}/assets/example_data/{image}"
689
  for image in os.listdir(f"{TRIPOSG_CODE_DIR}/assets/example_data")
690
  ],
691
+ fn=run_full,
692
  inputs=[image_prompts],
693
  outputs=[seg_image, model_output, textured_model_output],
694
  cache_examples=True,
 
713
  target_face_num
714
  ],
715
  outputs=[model_output]
716
+ ).then(lambda: gr.Button(interactive=True), outputs=[gen_texture_button])
 
 
717
  gen_texture_button.click(
718
  run_texture,
719
  inputs=[image_prompts, model_output, seed],
 
723
  demo.unload(end_session)
724
  logger.info("Gradio Blocks interface initialized successfully")
725
  except Exception as e:
726
+ logger.error(f"Failed to initialize Gradio Blocks interface: {str(e)}")
727
  raise
728
 
729
  if __name__ == "__main__":
 
732
  demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
733
  logger.info("Gradio application launched successfully")
734
  except Exception as e:
735
+ logger.error(f"Failed to launch Gradio application: {str(e)}")
736
  raise