lionelgarnier commited on
Commit
f17d614
·
1 Parent(s): d5ea645

refactor session handling to use session-specific directories and remove unused extraction functions

Browse files
Files changed (1) hide show
  1. app.py +9 -76
app.py CHANGED
@@ -55,13 +55,13 @@ _trellis_pipeline = None
55
 
56
 
57
  def start_session(req: gr.Request):
58
- user_dir = os.path.join(TMP_DIR, "temp_output")
59
- # user_dir = os.path.join(TMP_DIR, str(req.session_hash))
60
  os.makedirs(user_dir, exist_ok=True)
61
 
62
  def end_session(req: gr.Request):
63
- user_dir = os.path.join(TMP_DIR, "temp_output")
64
- # user_dir = os.path.join(TMP_DIR, str(req.session_hash))
65
  shutil.rmtree(user_dir)
66
 
67
  def preprocess_image(image: Image.Image) -> Image.Image:
@@ -321,6 +321,7 @@ def image_to_3d(
321
  ss_sampling_steps: int,
322
  slat_guidance_strength: float,
323
  slat_sampling_steps: int,
 
324
  ) -> Tuple[dict, str]:
325
  try:
326
  # Load the Trellis pipeline
@@ -347,8 +348,9 @@ def image_to_3d(
347
  "cfg_strength": slat_guidance_strength,
348
  },
349
  )
350
- temp_dir = os.path.join(TMP_DIR, "temp_output")
351
-
 
352
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
353
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
354
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
@@ -364,51 +366,6 @@ def image_to_3d(
364
  return None, f"Error generating 3D model: {str(e)}"
365
 
366
 
367
- @spaces.GPU(duration=90)
368
- def extract_glb(
369
- state: dict,
370
- mesh_simplify: float,
371
- texture_size: int,
372
- ) -> Tuple[str, str]:
373
- """
374
- Extract a GLB file from the 3D model.
375
-
376
- Args:
377
- state (dict): The state of the generated 3D model.
378
- mesh_simplify (float): The mesh simplification factor.
379
- texture_size (int): The texture resolution.
380
-
381
- Returns:
382
- str: The path to the extracted GLB file.
383
- """
384
- temp_dir = os.path.join(TMP_DIR, "temp_output")
385
- gs, mesh = unpack_state(state)
386
- glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
387
- glb_path = os.path.join(temp_dir, 'sample.glb')
388
- glb.export(glb_path)
389
- torch.cuda.empty_cache()
390
- return glb_path, glb_path
391
-
392
-
393
- @spaces.GPU
394
- def extract_gaussian(state: dict) -> Tuple[str, str]:
395
- """
396
- Extract a Gaussian file from the 3D model.
397
-
398
- Args:
399
- state (dict): The state of the generated 3D model.
400
-
401
- Returns:
402
- str: The path to the extracted Gaussian file.
403
- """
404
- temp_dir = os.path.join(TMP_DIR, "temp_output")
405
- gs, _ = unpack_state(state)
406
- gaussian_path = os.path.join(temp_dir, 'sample.ply')
407
- gs.save_ply(gaussian_path)
408
- torch.cuda.empty_cache()
409
- return gaussian_path, gaussian_path
410
-
411
-
412
  # Create a combined function that handles the whole pipeline from example to image
413
  # This version gets the parameters from the UI components
414
  @spaces.GPU()
@@ -465,17 +422,10 @@ def create_interface():
465
  visual_button = gr.Button("Create visual with Flux")
466
  generated_image = gr.Image(show_label=False, format="png", image_mode="RGBA", type="pil", height=300)
467
 
468
- preprocessed_button = gr.Button("Preprocess image")
469
- preprocessed_image = gr.Image(show_label=False)
470
-
471
  gen3d_button = gr.Button("Create 3D visual with Trellis")
472
  video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
473
  # model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=10.0, height=300)
474
 
475
- with gr.Row():
476
- download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
477
- download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
478
-
479
  message_box = gr.Textbox(
480
  label="Status Messages",
481
  interactive=False,
@@ -531,18 +481,8 @@ def create_interface():
531
  slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
532
  slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
533
 
534
- with gr.Tab("GLB Extraction Settings"):
535
- mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
536
- texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
537
-
538
- with gr.Row():
539
- extract_glb_btn = gr.Button("Extract GLB", interactive=False)
540
- extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
541
- gr.Markdown("""
542
- *NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
543
- """)
544
 
545
- output_buf = gr.State()
546
 
547
  gr.Examples(
548
  examples=examples,
@@ -566,13 +506,6 @@ def create_interface():
566
  outputs=[generated_image, message_box]
567
  )
568
 
569
- gr.on(
570
- triggers=[preprocessed_button.click],
571
- fn=preprocess_image,
572
- inputs=[generated_image],
573
- outputs=[preprocessed_image]
574
- )
575
-
576
  gr.on(
577
  triggers=[gen3d_button.click],
578
  fn=image_to_3d,
 
55
 
56
 
57
  def start_session(req: gr.Request):
58
+ # user_dir = os.path.join(TMP_DIR, "temp_output")
59
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
60
  os.makedirs(user_dir, exist_ok=True)
61
 
62
  def end_session(req: gr.Request):
63
+ # user_dir = os.path.join(TMP_DIR, "temp_output")
64
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
65
  shutil.rmtree(user_dir)
66
 
67
  def preprocess_image(image: Image.Image) -> Image.Image:
 
321
  ss_sampling_steps: int,
322
  slat_guidance_strength: float,
323
  slat_sampling_steps: int,
324
+ req: gr.Request,
325
  ) -> Tuple[dict, str]:
326
  try:
327
  # Load the Trellis pipeline
 
348
  "cfg_strength": slat_guidance_strength,
349
  },
350
  )
351
+ # temp_dir = os.path.join(TMP_DIR, "temp_output")
352
+ temp_dir = os.path.join(TMP_DIR, str(req.session_hash))
353
+
354
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
355
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
356
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
 
366
  return None, f"Error generating 3D model: {str(e)}"
367
 
368
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
369
  # Create a combined function that handles the whole pipeline from example to image
370
  # This version gets the parameters from the UI components
371
  @spaces.GPU()
 
422
  visual_button = gr.Button("Create visual with Flux")
423
  generated_image = gr.Image(show_label=False, format="png", image_mode="RGBA", type="pil", height=300)
424
 
 
 
 
425
  gen3d_button = gr.Button("Create 3D visual with Trellis")
426
  video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
427
  # model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=10.0, height=300)
428
 
 
 
 
 
429
  message_box = gr.Textbox(
430
  label="Status Messages",
431
  interactive=False,
 
481
  slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
482
  slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
483
 
 
 
 
 
 
 
 
 
 
 
484
 
485
+ # output_buf = gr.State()
486
 
487
  gr.Examples(
488
  examples=examples,
 
506
  outputs=[generated_image, message_box]
507
  )
508
 
 
 
 
 
 
 
 
509
  gr.on(
510
  triggers=[gen3d_button.click],
511
  fn=image_to_3d,