anvilinteractiv commited on
Commit
e0d7b74
·
verified ·
1 Parent(s): c484979

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -26
app.py CHANGED
@@ -13,6 +13,7 @@ import subprocess
13
  import shutil
14
  import base64
15
  import logging
 
16
 
17
  # Set up logging
18
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@@ -58,7 +59,6 @@ sys.path.append(MV_ADAPTER_CODE_DIR)
58
  sys.path.append(os.path.join(MV_ADAPTER_CODE_DIR, "scripts"))
59
 
60
  try:
61
- # triposg
62
  from image_process import prepare_image
63
  from briarmbg import BriaRMBG
64
  snapshot_download("briaai/RMBG-1.4", local_dir=RMBG_PRETRAINED_MODEL)
@@ -72,7 +72,6 @@ except Exception as e:
72
  raise
73
 
74
  try:
75
- # mv adapter
76
  NUM_VIEWS = 6
77
  from inference_ig2mv_sdxl import prepare_pipeline, preprocess_image, remove_bg
78
  from mvadapter.utils import get_orthogonal_camera, tensor_to_image, make_image_grid
@@ -144,7 +143,7 @@ def run_full(image: str, seed: int = 0, num_inference_steps: int = 50, guidance_
144
 
145
  torch.cuda.empty_cache()
146
 
147
- height, width = 768, 768
148
  cameras = get_orthogonal_camera(
149
  elevation_deg=[0, 0, 0, 0, 89.99, -89.99],
150
  distance=[1.8] * NUM_VIEWS,
@@ -168,13 +167,7 @@ def run_full(image: str, seed: int = 0, num_inference_steps: int = 50, guidance_
168
  normal_background=0.0,
169
  )
170
  control_images = (
171
- torch.cat(
172
- [
173
- (render_out.pos + 0.5).clamp(0, 1),
174
- (render_out.normal / 2 + 0.5).clamp(0, 1),
175
- ],
176
- dim=-1,
177
- )
178
  .permute(0, 3, 1, 2)
179
  .to(DEVICE)
180
  )
@@ -234,14 +227,12 @@ def run_full(image: str, seed: int = 0, num_inference_steps: int = 50, guidance_
234
  def gradio_generate(image: str, seed: int = 0, num_inference_steps: int = 50, guidance_scale: float = 7.5, simplify: bool = True, target_face_num: int = DEFAULT_FACE_NUMBER):
235
  try:
236
  logger.info("Starting gradio_generate")
237
- # Verify API key
238
  api_key = os.getenv("POLYGENIX_API_KEY", "your-secret-api-key")
239
  request = gr.Request()
240
  if not request.headers.get("x-api-key") == api_key:
241
  logger.error("Invalid API key")
242
  raise ValueError("Invalid API key")
243
 
244
- # Handle base64 image or file path
245
  if image.startswith("data:image"):
246
  logger.info("Processing base64 image")
247
  base64_string = image.split(",")[1]
@@ -291,9 +282,7 @@ def get_random_seed(randomize_seed, seed):
291
  logger.error(f"Error in get_random_seed: {str(e)}")
292
  raise
293
 
294
-
295
  def download_image(url: str, save_path: str) -> str:
296
- """Download an image from a URL and save it locally."""
297
  try:
298
  logger.info(f"Downloading image from {url}")
299
  response = requests.get(url, stream=True)
@@ -312,7 +301,6 @@ def download_image(url: str, save_path: str) -> str:
312
  def run_segmentation(image):
313
  try:
314
  logger.info("Running segmentation")
315
- # Handle FileData dict or URL
316
  if isinstance(image, dict):
317
  image_path = image.get("path") or image.get("url")
318
  if not image_path:
@@ -340,7 +328,7 @@ def run_segmentation(image):
340
  @spaces.GPU(duration=5)
341
  @torch.no_grad()
342
  def image_to_3d(
343
- image, # Changed to accept FileData dict or PIL Image
344
  seed: int,
345
  num_inference_steps: int,
346
  guidance_scale: float,
@@ -350,7 +338,6 @@ def image_to_3d(
350
  ):
351
  try:
352
  logger.info("Running image_to_3d")
353
- # Handle FileData dict from gradio_client
354
  if isinstance(image, dict):
355
  image_path = image.get("path") or image.get("url")
356
  if not image_path:
@@ -396,7 +383,7 @@ def image_to_3d(
396
  def run_texture(image: Image, mesh_path: str, seed: int, req: gr.Request):
397
  try:
398
  logger.info("Running texture generation")
399
- height, width = 768, 768
400
  cameras = get_orthogonal_camera(
401
  elevation_deg=[0, 0, 0, 0, 89.99, -89.99],
402
  distance=[1.8] * NUM_VIEWS,
@@ -420,13 +407,7 @@ def run_texture(image: Image, mesh_path: str, seed: int, req: gr.Request):
420
  normal_background=0.0,
421
  )
422
  control_images = (
423
- torch.cat(
424
- [
425
- (render_out.pos + 0.5).clamp(0, 1),
426
- (render_out.normal / 2 + 0.5).clamp(0, 1),
427
- ],
428
- dim=-1,
429
- )
430
  .permute(0, 3, 1, 2)
431
  .to(DEVICE)
432
  )
@@ -490,7 +471,6 @@ def run_texture(image: Image, mesh_path: str, seed: int, req: gr.Request):
490
  def run_full_api(image, seed: int = 0, num_inference_steps: int = 50, guidance_scale: float = 7.5, simplify: bool = True, target_face_num: int = DEFAULT_FACE_NUMBER, req: gr.Request = None):
491
  try:
492
  logger.info("Running run_full_api")
493
- # Handle FileData dict or URL
494
  if isinstance(image, dict):
495
  image_path = image.get("path") or image.get("url")
496
  if not image_path:
 
13
  import shutil
14
  import base64
15
  import logging
16
+ import requests
17
 
18
  # Set up logging
19
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
59
  sys.path.append(os.path.join(MV_ADAPTER_CODE_DIR, "scripts"))
60
 
61
  try:
 
62
  from image_process import prepare_image
63
  from briarmbg import BriaRMBG
64
  snapshot_download("briaai/RMBG-1.4", local_dir=RMBG_PRETRAINED_MODEL)
 
72
  raise
73
 
74
  try:
 
75
  NUM_VIEWS = 6
76
  from inference_ig2mv_sdxl import prepare_pipeline, preprocess_image, remove_bg
77
  from mvadapter.utils import get_orthogonal_camera, tensor_to_image, make_image_grid
 
143
 
144
  torch.cuda.empty_cache()
145
 
146
+ height, width = 1920, 1080 # Set resolution for YouTube Shorts, TikTok, Reels
147
  cameras = get_orthogonal_camera(
148
  elevation_deg=[0, 0, 0, 0, 89.99, -89.99],
149
  distance=[1.8] * NUM_VIEWS,
 
167
  normal_background=0.0,
168
  )
169
  control_images = (
170
+ (render_out.pos + 0.5).clamp(0, 1) # Use only position map, remove normal map
 
 
 
 
 
 
171
  .permute(0, 3, 1, 2)
172
  .to(DEVICE)
173
  )
 
227
  def gradio_generate(image: str, seed: int = 0, num_inference_steps: int = 50, guidance_scale: float = 7.5, simplify: bool = True, target_face_num: int = DEFAULT_FACE_NUMBER):
228
  try:
229
  logger.info("Starting gradio_generate")
 
230
  api_key = os.getenv("POLYGENIX_API_KEY", "your-secret-api-key")
231
  request = gr.Request()
232
  if not request.headers.get("x-api-key") == api_key:
233
  logger.error("Invalid API key")
234
  raise ValueError("Invalid API key")
235
 
 
236
  if image.startswith("data:image"):
237
  logger.info("Processing base64 image")
238
  base64_string = image.split(",")[1]
 
282
  logger.error(f"Error in get_random_seed: {str(e)}")
283
  raise
284
 
 
285
  def download_image(url: str, save_path: str) -> str:
 
286
  try:
287
  logger.info(f"Downloading image from {url}")
288
  response = requests.get(url, stream=True)
 
301
  def run_segmentation(image):
302
  try:
303
  logger.info("Running segmentation")
 
304
  if isinstance(image, dict):
305
  image_path = image.get("path") or image.get("url")
306
  if not image_path:
 
328
  @spaces.GPU(duration=5)
329
  @torch.no_grad()
330
  def image_to_3d(
331
+ image,
332
  seed: int,
333
  num_inference_steps: int,
334
  guidance_scale: float,
 
338
  ):
339
  try:
340
  logger.info("Running image_to_3d")
 
341
  if isinstance(image, dict):
342
  image_path = image.get("path") or image.get("url")
343
  if not image_path:
 
383
  def run_texture(image: Image, mesh_path: str, seed: int, req: gr.Request):
384
  try:
385
  logger.info("Running texture generation")
386
+ height, width = 1920, 1080 # Set resolution for YouTube Shorts, TikTok, Reels
387
  cameras = get_orthogonal_camera(
388
  elevation_deg=[0, 0, 0, 0, 89.99, -89.99],
389
  distance=[1.8] * NUM_VIEWS,
 
407
  normal_background=0.0,
408
  )
409
  control_images = (
410
+ (render_out.pos + 0.5).clamp(0, 1) # Use only position map, remove normal map
 
 
 
 
 
 
411
  .permute(0, 3, 1, 2)
412
  .to(DEVICE)
413
  )
 
471
  def run_full_api(image, seed: int = 0, num_inference_steps: int = 50, guidance_scale: float = 7.5, simplify: bool = True, target_face_num: int = DEFAULT_FACE_NUMBER, req: gr.Request = None):
472
  try:
473
  logger.info("Running run_full_api")
 
474
  if isinstance(image, dict):
475
  image_path = image.get("path") or image.get("url")
476
  if not image_path: