lionelgarnier commited on
Commit
74f1ee7
·
1 Parent(s): a5cc2b7

switch to deepseek

Browse files
Files changed (1) hide show
  1. app.py +6 -40
app.py CHANGED
@@ -92,9 +92,6 @@ def get_image_gen_pipeline():
92
  torch_dtype=dtype,
93
  ).to(device)
94
 
95
- # Comment these out for now to match the working example
96
- # _image_gen_pipeline.enable_model_cpu_offload()
97
- # _image_gen_pipeline.enable_vae_slicing()
98
  except Exception as e:
99
  print(f"Error loading image generation model: {e}")
100
  return None
@@ -107,14 +104,16 @@ def get_text_gen_pipeline():
107
  try:
108
  device = "cuda" if torch.cuda.is_available() else "cpu"
109
  tokenizer = AutoTokenizer.from_pretrained(
110
- "mistralai/Mistral-7B-Instruct-v0.3",
 
111
  use_fast=True
112
  )
113
  tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token
114
 
115
  _text_gen_pipeline = pipeline(
116
- "text-generation",
117
- model="mistralai/Mistral-7B-Instruct-v0.3",
 
118
  tokenizer=tokenizer,
119
  max_new_tokens=2048,
120
  device=device,
@@ -375,7 +374,6 @@ def create_interface():
375
 
376
  gen3d_button = gr.Button("Create 3D visual with Trellis")
377
  video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
378
- # model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=10.0, height=300)
379
 
380
  message_box = gr.Textbox(
381
  label="Status Messages",
@@ -432,9 +430,6 @@ def create_interface():
432
  slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
433
  slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
434
 
435
-
436
- # output_buf = gr.State()
437
-
438
  gr.Examples(
439
  examples=examples,
440
  fn=process_example_pipeline,
@@ -463,36 +458,7 @@ def create_interface():
463
  inputs=[generated_image, trellis_seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
464
  outputs=[output_state, video_output],
465
  )
466
- # .then(
467
- # # Update button states after successful 3D generation
468
- # lambda: (gr.Button.update(interactive=True), gr.Button.update(interactive=True), "3D model generated successfully"),
469
- # outputs=[extract_glb_btn, extract_gs_btn, message_box]
470
- # )
471
-
472
- # # Add handlers for GLB and Gaussian extraction
473
- # gr.on(
474
- # triggers=[extract_glb_btn.click],
475
- # fn=extract_glb,
476
- # inputs=[output_state, mesh_simplify, texture_size],
477
- # outputs=[model_output, download_glb]
478
- # ).then(
479
- # lambda path: (gr.DownloadButton.update(interactive=True, value=path), "GLB extraction completed"),
480
- # inputs=[model_output],
481
- # outputs=[download_glb, message_box]
482
- # )
483
-
484
- # gr.on(
485
- # triggers=[extract_gs_btn.click],
486
- # fn=extract_gaussian,
487
- # inputs=[output_state],
488
- # outputs=[model_output, download_gs]
489
- # ).then(
490
- # lambda path: (gr.DownloadButton.update(interactive=True, value=path), "Gaussian extraction completed"),
491
- # inputs=[model_output],
492
- # outputs=[download_gs, message_box]
493
- # )
494
-
495
- # Don't put any demo.* method calls here outside the Blocks context
496
  return demo
497
 
498
  if __name__ == "__main__":
 
92
  torch_dtype=dtype,
93
  ).to(device)
94
 
 
 
 
95
  except Exception as e:
96
  print(f"Error loading image generation model: {e}")
97
  return None
 
104
  try:
105
  device = "cuda" if torch.cuda.is_available() else "cpu"
106
  tokenizer = AutoTokenizer.from_pretrained(
107
+ # "mistralai/Mistral-7B-Instruct-v0.3",
108
+ "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
109
  use_fast=True
110
  )
111
  tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token
112
 
113
  _text_gen_pipeline = pipeline(
114
+ # "text-generation",
115
+ # model="mistralai/Mistral-7B-Instruct-v0.3",
116
+ model="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
117
  tokenizer=tokenizer,
118
  max_new_tokens=2048,
119
  device=device,
 
374
 
375
  gen3d_button = gr.Button("Create 3D visual with Trellis")
376
  video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
 
377
 
378
  message_box = gr.Textbox(
379
  label="Status Messages",
 
430
  slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
431
  slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
432
 
 
 
 
433
  gr.Examples(
434
  examples=examples,
435
  fn=process_example_pipeline,
 
458
  inputs=[generated_image, trellis_seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
459
  outputs=[output_state, video_output],
460
  )
461
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
462
  return demo
463
 
464
  if __name__ == "__main__":