lionelgarnier commited on
Commit
d6da646
·
1 Parent(s): 5ed71f8

add Trellis pipeline integration for 3D model generation and improve error handling

Browse files
Files changed (1) hide show
  1. app.py +80 -52
app.py CHANGED
@@ -50,6 +50,7 @@ os.makedirs(TMP_DIR, exist_ok=True)
50
 
51
  _text_gen_pipeline = None
52
  _image_gen_pipeline = None
 
53
 
54
 
55
  def start_session(req: gr.Request):
@@ -107,6 +108,25 @@ def get_text_gen_pipeline():
107
  return None
108
  return _text_gen_pipeline
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  @spaces.GPU()
111
  def refine_prompt(prompt, system_prompt=DEFAULT_SYSTEM_PROMPT, progress=gr.Progress()):
112
  text_gen = get_text_gen_pipeline()
@@ -217,11 +237,29 @@ def preload_models():
217
  print("Preloading models...")
218
  text_success = get_text_gen_pipeline() is not None
219
  image_success = get_image_gen_pipeline() is not None
220
- success = text_success and image_success
 
 
221
 
222
- status = "Models preloaded successfully!" if success else "Error preloading models"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  print(status)
224
- return success
225
 
226
 
227
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
@@ -274,46 +312,40 @@ def image_to_3d(
274
  slat_sampling_steps: int,
275
  req: gr.Request,
276
  ) -> Tuple[dict, str]:
277
- """
278
- Convert an image to a 3D model.
279
-
280
- Args:
281
- image (Image.Image): The input image.
282
- seed (int): The random seed.
283
- ss_guidance_strength (float): The guidance strength for sparse structure generation.
284
- ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
285
- slat_guidance_strength (float): The guidance strength for structured latent generation.
286
- slat_sampling_steps (int): The number of sampling steps for structured latent generation.
287
-
288
- Returns:
289
- dict: The information of the generated 3D model.
290
- str: The path to the video of the 3D model.
291
- """
292
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
293
-
294
- outputs = pipeline.run(
295
- image,
296
- seed=seed,
297
- formats=["gaussian", "mesh"],
298
- preprocess_image=False,
299
- sparse_structure_sampler_params={
300
- "steps": ss_sampling_steps,
301
- "cfg_strength": ss_guidance_strength,
302
- },
303
- slat_sampler_params={
304
- "steps": slat_sampling_steps,
305
- "cfg_strength": slat_guidance_strength,
306
- },
307
- )
308
 
309
- video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
310
- video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
311
- video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
312
- video_path = os.path.join(user_dir, 'sample.mp4')
313
- imageio.mimsave(video_path, video, fps=15)
314
- state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
315
- torch.cuda.empty_cache()
316
- return state, video_path
 
 
 
317
 
318
 
319
  @spaces.GPU(duration=90)
@@ -382,8 +414,8 @@ def process_example_pipeline(example_prompt, system_prompt=DEFAULT_SYSTEM_PROMPT
382
  def create_interface():
383
  # Preload models if needed
384
  if PRELOAD_MODELS:
385
- models_loaded = preload_models()
386
- model_status = "✅ Models loaded successfully!" if models_loaded else "⚠️ Error loading models"
387
  else:
388
  model_status = "ℹ️ Models will be loaded on demand"
389
 
@@ -520,14 +552,10 @@ def create_interface():
520
 
521
 
522
  if __name__ == "__main__":
523
- # Initialize the Trellis pipeline before creating the interface
524
- pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
525
- pipeline.cuda()
526
- try:
527
- # Preload rembg
528
- pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8)))
529
- except Exception as e:
530
- print(f"Warning when preloading rembg: {e}")
531
 
532
  demo = create_interface()
533
  demo.launch()
 
50
 
51
  _text_gen_pipeline = None
52
  _image_gen_pipeline = None
53
+ _trellis_pipeline = None
54
 
55
 
56
  def start_session(req: gr.Request):
 
108
  return None
109
  return _text_gen_pipeline
110
 
111
+ @spaces.GPU()
112
+ def get_trellis_pipeline():
113
+ global _trellis_pipeline
114
+ if _trellis_pipeline is None:
115
+ try:
116
+ print("Loading Trellis pipeline...")
117
+ _trellis_pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
118
+ _trellis_pipeline.cuda()
119
+
120
+ # Preload rembg by processing a small test image
121
+ try:
122
+ _trellis_pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8)))
123
+ except Exception as e:
124
+ print(f"Warning when preloading rembg: {e}")
125
+ except Exception as e:
126
+ print(f"Error loading Trellis pipeline: {e}")
127
+ return None
128
+ return _trellis_pipeline
129
+
130
  @spaces.GPU()
131
  def refine_prompt(prompt, system_prompt=DEFAULT_SYSTEM_PROMPT, progress=gr.Progress()):
132
  text_gen = get_text_gen_pipeline()
 
237
  print("Preloading models...")
238
  text_success = get_text_gen_pipeline() is not None
239
  image_success = get_image_gen_pipeline() is not None
240
+ trellis_success = get_trellis_pipeline() is not None
241
+
242
+ success = text_success and image_success and trellis_success
243
 
244
+ status_parts = []
245
+ if text_success:
246
+ status_parts.append("Mistral ✓")
247
+ else:
248
+ status_parts.append("Mistral ✗")
249
+
250
+ if image_success:
251
+ status_parts.append("Flux ✓")
252
+ else:
253
+ status_parts.append("Flux ✗")
254
+
255
+ if trellis_success:
256
+ status_parts.append("Trellis ✓")
257
+ else:
258
+ status_parts.append("Trellis ✗")
259
+
260
+ status = f"Models loaded: {', '.join(status_parts)}"
261
  print(status)
262
+ return success, status
263
 
264
 
265
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
 
312
  slat_sampling_steps: int,
313
  req: gr.Request,
314
  ) -> Tuple[dict, str]:
315
+ try:
316
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
317
+
318
+ # Get the pipeline using the getter function
319
+ pipeline = get_trellis_pipeline()
320
+ if pipeline is None:
321
+ return None, "Trellis pipeline is unavailable."
322
+
323
+ outputs = pipeline.run(
324
+ image,
325
+ seed=seed,
326
+ formats=["gaussian", "mesh"],
327
+ preprocess_image=False,
328
+ sparse_structure_sampler_params={
329
+ "steps": ss_sampling_steps,
330
+ "cfg_strength": ss_guidance_strength,
331
+ },
332
+ slat_sampler_params={
333
+ "steps": slat_sampling_steps,
334
+ "cfg_strength": slat_guidance_strength,
335
+ },
336
+ )
 
 
 
 
 
 
 
 
 
337
 
338
+ video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
339
+ video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
340
+ video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
341
+ video_path = os.path.join(user_dir, 'sample.mp4')
342
+ imageio.mimsave(video_path, video, fps=15)
343
+ state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
344
+ torch.cuda.empty_cache()
345
+ return state, video_path
346
+ except Exception as e:
347
+ print(f"Error in image_to_3d: {str(e)}")
348
+ return None, f"Error generating 3D model: {str(e)}"
349
 
350
 
351
  @spaces.GPU(duration=90)
 
414
  def create_interface():
415
  # Preload models if needed
416
  if PRELOAD_MODELS:
417
+ model_success, model_status_details = preload_models()
418
+ model_status = f"✅ {model_status_details}" if model_success else f"⚠️ {model_status_details}"
419
  else:
420
  model_status = "ℹ️ Models will be loaded on demand"
421
 
 
552
 
553
 
554
  if __name__ == "__main__":
555
+ # Initialize models if PRELOAD_MODELS is True
556
+ if PRELOAD_MODELS:
557
+ success, status = preload_models()
558
+ print(status)
 
 
 
 
559
 
560
  demo = create_interface()
561
  demo.launch()