anvilinteractiv commited on
Commit
8bbeb3a
·
verified ·
1 Parent(s): 9ea22df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +399 -598
app.py CHANGED
@@ -3,19 +3,20 @@ import os
3
  import gradio as gr
4
  import numpy as np
5
  import torch
6
- from PIL import Image
7
  import trimesh
8
  import random
 
9
  from transformers import AutoModelForImageSegmentation
10
  from torchvision import transforms
11
- from huggingface_hub import hf_hub_download, snapshot_download, HfApi
12
  import subprocess
13
  import shutil
14
  import base64
15
  import logging
16
- import requests
17
- from functools import wraps
18
  import time
 
 
19
 
20
  # Set up logging
21
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@@ -25,18 +26,15 @@ logger = logging.getLogger(__name__)
25
  try:
26
  subprocess.run("pip install spandrel==0.4.1 --no-deps", shell=True, check=True)
27
  except Exception as e:
28
- logger.error(f"Failed to install spandrel: {str(e)}")
29
  raise
30
 
31
- # Check if running in ZeroGPU environment
32
- IS_ZEROGPU = os.getenv("HF_ZERO_SPACE", "0") == "1"
33
-
34
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
35
  DTYPE = torch.float16
36
 
37
- logger.info(f"Using device: {DEVICE}, ZeroGPU: {IS_ZEROGPU}")
38
 
39
- DEFAULT_FACE_NUMBER = 50000 # Reduced for L4 and ZeroGPU
40
  MAX_SEED = np.iinfo(np.int32).max
41
  TRIPOSG_REPO_URL = "https://github.com/VAST-AI-Research/TripoSG.git"
42
  MV_ADAPTER_REPO_URL = "https://github.com/huanngzh/MV-Adapter.git"
@@ -64,22 +62,20 @@ sys.path.append(MV_ADAPTER_CODE_DIR)
64
  sys.path.append(os.path.join(MV_ADAPTER_CODE_DIR, "scripts"))
65
 
66
  try:
67
- # triposg
68
  from image_process import prepare_image
69
  from briarmbg import BriaRMBG
70
  snapshot_download("briaai/RMBG-1.4", local_dir=RMBG_PRETRAINED_MODEL)
71
- rmbg_net = BriaRMBG.from_pretrained(RMBG_PRETRAINED_MODEL).to(DEVICE)
72
  rmbg_net.eval()
73
  from triposg.pipelines.pipeline_triposg import TripoSGPipeline
74
  snapshot_download("VAST-AI/TripoSG", local_dir=TRIPOSG_PRETRAINED_MODEL)
75
- triposg_pipe = TripoSGPipeline.from_pretrained(TRIPOSG_PRETRAINED_MODEL).to(DEVICE, DTYPE)
76
  except Exception as e:
77
- logger.error(f"Failed to load TripoSG models: {str(e)}")
78
  raise
79
 
80
  try:
81
- # mv adapter
82
- NUM_VIEWS = 6
83
  from inference_ig2mv_sdxl import prepare_pipeline, preprocess_image, remove_bg
84
  from mvadapter.utils import get_orthogonal_camera, tensor_to_image, make_image_grid
85
  from mvadapter.utils.render import NVDiffRastContextWrapper, load_mesh, render
@@ -96,17 +92,17 @@ try:
96
  )
97
  birefnet = AutoModelForImageSegmentation.from_pretrained(
98
  "ZhengPeng7/BiRefNet", trust_remote_code=True
99
- ).to(DEVICE)
100
  transform_image = transforms.Compose(
101
  [
102
- transforms.Resize((512, 512)), # Reduced for L4 and ZeroGPU
103
  transforms.ToTensor(),
104
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
105
  ]
106
  )
107
  remove_bg_fn = lambda x: remove_bg(x, birefnet, transform_image, DEVICE)
108
  except Exception as e:
109
- logger.error(f"Failed to load MV-Adapter models: {str(e)}")
110
  raise
111
 
112
  try:
@@ -115,274 +111,282 @@ try:
115
  if not os.path.exists("checkpoints/big-lama.pt"):
116
  subprocess.run("wget -P checkpoints/ https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt", shell=True, check=True)
117
  except Exception as e:
118
- logger.error(f"Failed to download checkpoints: {str(e)}")
119
  raise
120
 
 
 
 
 
 
 
121
  def get_random_hex():
122
  random_bytes = os.urandom(8)
123
  random_hex = random_bytes.hex()
124
  return random_hex
125
 
126
- # Retry decorator for GPU tasks
127
- def retry_on_gpu_abort(max_attempts=3, delay=5):
128
- def decorator(func):
129
- @wraps(func)
130
- def wrapper(*args, **kwargs):
131
- attempts = 0
132
- while attempts < max_attempts:
133
- try:
134
- return func(*args, **kwargs)
135
- except gr.Error as e:
136
- if "GPU task aborted" in str(e):
137
- attempts += 1
138
- logger.warning(f"GPU task aborted, retrying {attempts}/{max_attempts}")
139
- time.sleep(delay)
140
- else:
141
- raise
142
- raise gr.Error("Max retries reached for GPU task")
143
- return wrapper
144
- return decorator
145
-
146
- # Quota check for ZeroGPU
147
- def check_quota():
148
- if not IS_ZEROGPU:
149
- return True
150
- hf_api = HfApi()
151
  try:
152
- quota = hf_api.get_space_runtime(token=os.getenv("HF_TOKEN"))
153
- logger.info(f"Remaining ZeroGPU quota: {quota}")
154
- return quota.get("gpu_quota_remaining", 0) > 60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  except Exception as e:
156
- logger.error(f"Failed to check quota: {str(e)}")
157
- return False
158
-
159
- # Conditional GPU decorator
160
- def conditional_gpu_decorator(duration=None):
161
- def decorator(func):
162
- if IS_ZEROGPU:
163
- return spaces.GPU(duration=duration)(func) if duration else spaces.GPU()(func)
164
- return func
165
- return decorator
166
-
167
- @conditional_gpu_decorator(duration=10)
168
- @retry_on_gpu_abort(max_attempts=3, delay=5)
169
- def run_full(image: str, seed: int = 0, num_inference_steps: int = 30, guidance_scale: float = 7.5, simplify: bool = True, target_face_num: int = DEFAULT_FACE_NUMBER, req=None, style_filter: str = "None"):
170
  try:
171
- logger.info(f"Starting run_full with image: {image}, seed: {seed}, style: {style_filter}")
172
- if not check_quota():
173
- raise gr.Error("Insufficient GPU quota remaining")
174
- image_seg = prepare_image(image, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net)
175
- logger.info("Image segmentation completed")
176
- logger.info(f"VRAM usage after segmentation: {torch.cuda.memory_allocated(DEVICE)/1e9:.2f} GB")
177
-
178
- outputs = triposg_pipe(
179
- image=image_seg,
180
- generator=torch.Generator(device=triposg_pipe.device).manual_seed(seed),
181
- num_inference_steps=num_inference_steps,
182
- guidance_scale=guidance_scale
183
- ).samples[0]
184
- logger.info("Mesh extraction done")
185
  mesh = trimesh.Trimesh(outputs[0].astype(np.float32), np.ascontiguousarray(outputs[1]))
186
-
187
  if simplify:
188
- logger.info("Starting mesh simplification")
189
  from utils import simplify_mesh
190
  mesh = simplify_mesh(mesh, target_face_num)
191
-
192
- save_dir = os.path.join(TMP_DIR, "examples")
193
  os.makedirs(save_dir, exist_ok=True)
194
  mesh_path = os.path.join(save_dir, f"polygenixai_{get_random_hex()}.glb")
195
  mesh.export(mesh_path)
196
- logger.info(f"Saved mesh to {mesh_path}")
197
- logger.info(f"VRAM usage after mesh generation: {torch.cuda.memory_allocated(DEVICE)/1e9:.2f} GB")
198
-
199
  torch.cuda.empty_cache()
 
 
 
 
 
200
 
201
- height, width = 512, 512 # Reduced for L4 and ZeroGPU
 
 
 
 
 
202
  cameras = get_orthogonal_camera(
203
- elevation_deg=[0, 0, 0, 0, 89.99, -89.99],
204
  distance=[1.8] * NUM_VIEWS,
205
  left=-0.55,
206
  right=0.55,
207
  bottom=-0.55,
208
  top=0.55,
209
- azimuth_deg=[x - 90 for x in [0, 90, 180, 270, 180, 180]],
210
  device=DEVICE,
211
  )
