lionelgarnier commited on
Commit
593768d
·
1 Parent(s): d6da646

refactor interface to improve session management and state handling for 3D model extraction

Browse files
Files changed (1) hide show
  1. app.py +43 -6
app.py CHANGED
@@ -219,7 +219,6 @@ def infer(prompt, seed=DEFAULT_SEED,
219
  return None, f"Error generating image: {str(e)}"
220
 
221
 
222
- # Format: [prompt, system_prompt]
223
  examples = [
224
  "a backpack for kids, flower style",
225
  "medieval flip flops",
@@ -420,7 +419,14 @@ def create_interface():
420
  model_status = "ℹ️ Models will be loaded on demand"
421
 
422
  with gr.Blocks(css=css) as demo:
 
 
 
 
423
  gr.Info(model_status)
 
 
 
424
 
425
  with gr.Column(elem_id="col-container"):
426
  gr.Markdown("# Text to Product\nUsing Mistral-7B-Instruct-v0.3 + FLUX.1-dev + Trellis")
@@ -511,6 +517,13 @@ def create_interface():
511
  mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
512
  texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
513
 
 
 
 
 
 
 
 
514
  output_buf = gr.State()
515
 
516
  # Examples section - simplified version that only updates the prompt fields
@@ -537,17 +550,41 @@ def create_interface():
537
  outputs=[generated_image, message_box]
538
  )
539
 
 
540
  gr.on(
541
  triggers=[gen3d_button.click],
542
  fn=image_to_3d,
543
  inputs=[generated_image, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
544
- outputs=[output_buf, video_output],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
545
  )
546
 
547
- # Handlers
548
- demo.load(start_session)
549
- demo.unload(end_session)
550
-
551
  return demo
552
 
553
 
 
219
  return None, f"Error generating image: {str(e)}"
220
 
221
 
 
222
  examples = [
223
  "a backpack for kids, flower style",
224
  "medieval flip flops",
 
419
  model_status = "ℹ️ Models will be loaded on demand"
420
 
421
  with gr.Blocks(css=css) as demo:
422
+ # Set up session management
423
+ demo.load(start_session)
424
+ demo.unload(end_session)
425
+
426
  gr.Info(model_status)
427
+
428
+ # State for storing 3D model data - moved to the top level inside Blocks context
429
+ output_state = gr.State(None)
430
 
431
  with gr.Column(elem_id="col-container"):
432
  gr.Markdown("# Text to Product\nUsing Mistral-7B-Instruct-v0.3 + FLUX.1-dev + Trellis")
 
517
  mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
518
  texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
519
 
520
+ with gr.Row():
521
+ extract_glb_btn = gr.Button("Extract GLB", interactive=False)
522
+ extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
523
+ gr.Markdown("""
524
+ *NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
525
+ """)
526
+
527
  output_buf = gr.State()
528
 
529
  # Examples section - simplified version that only updates the prompt fields
 
550
  outputs=[generated_image, message_box]
551
  )
552
 
553
+ # Updated to use output_state
554
  gr.on(
555
  triggers=[gen3d_button.click],
556
  fn=image_to_3d,
557
  inputs=[generated_image, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
558
+ outputs=[output_state, video_output],
559
+ ).success(
560
+ # Update button states after successful 3D generation
561
+ lambda: (gr.Button.update(interactive=True), gr.Button.update(interactive=True)),
562
+ outputs=[extract_glb_btn, extract_gs_btn]
563
+ )
564
+
565
+ # Add handlers for GLB and Gaussian extraction
566
+ gr.on(
567
+ triggers=[extract_glb_btn.click],
568
+ fn=extract_glb,
569
+ inputs=[output_state, mesh_simplify, texture_size],
570
+ outputs=[model_output, download_glb]
571
+ ).success(
572
+ lambda path: gr.DownloadButton.update(interactive=True, value=path),
573
+ inputs=[model_output],
574
+ outputs=[download_glb]
575
+ )
576
+
577
+ gr.on(
578
+ triggers=[extract_gs_btn.click],
579
+ fn=extract_gaussian,
580
+ inputs=[output_state],
581
+ outputs=[model_output, download_gs]
582
+ ).success(
583
+ lambda path: gr.DownloadButton.update(interactive=True, value=path),
584
+ inputs=[model_output],
585
+ outputs=[download_gs]
586
  )
587
 
 
 
 
 
588
  return demo
589
 
590