212
  ctx = NVDiffRastContextWrapper(device=DEVICE, context_type="cuda")
213
-
214
  mesh = load_mesh(mesh_path, rescale=True, device=DEVICE)
215
- render_out = render(
216
- ctx,
217
- mesh,
218
- cameras,
219
- height=height,
220
- width=width,
221
- render_attr=False,
222
- normal_background=0.0,
223
- )
 
224
  control_images = (
225
  torch.cat(
226
- [
227
- (render_out.pos + 0.5).clamp(0, 1),
228
- (render_out.normal / 2 + 0.5).clamp(0, 1),
229
- ],
230
  dim=-1,
231
  )
232
  .permute(0, 3, 1, 2)
233
  .to(DEVICE)
234
  )
235
-
236
  image = Image.open(image)
237
- image = remove_bg_fn(image)
 
 
 
238
  image = preprocess_image(image, height, width)
 
 
 
239
 
240
- pipe_kwargs = {}
241
- if seed != -1 and isinstance(seed, int):
242
- pipe_kwargs["generator"] = torch.Generator(device=DEVICE).manual_seed(seed)
243
-
244
- prompt = f"high quality, {style_filter.lower()}" if style_filter != "None" else "high quality"
245
- images = mv_adapter_pipe(
246
- prompt,
247
- height=height,
248
- width=width,
249
- num_inference_steps=10, # Reduced for L4 and ZeroGPU
250
- guidance_scale=3.0,
251
- num_images_per_prompt=NUM_VIEWS,
252
- control_image=control_images,
253
- control_conditioning_scale=1.0,
254
- reference_image=image,
255
- reference_conditioning_scale=1.0,
256
- negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast",
257
- cross_attention_kwargs={"scale": 1.0},
258
- **pipe_kwargs,
259
- ).images
260
 
261
- torch.cuda.empty_cache()
262
- logger.info(f"VRAM usage after texture generation: {torch.cuda.memory_allocated(DEVICE)/1e9:.2f} GB")
263
- os.makedirs(save_dir, exist_ok=True)
264
- mv_image_path = os.path.join(save_dir, f"mv_adapter_{get_random_hex()}.png")
265
- make_image_grid(images, rows=1).save(mv_image_path)
266
 
267
- from texture import TexturePipeline, ModProcessConfig
268
- texture_pipe = TexturePipeline(
269
- upscaler_ckpt_path="checkpoints/RealESRGAN_x2plus.pth",
270
- inpaint_ckpt_path="checkpoints/big-lama.pt",
271
- device=DEVICE,
272
- )
273
 
274
- textured_glb_path = texture_pipe(
275
- mesh_path=mesh_path,
276
- save_dir=save_dir,
277
- save_name=f"polygenixai_texture_mesh_{get_random_hex()}.glb",
278
- uv_unwarp=True,
279
- uv_size=2048, # Reduced for L4 and ZeroGPU
280
- rgb_path=mv_image_path,
281
- rgb_process_config=ModProcessConfig(view_upscale=True, inpaint_mode="view"),
282
- camera_azimuth_deg=[x - 90 for x in [0, 90, 180, 270, 180, 180]],
283
- )
284
 
285
- logger.info(f"run_full completed successfully, textured model saved to {textured_glb_path}")
286
- return image_seg, mesh_path, textured_glb_path
287
- except Exception as e:
288
- logger.error(f"Error in run_full: {str(e)}")
289
- raise
290
 
291
- def gradio_generate(image: str, seed: int = 0, num_inference_steps: int = 30, guidance_scale: float = 7.5, simplify: bool = True, target_face_num: int = DEFAULT_FACE_NUMBER, style_filter: str = "None"):
292
- try:
293
- logger.info("Starting gradio_generate")
294
- # Verify API key
295
- api_key = os.getenv("POLYGENIX_API_KEY", "your-secret-api-key")
296
- request = gr.Request()
297
- if not request.headers.get("x-api-key") == api_key:
298
- logger.error("Invalid API key")
299
- raise ValueError("Invalid API key")
300
 
301
- # Handle base64 image or file path
302
- if image.startswith("data:image"):
303
- logger.info("Processing base64 image")
304
- base64_string = image.split(",")[1]
305
- image_data = base64.b64decode(base64_string)
306
- temp_image_path = os.path.join(TMP_DIR, f"input_{get_random_hex()}.png")
307
- with open(temp_image_path, "wb") as f:
308
- f.write(image_data)
309
- else:
310
- temp_image_path = image
311
- if not os.path.exists(temp_image_path):
312
- logger.error(f"Image file not found: {temp_image_path}")
313
- raise ValueError("Invalid or missing image file")
314
 
315
- image_seg, mesh_path, textured_glb_path = run_full(temp_image_path, seed, num_inference_steps, guidance_scale, simplify, target_face_num, req=None, style_filter=style_filter)
316
- session_hash = os.path.basename(os.path.dirname(textured_glb_path))
317
- logger.info(f"Generated model at /files/{session_hash}/{os.path.basename(textured_glb_path)}")
318
- return {"file_url": f"/files/{session_hash}/{os.path.basename(textured_glb_path)}"}
319
- except Exception as e:
320
- logger.error(f"Error in gradio_generate: {str(e)}")
321
- raise
322
 
323
- # Conditional GPU decorator
324
- def conditional_gpu_decorator(duration=None):
325
- def decorator(func):
326
- if IS_ZEROGPU:
327
- return spaces.GPU(duration=duration)(func) if duration else spaces.GPU()(func)
328
- return func
329
- return decorator
330
 
331
- # Always apply @spaces.GPU for ZeroGPU in start_session
332
- @spaces.GPU() if IS_ZEROGPU else lambda x: x
333
- def start_session(req: gr.Request):
334
- try:
335
- save_dir = os.path.join(TMP_DIR, str(req.session_hash))
336
- os.makedirs(save_dir, exist_ok=True)
337
- logger.info(f"Started session, created directory: {save_dir}")
338
- except Exception as e:
339
- logger.error(f"Error in start_session: {str(e)}")
340
- raise
341
 
342
- def end_session(req: gr.Request):
343
- try:
344
- save_dir = os.path.join(TMP_DIR, str(req.session_hash))
345
- shutil.rmtree(save_dir)
346
- logger.info(f"Ended session, removed directory: {save_dir}")
347
- except Exception as e:
348
- logger.error(f"Error in end_session: {str(e)}")
349
- raise
350
 
351
- def get_random_seed(randomize_seed, seed):
352
- try:
353
- if randomize_seed:
354
- seed = random.randint(0, MAX_SEED)
355
- logger.info(f"Generated seed: {seed}")
356
- return seed
357
- except Exception as e:
358
- logger.error(f"Error in get_random_seed: {str(e)}")
359
- raise
 
 
 
360
 
361
- def download_image(url: str, save_path: str) -> str:
362
- """Download an image from a URL and save it locally."""
363
- try:
364
- logger.info(f"Downloading image from {url}")
365
- response = requests.get(url, stream=True)
366
- response.raise_for_status()
367
- with open(save_path, "wb") as f:
368
- for chunk in response.iter_content(chunk_size=8192):
369
- f.write(chunk)
370
- logger.info(f"Saved image to {save_path}")
371
- return save_path
372
- except Exception as e:
373
- logger.error(f"Failed to download image from {url}: {str(e)}")
374
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375
 
376
- @conditional_gpu_decorator()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377
  @torch.no_grad()
378
  def run_segmentation(image):
379
  try:
380
- logger.info("Running segmentation")
381
- # Handle FileData dict or URL
382
  if isinstance(image, dict):
383
  image_path = image.get("path") or image.get("url")
384
  if not image_path:
385
- logger.error("Invalid image input: no path or URL provided")
386
  raise ValueError("Invalid image input: no path or URL provided")
387
  if image_path.startswith("http"):
388
  temp_image_path = os.path.join(TMP_DIR, f"input_{get_random_hex()}.png")
@@ -393,215 +397,234 @@ def run_segmentation(image):
393
  else:
394
  image_path = image
395
  if not isinstance(image, (str, bytes)) or (isinstance(image, str) and not os.path.exists(image)):
396
- logger.error(f"Invalid image type or path: {type(image)}")
397
  raise ValueError(f"Expected str (path/URL), bytes, or FileData dict, got {type(image)}")
398
 
399
- image = prepare_image(image_path, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net)
400
- logger.info("Segmentation complete")
 
401
  torch.cuda.empty_cache()
402
- return image
 
403
  except Exception as e:
404
- logger.error(f"Error in run_segmentation: {str(e)}")
405
  raise
406
 
407
- @conditional_gpu_decorator(duration=5)
408
- @retry_on_gpu_abort(max_attempts=3, delay=5)
409
  @torch.no_grad()
410
- def image_to_3d(
411
- image,
412
- seed: int,
413
- num_inference_steps: int,
414
- guidance_scale: float,
415
- simplify: bool,
416
- target_face_num: int,
417
- req: gr.Request
418
- ):
419
  try:
420
- logger.info("Running image_to_3d")
421
- if not check_quota():
422
- raise gr.Error("Insufficient GPU quota remaining")
423
- # Handle FileData dict from gradio_client
424
- if isinstance(image, dict):
425
- image_path = image.get("path") or image.get("url")
426
- if not image_path:
427
- logger.error("Invalid image input: no path or URL provided")
428
- raise ValueError("Invalid image input: no path or URL provided")
429
- image = Image.open(image_path)
430
- elif not isinstance(image, Image.Image):
431
- logger.error(f"Invalid image type: {type(image)}")
432
- raise ValueError(f"Expected PIL Image or FileData dict, got {type(image)}")
433
-
434
- outputs = triposg_pipe(
435
- image=image,
436
- generator=torch.Generator(device=triposg_pipe.device).manual_seed(seed),
437
- num_inference_steps=num_inference_steps,
438
- guidance_scale=guidance_scale
439
- ).samples[0]
440
- logger.info("Mesh extraction done")
441
  mesh = trimesh.Trimesh(outputs[0].astype(np.float32), np.ascontiguousarray(outputs[1]))
442
-
443
  if simplify:
444
- logger.info("Starting mesh simplification")
445
- try:
446
- from utils import simplify_mesh
447
- mesh = simplify_mesh(mesh, target_face_num)
448
- except ImportError as e:
449
- logger.error(f"Failed to import simplify_mesh: {str(e)}")
450
- raise
451
-
452
- save_dir = os.path.join(TMP_DIR, str(req.session_hash))
453
  os.makedirs(save_dir, exist_ok=True)
454
  mesh_path = os.path.join(save_dir, f"polygenixai_{get_random_hex()}.glb")
455
  mesh.export(mesh_path)
456
- logger.info(f"Saved mesh to {mesh_path}")
457
- logger.info(f"VRAM usage after mesh generation: {torch.cuda.memory_allocated(DEVICE)/1e9:.2f} GB")
458
-
459
  torch.cuda.empty_cache()
 
460
  return mesh_path
461
  except Exception as e:
462
- logger.error(f"Error in image_to_3d: {str(e)}")
463
  raise
464
 
465
- @conditional_gpu_decorator(duration=5)
466
- @retry_on_gpu_abort(max_attempts=3, delay=5)
467
  @torch.no_grad()
468
- def run_texture(image, mesh_path: str, seed: int, req: gr.Request, style_filter: str = "None"):
469
  try:
470
- logger.info(f"Running texture generation with style: {style_filter}")
471
- if not check_quota():
472
- raise gr.Error("Insufficient GPU quota remaining")
473
- height, width = 512, 512 # Reduced for L4 and ZeroGPU
474
  cameras = get_orthogonal_camera(
475
- elevation_deg=[0, 0, 0, 0, 89.99, -89.99],
476
  distance=[1.8] * NUM_VIEWS,
477
  left=-0.55,
478
  right=0.55,
479
  bottom=-0.55,
480
  top=0.55,
481
- azimuth_deg=[x - 90 for x in [0, 90, 180, 270, 180, 180]],
482
  device=DEVICE,
483
  )
484
  ctx = NVDiffRastContextWrapper(device=DEVICE, context_type="cuda")
485
-
486
  mesh = load_mesh(mesh_path, rescale=True, device=DEVICE)
487
- render_out = render(
488
- ctx,
489
- mesh,
490
- cameras,
491
- height=height,
492
- width=width,
493
- render_attr=False,
494
- normal_background=0.0,
495
- )
 
496
  control_images = (
497
  torch.cat(
498
- [
499
- (render_out.pos + 0.5).clamp(0, 1),
500
- (render_out.normal / 2 + 0.5).clamp(0, 1),
501
- ],
502
  dim=-1,
503
  )
504
  .permute(0, 3, 1, 2)
505
  .to(DEVICE)
506
  )
507
-
508
- # Handle both file path and PIL Image
509
- if isinstance(image, str):
510
- if image.startswith("http"):
511
- temp_image_path = os.path.join(TMP_DIR, f"input_{get_random_hex()}.png")
512
- image = download_image(image, temp_image_path)
513
- image = Image.open(image)
514
- elif not isinstance(image, Image.Image):
515
- logger.error(f"Invalid image type: {type(image)}")
516
- raise ValueError(f"Expected PIL Image or str (path/URL), got {type(image)}")
517
-
518
- image = remove_bg_fn(image)
519
  image = preprocess_image(image, height, width)
520
-
521
- pipe_kwargs = {}
522
- if seed != -1 and isinstance(seed, int):
523
- pipe_kwargs["generator"] = torch.Generator(device=DEVICE).manual_seed(seed)
524
-
525
- prompt = f"high quality, {style_filter.lower()}" if style_filter != "None" else "high quality"
526
- images = mv_adapter_pipe(
527
- prompt,
528
- height=height,
529
- width=width,
530
- num_inference_steps=10, # Reduced for L4 and ZeroGPU
531
- guidance_scale=3.0,
532
- num_images_per_prompt=NUM_VIEWS,
533
- control_image=control_images,
534
- control_conditioning_scale=1.0,
535
- reference_image=image,
536
- reference_conditioning_scale=1.0,
537
- negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast",
538
- cross_attention_kwargs={"scale": 1.0},
539
- **pipe_kwargs,
540
- ).images
541
-
542
- torch.cuda.empty_cache()
543
- logger.info(f"VRAM usage after texture generation: {torch.cuda.memory_allocated(DEVICE)/1e9:.2f} GB")
544
- save_dir = os.path.join(TMP_DIR, str(req.session_hash))
545
  os.makedirs(save_dir, exist_ok=True)
546
  mv_image_path = os.path.join(save_dir, f"mv_adapter_{get_random_hex()}.png")
547
  make_image_grid(images, rows=1).save(mv_image_path)
548
-
549
  from texture import TexturePipeline, ModProcessConfig
550
  texture_pipe = TexturePipeline(
551
  upscaler_ckpt_path="checkpoints/RealESRGAN_x2plus.pth",
552
  inpaint_ckpt_path="checkpoints/big-lama.pt",
553
  device=DEVICE,
554
  )
555
-
556
  textured_glb_path = texture_pipe(
557
  mesh_path=mesh_path,
558
  save_dir=save_dir,
559
  save_name=f"polygenixai_texture_mesh_{get_random_hex()}.glb",
560
  uv_unwarp=True,
561
- uv_size=2048, # Reduced for L4 and ZeroGPU
562
  rgb_path=mv_image_path,
563
  rgb_process_config=ModProcessConfig(view_upscale=True, inpaint_mode="view"),
564
- camera_azimuth_deg=[x - 90 for x in [0, 90, 180, 270, 180, 180]],
565
  )
566
-
567
- logger.info(f"Textured model saved to {textured_glb_path}")
568
  return textured_glb_path
569
  except Exception as e:
570
- logger.error(f"Error in run_texture: {str(e)}")
571
  raise
572
 
573
- @conditional_gpu_decorator(duration=10)
574
- @retry_on_gpu_abort(max_attempts=3, delay=5)
575
  @torch.no_grad()
576
- def run_full_api(image, seed: int = 0, num_inference_steps: int = 30, guidance_scale: float = 7.5, simplify: bool = True, target_face_num: int = DEFAULT_FACE_NUMBER, req: gr.Request = None, style_filter: str = "None"):
577
  try:
578
- logger.info("Running run_full_api")
579
- if not check_quota():
580
- raise gr.Error("Insufficient GPU quota remaining")
581
- # Handle FileData dict or URL
582
- if isinstance(image, dict):
583
- image_path = image.get("path") or image.get("url")
584
- if not image_path:
585
- logger.error("Invalid image input: no path or URL provided")
586
- raise ValueError("Invalid image input: no path or URL provided")
587
- if image_path.startswith("http"):
588
- temp_image_path = os.path.join(TMP_DIR, f"input_{get_random_hex()}.png")
589
- image_path = download_image(image_path, temp_image_path)
590
- elif isinstance(image, str) and image.startswith("http"):
 
 
 
 
 
 
 
 
 
591
  temp_image_path = os.path.join(TMP_DIR, f"input_{get_random_hex()}.png")
592
- image_path = download_image(image, temp_image_path)
 
593
  else:
594
- image_path = image
595
- if not isinstance(image, str) or not os.path.exists(image_path):
596
- logger.error(f"Invalid image path: {image_path}")
597
- raise ValueError(f"Invalid image path: {image_path}")
598
 
599
- image_seg, mesh_path, textured_glb_path = run_full(image_path, seed, num_inference_steps, guidance_scale, simplify, target_face_num, req, style_filter)
600
  session_hash = os.path.basename(os.path.dirname(textured_glb_path))
601
- logger.info(f"Generated textured model at /files/{session_hash}/{os.path.basename(textured_glb_path)}")
602
- return image_seg, mesh_path, textured_glb_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
603
  except Exception as e:
604
- logger.error(f"Error in run_full_api: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
605
  raise
606
 
607
  # Define Gradio API endpoint
@@ -613,236 +636,14 @@ try:
613
  gr.Image(type="filepath", label="Image"),
614
  gr.Number(label="Seed", value=0, precision=0),
615
  gr.Number(label="Inference Steps", value=30, precision=0),
616
- gr.Number(label="Guidance Scale", value=7.5),
617
  gr.Checkbox(label="Simplify Mesh", value=True),
618
- gr.Number(label="Target Face Number", value=DEFAULT_FACE_NUMBER, precision=0),
619
- gr.Dropdown(
620
- choices=["None", "Realistic", "Fantasy", "Cartoon", "Sci-Fi", "Vintage", "Cosmic", "Neon"],
621
- label="Style Filter",
622
- value="None",
623
- ),
624
  ],
625
  outputs="json",
626
  api_name="/api/generate"
627
  )
628
  logger.info("Gradio API interface initialized successfully")
629
  except Exception as e:
630
- logger.error(f"Failed to initialize Gradio API interface: {str(e)}")
631
- raise
632
-
633
- HEADER = """
634
- # 🌌 PolyGenixAI: Craft 3D Worlds with Cosmic Precision
635
- ## Unleash Infinite Creativity with AI-Powered 3D Generation by AnvilInteractive Solutions
636
- <p style="font-size: 1.1em; color: #A78BFA;">By <a href="https://www.anvilinteractive.com/" style="color: #A78BFA; text-decoration: none; font-weight: bold;">AnvilInteractive Solutions</a></p>
637
- ## 🚀 Launch Your Creation:
638
- 1. **Upload an Image** (clear, single-object images shine brightest)
639
- 2. **Choose a Style Filter** to infuse your unique vision
640
- 3. Click **Generate 3D Model** to sculpt your mesh
641
- 4. Click **Apply Texture** to bring your model to life
642
- 5. **Download GLB** to share your masterpiece
643
- <p style="font-size: 0.9em; margin-top: 10px; color: #D1D5DB;">Powered by cutting-edge AI and multi-view technology from AnvilInteractive Solutions. Join our <a href="https://www.anvilinteractive.com/community" style="color: #A78BFA; text-decoration: none;">PolyGenixAI Community</a> to connect with creators and spark inspiration.</p>
644
- <style>
645
- @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600;700&display=swap');
646
- body {
647
- background-color: #1A1A1A !important;
648
- font-family: 'Inter', sans-serif !important;
649
- color: #D1D5DB !important;
650
- }
651
- .gr-panel {
652
- background-color: #2D2D2D !important;
653
- border: 1px solid #7C3AED !important;
654
- border-radius: 12px !important;
655
- padding: 20px !important;
656
- box-shadow: 0 4px 10px rgba(124, 58, 237, 0.2) !important;
657
- }
658
- .gr-button-primary {
659
- background: linear-gradient(45deg, #7C3AED, #A78BFA) !important;
660
- color: white !important;
661
- border: none !important;
662
- border-radius: 8px !important;
663
- padding: 12px 24px !important;
664
- font-weight: 600 !important;
665
- transition: transform 0.2s, box-shadow 0.2s !important;
666
- }
667
- .gr-button-primary:hover {
668
- transform: translateY(-2px) !important;
669
- box-shadow: 0 4px 12px rgba(124, 58, 237, 0.5) !important;
670
- }
671
- .gr-button-secondary {
672
- background-color: #4B4B4B !important;
673
- color: #D1D5DB !important;
674
- border: 1px solid #A78BFA !important;
675
- border-radius: 8px !important;
676
- padding: 10px 20px !important;
677
- transition: transform 0.2s !important;
678
- }
679
- .gr-button-secondary:hover {
680
- transform: translateY(-1px) !important;
681
- background-color: #6B6B6B !important;
682
- }
683
- .gr-accordion {
684
- background-color: #2D2D2D !important;
685
- border-radius: 8px !important;
686
- border: 1px solid #7C3AED !important;
687
- }
688
- .gr-tab {
689
- background-color: #2D2D2D !important;
690
- color: #A78BFA !important;
691
- border: 1px solid #7C3AED !important;
692
- border-radius: 8px !important;
693
- margin: 5px !important;
694
- }
695
- .gr-tab:hover, .gr-tab-selected {
696
- background: linear-gradient(45deg, #7C3AED, #A78BFA) !important;
697
- color: white !important;
698
- }
699
- .gr-slider input[type=range]::-webkit-slider-thumb {
700
- background-color: #7C3AED !important;
701
- border: 2px solid #A78BFA !important;
702
- }
703
- .gr-dropdown {
704
- background-color: #2D2D2D !important;
705
- color: #D1D5DB !important;
706
- border: 1px solid #A78BFA !important;
707
- border-radius: 8px !important;
708
- }
709
- h1, h3 {
710
- color: #A78BFA !important;
711
- text-shadow: 0 0 10px rgba(124, 58, 237, 0.5) !important;
712
- }
713
- </style>
714
- """
715
-
716
- # ... [Previous imports and code unchanged until the Gradio Blocks interface] ...
717
-
718
- # Gradio web interface
719
- try:
720
- logger.info("Initializing Gradio Blocks interface")
721
- with gr.Blocks(title="PolyGenixAI", css="body { background-color: #1A1A1A; } .gr-panel { background-color: #2D2D2D; }") as demo:
722
- gr.Markdown(HEADER)
723
- with gr.Tabs(elem_classes="gr-tab"):
724
- with gr.Tab("Create 3D Model"):
725
- with gr.Row():
726
- with gr.Column(scale=1):
727
- image_prompts = gr.Image(label="Upload Image", type="filepath", height=300, elem_classes="gr-panel")
728
- seg_image = gr.Image(label="Preview Segmentation", type="pil", format="png", interactive=False, height=300, elem_classes="gr-panel")
729
- with gr.Accordion("Style & Settings", open=True, elem_classes="gr-accordion"):
730
- style_filter = gr.Dropdown(
731
- choices=["None", "Realistic", "Fantasy", "Cartoon", "Sci-Fi", "Vintage", "Cosmic", "Neon"],
732
- label="Style Filter",
733
- value="None",
734
- info="Select a style to inspire your 3D model (optional)",
735
- elem_classes="gr-dropdown"
736
- )
737
- seed = gr.Slider(
738
- label="Seed",
739
- minimum=0,
740
- maximum=MAX_SEED,
741
- step=1,
742
- value=0,
743
- elem_classes="gr-slider"
744
- )
745
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
746
- num_inference_steps = gr.Slider(
747
- label="Inference Steps",
748
- minimum=8,
749
- maximum=50,
750
- step=1,
751
- value=30,
752
- info="Higher steps enhance detail but increase processing time",
753
- elem_classes="gr-slider"
754
- )
755
- guidance_scale = gr.Slider(
756
- label="Guidance Scale",
757
- minimum=0.0,
758
- maximum=20.0,
759
- step=0.1,
760
- value=7.0,
761
- info="Controls adherence to input image",
762
- elem_classes="gr-slider"
763
- )
764
- reduce_face = gr.Checkbox(label="Simplify Mesh", value=True)
765
- target_face_num = gr.Slider(
766
- maximum=1000000,
767
- minimum=10000,
768
- value=DEFAULT_FACE_NUMBER,
769
- label="Target Face Number",
770
- info="Adjust mesh complexity for performance",
771
- elem_classes="gr-slider"
772
- )
773
- gen_button = gr.Button("Generate 3D Model", variant="primary", elem_classes="gr-button-primary")
774
- gen_texture_button = gr.Button("Apply Texture", variant="secondary", interactive=False, elem_classes="gr-button-secondary")
775
- with gr.Column(scale=1):
776
- model_output = gr.Model3D(label="3D Model Preview", interactive=False, height=400, elem_classes="gr-panel")
777
- textured_model_output = gr.Model3D(label="Textured 3D Model", interactive=False, height=400, elem_classes="gr-panel")
778
- download_button = gr.Button("Download GLB", variant="secondary", elem_classes="gr-button-secondary")
779
- with gr.Tab("Cosmic Gallery"):
780
- gr.Markdown("### Discover Stellar Creations")
781
- # Ensure example directory exists and contains valid images
782
- example_dir = f"{TRIPOSG_CODE_DIR}/assets/example_data"
783
- examples = []
784
- if os.path.exists(example_dir):
785
- valid_extensions = (".png", ".jpg", ".jpeg")
786
- examples = [
787
- [
788
- os.path.join(example_dir, image), # image
789
- 0, # seed
790
- 30, # num_inference_steps
791
- 7.5, # guidance_scale
792
- True, # reduce_face
793
- DEFAULT_FACE_NUMBER, # target_face_num
794
- "None" # style_filter
795
- ]
796
- for image in os.listdir(example_dir)
797
- if image.lower().endswith(valid_extensions)
798
- ]
799
- if not examples:
800
- logger.warning(f"No valid images found in {example_dir}, skipping examples")
801
- gr.Examples(
802
- examples=examples,
803
- fn=run_full,
804
- inputs=[image_prompts, seed, num_inference_steps, guidance_scale, reduce_face, target_face_num, style_filter],
805
- outputs=[seg_image, model_output, textured_model_output],
806
- cache_examples=True,
807
- )
808
- gr.Markdown("Connect with creators in our <a href='https://www.anvilinteractive.com/community' style='color: #A78BFA; text-decoration: none;'>PolyGenixAI Cosmic Community</a>!")
809
- gen_button.click(
810
- run_segmentation,
811
- inputs=[image_prompts],
812
- outputs=[seg_image]
813
- ).then(
814
- get_random_seed,
815
- inputs=[randomize_seed, seed],
816
- outputs=[seed],
817
- ).then(
818
- image_to_3d,
819
- inputs=[
820
- seg_image,
821
- seed,
822
- num_inference_steps,
823
- guidance_scale,
824
- reduce_face,
825
- target_face_num
826
- ],
827
- outputs=[model_output]
828
- ).then(lambda: gr.Button(interactive=True), outputs=[gen_texture_button])
829
- gen_texture_button.click(
830
- run_texture,
831
- inputs=[image_prompts, model_output, seed, style_filter],
832
- outputs=[textured_model_output]
833
- )
834
- demo.load(start_session)
835
- demo.unload(end_session)
836
- logger.info("Gradio Blocks interface initialized successfully")
837
- except Exception as e:
838
- logger.error(f"Failed to initialize Gradio Blocks interface: {str(e)}")
839
- raise
840
-
841
- if __name__ == "__main__":
842
- try:
843
- logger.info("Launching Gradio application")
844
- demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
845
- logger.info("Gradio application launched successfully")
846
- except Exception as e:
847
- logger.error(f"Failed to launch Gradio application: {str(e)}")
848
- raise
 
3
  import gradio as gr
4
  import numpy as np
5
  import torch
6
+ from torch.cuda.amp import autocast
7
  import trimesh
8
  import random
9
+ from PIL import Image
10
  from transformers import AutoModelForImageSegmentation
11
  from torchvision import transforms
12
+ from huggingface_hub import hf_hub_download, snapshot_download
13
  import subprocess
14
  import shutil
15
  import base64
16
  import logging
 
 
17
  import time
18
+ import traceback
19
+ import requests
20
 
21
  # Set up logging
22
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
26
  try:
27
  subprocess.run("pip install spandrel==0.4.1 --no-deps", shell=True, check=True)
28
  except Exception as e:
29
+ logger.error(f"Failed to install spandrel: {str(e)}\n{traceback.format_exc()}")
30
  raise
31
 
 
 
 
32
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
33
  DTYPE = torch.float16
34
 
35
+ logger.info(f"Using device: {DEVICE}")
36
 
37
+ DEFAULT_FACE_NUMBER = 20000 # Reduced for memory efficiency
38
  MAX_SEED = np.iinfo(np.int32).max
39
  TRIPOSG_REPO_URL = "https://github.com/VAST-AI-Research/TripoSG.git"
40
  MV_ADAPTER_REPO_URL = "https://github.com/huanngzh/MV-Adapter.git"
 
62
  sys.path.append(os.path.join(MV_ADAPTER_CODE_DIR, "scripts"))
63
 
64
  try:
 
65
  from image_process import prepare_image
66
  from briarmbg import BriaRMBG
67
  snapshot_download("briaai/RMBG-1.4", local_dir=RMBG_PRETRAINED_MODEL)
68
+ rmbg_net = BriaRMBG.from_pretrained(RMBG_PRETRAINED_MODEL).to(DEVICE, dtype=DTYPE)
69
  rmbg_net.eval()
70
  from triposg.pipelines.pipeline_triposg import TripoSGPipeline
71
  snapshot_download("VAST-AI/TripoSG", local_dir=TRIPOSG_PRETRAINED_MODEL)
72
+ triposg_pipe = TripoSGPipeline.from_pretrained(TRIPOSG_PRETRAINED_MODEL).to(DEVICE, dtype=DTYPE)
73
  except Exception as e:
74
+ logger.error(f"Failed to load TripoSG models: {str(e)}\n{traceback.format_exc()}")
75
  raise
76
 
77
  try:
78
+ NUM_VIEWS = 4 # Reduced for memory efficiency
 
79
  from inference_ig2mv_sdxl import prepare_pipeline, preprocess_image, remove_bg
80
  from mvadapter.utils import get_orthogonal_camera, tensor_to_image, make_image_grid
81
  from mvadapter.utils.render import NVDiffRastContextWrapper, load_mesh, render
 
92
  )
93
  birefnet = AutoModelForImageSegmentation.from_pretrained(
94
  "ZhengPeng7/BiRefNet", trust_remote_code=True
95
+ ).to(DEVICE, dtype=DTYPE)
96
  transform_image = transforms.Compose(
97
  [
98
+ transforms.Resize((512, 512)), # Reduced resolution
99
  transforms.ToTensor(),
100
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
101
  ]
102
  )
103
  remove_bg_fn = lambda x: remove_bg(x, birefnet, transform_image, DEVICE)
104
  except Exception as e:
105
+ logger.error(f"Failed to load MV-Adapter models: {str(e)}\n{traceback.format_exc()}")
106
  raise
107
 
108
  try:
 
111
  if not os.path.exists("checkpoints/big-lama.pt"):
112
  subprocess.run("wget -P checkpoints/ https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt", shell=True, check=True)
113
  except Exception as e:
114
+ logger.error(f"Failed to download checkpoints: {str(e)}\n{traceback.format_exc()}")
115
  raise
116
 
117
+ def log_gpu_memory():
118
+ if torch.cuda.is_available():
119
+ allocated = torch.cuda.memory_allocated() / 1024**3
120
+ reserved = torch.cuda.memory_reserved() / 1024**3
121
+ logger.info(f"GPU Memory: Allocated {allocated:.2f} GB, Reserved {reserved:.2f} GB")
122
+
123
  def get_random_hex():
124
  random_bytes = os.urandom(8)
125
  random_hex = random_bytes.hex()
126
  return random_hex
127
 
128
+ def retry_on_failure(func, max_attempts=3, delay=1):
129
+ for attempt in range(max_attempts):
130
+ try:
131
+ return func()
132
+ except RuntimeError as e:
133
+ logger.warning(f"Attempt {attempt + 1} failed: {str(e)}\n{traceback.format_exc()}")
134
+ if attempt == max_attempts - 1:
135
+ raise
136
+ time.sleep(delay)
137
+
138
+ @spaces.GPU(duration=2)
139
+ @torch.no_grad()
140
+ def run_segmentation(image):
 
 
 
 
 
 
 
 
 
 
 
 
141
  try:
142
+ log_gpu_memory()
143
+ if isinstance(image, dict):
144
+ image_path = image.get("path") or image.get("url")
145
+ if not image_path:
146
+ raise ValueError("Invalid image input: no path or URL provided")
147
+ if image_path.startswith("http"):
148
+ temp_image_path = os.path.join(TMP_DIR, f"input_{get_random_hex()}.png")
149
+ image_path = download_image(image_path, temp_image_path)
150
+ elif isinstance(image, str) and image.startswith("http"):
151
+ temp_image_path = os.path.join(TMP_DIR, f"input_{get_random_hex()}.png")
152
+ image_path = download_image(image, temp_image_path)
153
+ else:
154
+ image_path = image
155
+ if not isinstance(image, (str, bytes)) or (isinstance(image, str) and not os.path.exists(image)):
156
+ raise ValueError(f"Expected str (path/URL), bytes, or FileData dict, got {type(image)}")
157
+
158
+ with autocast():
159
+ image_seg = prepare_image(image_path, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net)
160
+ rmbg_net.to("cpu")
161
+ torch.cuda.empty_cache()
162
+ log_gpu_memory()
163
+ return image_seg
164
  except Exception as e:
165
+ logger.error(f"Error in run_segmentation: {str(e)}\n{traceback.format_exc()}")
166
+ raise
167
+
168
+ @spaces.GPU(duration=3)
169
+ @torch.no_grad()
170
+ def image_to_3d(image, seed, num_inference_steps=30, guidance_scale=7.0, simplify=True, target_face_num=DEFAULT_FACE_NUMBER, req=None):
 
 
 
 
 
 
 
 
171
  try:
172
+ log_gpu_memory()
173
+ triposg_pipe.to(DEVICE, dtype=DTYPE)
174
+ with autocast():
175
+ outputs = triposg_pipe(
176
+ image=image,
177
+ generator=torch.Generator(device=triposg_pipe.device).manual_seed(seed),
178
+ num_inference_steps=num_inference_steps,
179
+ guidance_scale=guidance_scale
180
+ ).samples[0]
 
 
 
 
 
181
  mesh = trimesh.Trimesh(outputs[0].astype(np.float32), np.ascontiguousarray(outputs[1]))
 
182
  if simplify:
 
183
  from utils import simplify_mesh
184
  mesh = simplify_mesh(mesh, target_face_num)
185
+ save_dir = os.path.join(TMP_DIR, str(req.session_hash) if req else "examples")
 
186
  os.makedirs(save_dir, exist_ok=True)
187
  mesh_path = os.path.join(save_dir, f"polygenixai_{get_random_hex()}.glb")
188
  mesh.export(mesh_path)
189
+ triposg_pipe.to("cpu")
 
 
190
  torch.cuda.empty_cache()
191
+ log_gpu_memory()
192
+ return mesh_path
193
+ except Exception as e:
194
+ logger.error(f"Error in image_to_3d: {str(e)}\n{traceback.format_exc()}")
195
+ raise
196
 
197
+ @spaces.GPU(duration=3)
198
+ @torch.no_grad()
199
+ def run_texture(image, mesh_path, seed, req=None):
200
+ try:
201
+ log_gpu_memory()
202
+ height, width = 512, 512
203
  cameras = get_orthogonal_camera(
204
+ elevation_deg=[0, 0, 0, 89.99],
205
  distance=[1.8] * NUM_VIEWS,
206
  left=-0.55,
207
  right=0.55,
208
  bottom=-0.55,
209
  top=0.55,
210
+ azimuth_deg=[x - 90 for x in [0, 90, 180, 180]],
211
  device=DEVICE,
212
  )
213
  ctx = NVDiffRastContextWrapper(device=DEVICE, context_type="cuda")
 
214
  mesh = load_mesh(mesh_path, rescale=True, device=DEVICE)
215
+ with autocast():
216
+ render_out = render(
217
+ ctx,
218
+ mesh,
219
+ cameras,
220
+ height=height,
221
+ width=width,
222
+ render_attr=False,
223
+ normal_background=0.0,
224
+ )
225
  control_images = (
226
  torch.cat(
227
+ [(render_out.pos + 0.5).clamp(0, 1), (render_out.normal / 2 + 0.5).clamp(0, 1)],
 
 
 
228
  dim=-1,
229
  )
230
  .permute(0, 3, 1, 2)
231
  .to(DEVICE)
232
  )
233
+ del render_out
234
  image = Image.open(image)
235
+ birefnet.to(DEVICE, dtype=DTYPE)
236
+ with autocast():
237
+ image = remove_bg_fn(image)
238
+ birefnet.to("cpu")
239
  image = preprocess_image(image, height, width)
240
+ pipe_kwargs = {"generator": torch.Generator(device=DEVICE).manual_seed(seed)} if seed != -1 else {}
241
+ mv_adapter_pipe.to(DEVICE, dtype=DTYPE)
242
+ Tijdens het genereren van de code is er een probleem opgetreden dat de voltooiing heeft onderbroken. De code is incompleet en eindigt abrupt. Hier is de gedeeltelijk gegenereerde code tot aan het punt van onderbreking:
243
 
244
+ <xaiArtifact artifact_id="639c400c-2c7c-4b65-a385-eeaa3fdd5602" artifact_version_id="167946b5-d0b3-4e41-92c2-87163e0ff287" title="app.py" contentType="text/python">
245
+ import spaces
246
+ import os
247
+ import gradio as gr
248
+ import numpy as np
249
+ import torch
250
+ from torch.cuda.amp import autocast
251
+ import trimesh
252
+ import random
253
+ from PIL import Image
254
+ from transformers import AutoModelForImageSegmentation
255
+ from torchvision import transforms
256
+ from huggingface_hub import hf_hub_download, snapshot_download
257
+ import subprocess
258
+ import shutil
259
+ import base64
260
+ import logging
261
+ import time
262
+ import traceback
263
+ import requests
264
 
265
+ # Set up logging
266
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
267
+ logger = logging.getLogger(__name__)
 
 
268
 
269
+ # Install additional dependencies
270
+ try:
271
+ subprocess.run("pip install spandrel==0.4.1 --no-deps", shell=True, check=True)
272
+ except Exception as e:
273
+ logger.error(f"Failed to install spandrel: {str(e)}\n{traceback.format_exc()}")
274
+ raise
275
 
276
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
277
+ DTYPE = torch.float16
 
 
 
 
 
 
 
 
278
 
279
+ logger.info(f"Using device: {DEVICE}")
 
 
 
 
280
 
281
+ DEFAULT_FACE_NUMBER = 20000 # Reduced for memory efficiency
282
+ MAX_SEED = np.iinfo(np.int32).max
283
+ TRIPOSG_REPO_URL = "https://github.com/VAST-AI-Research/TripoSG.git"
284
+ MV_ADAPTER_REPO_URL = "https://github.com/huanngzh/MV-Adapter.git"
 
 
 
 
 
285
 
286
+ RMBG_PRETRAINED_MODEL = "checkpoints/RMBG-1.4"
287
+ TRIPOSG_PRETRAINED_MODEL = "checkpoints/TripoSG"
 
 
 
 
 
 
 
 
 
 
 
288
 
289
+ TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp")
290
+ os.makedirs(TMP_DIR, exist_ok=True)
 
 
 
 
 
291
 
292
+ TRIPOSG_CODE_DIR = "./triposg"
293
+ if not os.path.exists(TRIPOSG_CODE_DIR):
294
+ logger.info(f"Cloning TripoSG repository to {TRIPOSG_CODE_DIR}")
295
+ os.system(f"git clone {TRIPOSG_REPO_URL} {TRIPOSG_CODE_DIR}")
 
 
 
296
 
297
+ MV_ADAPTER_CODE_DIR = "./mv_adapter"
298
+ if not os.path.exists(MV_ADAPTER_CODE_DIR):
299
+ logger.info(f"Cloning MV-Adapter repository to {MV_ADAPTER_CODE_DIR}")
300
+ os.system(f"git clone {MV_ADAPTER_REPO_URL} {MV_ADAPTER_CODE_DIR} && cd {MV_ADAPTER_CODE_DIR} && git checkout 7d37a97e9bc223cdb8fd26a76bd8dd46504c7c3d")
 
 
 
 
 
 
301
 
302
+ import sys
303
+ sys.path.append(TRIPOSG_CODE_DIR)
304
+ sys.path.append(os.path.join(TRIPOSG_CODE_DIR, "scripts"))
305
+ sys.path.append(MV_ADAPTER_CODE_DIR)
306
+ sys.path.append(os.path.join(MV_ADAPTER_CODE_DIR, "scripts"))
 
 
 
307
 
308
+ try:
309
+ from image_process import prepare_image
310
+ from briarmbg import BriaRMBG
311
+ snapshot_download("briaai/RMBG-1.4", local_dir=RMBG_PRETRAINED_MODEL)
312
+ rmbg_net = BriaRMBG.from_pretrained(RMBG_PRETRAINED_MODEL).to(DEVICE, dtype=DTYPE)
313
+ rmbg_net.eval()
314
+ from triposg.pipelines.pipeline_triposg import TripoSGPipeline
315
+ snapshot_download("VAST-AI/TripoSG", local_dir=TRIPOSG_PRETRAINED_MODEL)
316
+ triposg_pipe = TripoSGPipeline.from_pretrained(TRIPOSG_PRETRAINED_MODEL).to(DEVICE, dtype=DTYPE)
317
+ except Exception as e:
318
+ logger.error(f"Failed to load TripoSG models: {str(e)}\n{traceback.format_exc()}")
319
+ raise
320
 
321
+ try:
322
+ NUM_VIEWS = 4 # Reduced for memory efficiency
323
+ from inference_ig2mv_sdxl import prepare_pipeline, preprocess_image, remove_bg
324
+ from mvadapter.utils import get_orthogonal_camera, tensor_to_image, make_image_grid
325
+ from mvadapter.utils.render import NVDiffRastContextWrapper, load_mesh, render
326
+ mv_adapter_pipe = prepare_pipeline(
327
+ base_model="stabilityai/stable-diffusion-xl-base-1.0",
328
+ vae_model="madebyollin/sdxl-vae-fp16-fix",
329
+ unet_model=None,
330
+ lora_model=None,
331
+ adapter_path="huanngzh/mv-adapter",
332
+ scheduler=None,
333
+ num_views=NUM_VIEWS,
334
+ device=DEVICE,
335
+ dtype=torch.float16,
336
+ )
337
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
338
+ "ZhengPeng7/BiRefNet", trust_remote_code=True
339
+ ).to(DEVICE, dtype=DTYPE)
340
+ transform_image = transforms.Compose(
341
+ [
342
+ transforms.Resize((512, 512)), # Reduced resolution
343
+ transforms.ToTensor(),
344
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
345
+ ]
346
+ )
347
+ remove_bg_fn = lambda x: remove_bg(x, birefnet, transform_image, DEVICE)
348
+ except Exception as e:
349
+ logger.error(f"Failed to load MV-Adapter models: {str(e)}\n{traceback.format_exc()}")
350
+ raise
351
+
352
+ try:
353
+ if not os.path.exists("checkpoints/RealESRGAN_x2plus.pth"):
354
+ hf_hub_download("dtarnow/UPscaler", filename="RealESRGAN_x2plus.pth", local_dir="checkpoints")
355
+ if not os.path.exists("checkpoints/big-lama.pt"):
356
+ subprocess.run("wget -P checkpoints/ https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt", shell=True, check=True)
357
+ except Exception as e:
358
+ logger.error(f"Failed to download checkpoints: {str(e)}\n{traceback.format_exc()}")
359
+ raise
360
 
361
+ def log_gpu_memory():
362
+ if torch.cuda.is_available():
363
+ allocated = torch.cuda.memory_allocated() / 1024**3
364
+ reserved = torch.cuda.memory_reserved() / 1024**3
365
+ logger.info(f"GPU Memory: Allocated {allocated:.2f} GB, Reserved {reserved:.2f} GB")
366
+
367
+ def get_random_hex():
368
+ random_bytes = os.urandom(8)
369
+ random_hex = random_bytes.hex()
370
+ return random_hex
371
+
372
+ def retry_on_failure(func, max_attempts=3, delay=1):
373
+ for attempt in range(max_attempts):
374
+ try:
375
+ return func()
376
+ except RuntimeError as e:
377
+ logger.warning(f"Attempt {attempt + 1} failed: {str(e)}\n{traceback.format_exc()}")
378
+ if attempt == max_attempts - 1:
379
+ raise
380
+ time.sleep(delay)
381
+
382
+ @spaces.GPU(duration=2)
383
  @torch.no_grad()
384
  def run_segmentation(image):
385
  try:
386
+ log_gpu_memory()
 
387
  if isinstance(image, dict):
388
  image_path = image.get("path") or image.get("url")
389
  if not image_path:
 
390
  raise ValueError("Invalid image input: no path or URL provided")
391
  if image_path.startswith("http"):
392
  temp_image_path = os.path.join(TMP_DIR, f"input_{get_random_hex()}.png")
 
397
  else:
398
  image_path = image
399
  if not isinstance(image, (str, bytes)) or (isinstance(image, str) and not os.path.exists(image)):
 
400
  raise ValueError(f"Expected str (path/URL), bytes, or FileData dict, got {type(image)}")
401
 
402
+ with autocast():
403
+ image_seg = prepare_image(image_path, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net)
404
+ rmbg_net.to("cpu")
405
  torch.cuda.empty_cache()
406
+ log_gpu_memory()
407
+ return image_seg
408
  except Exception as e:
409
+ logger.error(f"Error in run_segmentation: {str(e)}\n{traceback.format_exc()}")
410
  raise
411
 
412
+ @spaces.GPU(duration=3)
 
413
  @torch.no_grad()
414
+ def image_to_3d(image, seed, num_inference_steps=30, guidance_scale=7.0, simplify=True, target_face_num=DEFAULT_FACE_NUMBER, req=None):
 
 
 
 
 
 
 
 
415
  try:
416
+ log_gpu_memory()
417
+ triposg_pipe.to(DEVICE, dtype=DTYPE)
418
+ with autocast():
419
+ outputs = triposg_pipe(
420
+ image=image,
421
+ generator=torch.Generator(device=triposg_pipe.device).manual_seed(seed),
422
+ num_inference_steps=num_inference_steps,
423
+ guidance_scale=guidance_scale
424
+ ).samples[0]
 
 
 
 
 
 
 
 
 
 
 
 
425
  mesh = trimesh.Trimesh(outputs[0].astype(np.float32), np.ascontiguousarray(outputs[1]))
 
426
  if simplify:
427
+ from utils import simplify_mesh
428
+ mesh = simplify_mesh(mesh, target_face_num)
429
+ save_dir = os.path.join(TMP_DIR, str(req.session_hash) if req else "examples")
 
 
 
 
 
 
430
  os.makedirs(save_dir, exist_ok=True)
431
  mesh_path = os.path.join(save_dir, f"polygenixai_{get_random_hex()}.glb")
432
  mesh.export(mesh_path)
433
+ triposg_pipe.to("cpu")
 
 
434
  torch.cuda.empty_cache()
435
+ log_gpu_memory()
436
  return mesh_path
437
  except Exception as e:
438
+ logger.error(f"Error in image_to_3d: {str(e)}\n{traceback.format_exc()}")
439
  raise
440
 
441
+ @spaces.GPU(duration=3)
 
442
  @torch.no_grad()
443
+ def run_texture(image, mesh_path, seed, req=None):
444
  try:
445
+ log_gpu_memory()
446
+ height, width = 512, 512
 
 
447
  cameras = get_orthogonal_camera(
448
+ elevation_deg=[0, 0, 0, 89.99],
449
  distance=[1.8] * NUM_VIEWS,
450
  left=-0.55,
451
  right=0.55,
452
  bottom=-0.55,
453
  top=0.55,
454
+ azimuth_deg=[x - 90 for x in [0, 90, 180, 180]],
455
  device=DEVICE,
456
  )
457
  ctx = NVDiffRastContextWrapper(device=DEVICE, context_type="cuda")
 
458
  mesh = load_mesh(mesh_path, rescale=True, device=DEVICE)
459
+ with autocast():
460
+ render_out = render(
461
+ ctx,
462
+ mesh,
463
+ cameras,
464
+ height=height,
465
+ width=width,
466
+ render_attr=False,
467
+ normal_background=0.0,
468
+ )
469
  control_images = (
470
  torch.cat(
471
+ [(render_out.pos + 0.5).clamp(0, 1), (render_out.normal / 2 + 0.5).clamp(0, 1)],
 
 
 
472
  dim=-1,
473
  )
474
  .permute(0, 3, 1, 2)
475
  .to(DEVICE)
476
  )
477
+ del render_out
478
+ image = Image.open(image)
479
+ birefnet.to(DEVICE, dtype=DTYPE)
480
+ with autocast():
481
+ image = remove_bg_fn(image)
482
+ birefnet.to("cpu")
 
 
 
 
 
 
483
  image = preprocess_image(image, height, width)
484
+ pipe_kwargs = {"generator": torch.Generator(device=DEVICE).manual_seed(seed)} if seed != -1 else {}
485
+ mv_adapter_pipe.to(DEVICE, dtype=DTYPE)
486
+ with autocast():
487
+ images = mv_adapter_pipe(
488
+ "high quality",
489
+ height=height,
490
+ width=width,
491
+ num_inference_steps=10,
492
+ guidance_scale=3.0,
493
+ num_images_per_prompt=NUM_VIEWS,
494
+ control_image=control_images,
495
+ control_conditioning_scale=1.0,
496
+ reference_image=image,
497
+ reference_conditioning_scale=1.0,
498
+ negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast",
499
+ cross_attention_kwargs={"scale": 1.0},
500
+ **pipe_kwargs,
501
+ ).images
502
+ mv_adapter_pipe.to("cpu")
503
+ del control_images
504
+ save_dir = os.path.join(TMP_DIR, str(req.session_hash) if req else "examples")
 
 
 
 
505
  os.makedirs(save_dir, exist_ok=True)
506
  mv_image_path = os.path.join(save_dir, f"mv_adapter_{get_random_hex()}.png")
507
  make_image_grid(images, rows=1).save(mv_image_path)
 
508
  from texture import TexturePipeline, ModProcessConfig
509
  texture_pipe = TexturePipeline(
510
  upscaler_ckpt_path="checkpoints/RealESRGAN_x2plus.pth",
511
  inpaint_ckpt_path="checkpoints/big-lama.pt",
512
  device=DEVICE,
513
  )
 
514
  textured_glb_path = texture_pipe(
515
  mesh_path=mesh_path,
516
  save_dir=save_dir,
517
  save_name=f"polygenixai_texture_mesh_{get_random_hex()}.glb",
518
  uv_unwarp=True,
519
+ uv_size=2048,
520
  rgb_path=mv_image_path,
521
  rgb_process_config=ModProcessConfig(view_upscale=True, inpaint_mode="view"),
522
+ camera_azimuth_deg=[x - 90 for x in [0, 90, 180, 180]],
523
  )
524
+ torch.cuda.empty_cache()
525
+ log_gpu_memory()
526
  return textured_glb_path
527
  except Exception as e:
528
+ logger.error(f"Error in run_texture: {str(e)}\n{traceback.format_exc()}")
529
  raise
530
 
531
+ @spaces.GPU(duration=3)
 
532
  @torch.no_grad()
533
+ def run_full(image, seed=0, num_inference_steps=30, guidance_scale=7.0, simplify=True, target_face_num=DEFAULT_FACE_NUMBER, req=None):
534
  try:
535
+ log_gpu_memory()
536
+ image_seg = run_segmentation(image)
537
+ mesh_path = image_to_3d(image_seg, seed, num_inference_steps, guidance_scale, simplify, target_face_num, req)
538
+ textured_glb_path = run_texture(image, mesh_path, seed, req)
539
+ return image_seg, mesh_path, textured_glb_path
540
+ except Exception as e:
541
+ logger.error(f"Error in run_full: {str(e)}\n{traceback.format_exc()}")
542
+ raise
543
+
544
+ def gradio_generate(image, seed=0, num_inference_steps=30, guidance_scale=7.0, simplify=True, target_face_num=DEFAULT_FACE_NUMBER):
545
+ try:
546
+ logger.info("Starting gradio_generate")
547
+ api_key = os.getenv("POLYGENIX_API_KEY", "your-secret-api-key")
548
+ request = gr.Request()
549
+ if not request.headers.get("x-api-key") == api_key:
550
+ logger.error("Invalid API key")
551
+ raise ValueError("Invalid API key")
552
+
553
+ if image.startswith("data:image"):
554
+ logger.info("Processing base64 image")
555
+ base64_string = image.split(",")[1]
556
+ image_data = base64.b64decode(base64_string)
557
  temp_image_path = os.path.join(TMP_DIR, f"input_{get_random_hex()}.png")
558
+ with open(temp_image_path, "wb") as f:
559
+ f.write(image_data)
560
  else:
561
+ temp_image_path = image
562
+ if not os.path.exists(temp_image_path):
563
+ logger.error(f"Image file not found: {temp_image_path}")
564
+ raise ValueError("Invalid or missing image file")
565
 
566
+ image_seg, mesh_path, textured_glb_path = run_full(temp_image_path, seed, num_inference_steps, guidance_scale, simplify, target_face_num, request)
567
  session_hash = os.path.basename(os.path.dirname(textured_glb_path))
568
+ logger.info(f"Generated model at /files/{session_hash}/{os.path.basename(textured_glb_path)}")
569
+ return {"file_url": f"/files/{session_hash}/{os.path.basename(textured_glb_path)}"}
570
+ except Exception as e:
571
+ logger.error(f"Error in gradio_generate: {str(e)}\n{traceback.format_exc()}")
572
+ raise
573
+
574
+ def start_session(req: gr.Request):
575
+ try:
576
+ save_dir = os.path.join(TMP_DIR, str(req.session_hash))
577
+ os.makedirs(save_dir, exist_ok=True)
578
+ logger.info(f"Started session, created directory: {save_dir}")
579
+ except Exception as e:
580
+ logger.error(f"Error in start_session: {str(e)}\n{traceback.format_exc()}")
581
+ raise
582
+
583
+ def end_session(req: gr.Request):
584
+ try:
585
+ save_dir = os.path.join(TMP_DIR, str(req.session_hash))
586
+ shutil.rmtree(save_dir)
587
+ logger.info(f"Ended session, removed directory: {save_dir}")
588
  except Exception as e:
589
+ logger.error(f"Error in end_session: {str(e)}\n{traceback.format_exc()}")
590
+ raise
591
+
592
+ def get_random_seed(randomize_seed, seed):
593
+ try:
594
+ if randomize_seed:
595
+ seed = random.randint(0, MAX_SEED)
596
+ logger.info(f"Generated seed: {seed}")
597
+ return seed
598
+ except Exception as e:
599
+ logger.error(f"Error in get_random_seed: {str(e)}\n{traceback.format_exc()}")
600
+ raise
601
+
602
+ def download_image(url: str, save_path: str) -> str:
603
+ try:
604
+ logger.info(f"Downloading image from {url}")
605
+ response = requests.get(url, stream=True)
606
+ response.raise_for_status()
607
+ with open(save_path, "wb") as f:
608
+ for chunk in response.iter_content(chunk_size=8192):
609
+ f.write(chunk)
610
+ logger.info(f"Saved image to {save_path}")
611
+ return save_path
612
+ except Exception as e:
613
+ logger.error(f"Failed to download image from {url}: {str(e)}\n{traceback.format_exc()}")
614
+ raise
615
+
616
+ @spaces.GPU(duration=3)
617
+ @torch.no_grad()
618
+ def run_full_api(image, seed=0, num_inference_steps=30, guidance_scale=7.0, simplify=True, target_face_num=DEFAULT_FACE_NUMBER, req=None):
619
+ try:
620
+ logger.info("Running run_full_api")
621
+ def execute():
622
+ image_seg, mesh_path, textured_glb_path = run_full(image, seed, num_inference_steps, guidance_scale, simplify, target_face_num, req)
623
+ session_hash = os.path.basename(os.path.dirname(textured_glb_path))
624
+ return {"file_url": f"/files/{session_hash}/{os.path.basename(textured_glb_path)}"}
625
+ return retry_on_failure(execute)
626
+ except Exception as e:
627
+ logger.error(f"Error in run_full_api: {str(e)}\n{traceback.format_exc()}")
628
  raise
629
 
630
  # Define Gradio API endpoint
 
636
  gr.Image(type="filepath", label="Image"),
637
  gr.Number(label="Seed", value=0, precision=0),
638
  gr.Number(label="Inference Steps", value=30, precision=0),
639
+ gr.Number(label="Guidance Scale", value=7.0),
640
  gr.Checkbox(label="Simplify Mesh", value=True),
641
+ gr.Number(label="Target Face Number", value=DEFAULT_FACE_NUMBER, precision=0)
 
 
 
 
 
642
  ],
643
  outputs="json",
644
  api_name="/api/generate"
645
  )
646
  logger.info("Gradio API interface initialized successfully")
647
  except Exception as e:
648
+ logger.error(f"Failed to initialize Gradio API interface: {str(e)}\n{traceback.format_exc()}")
649
+ raise