Spaces:
Runtime error
Runtime error
lionelgarnier
commited on
Commit
·
d6da646
1
Parent(s):
5ed71f8
add Trellis pipeline integration for 3D model generation and improve error handling
Browse files
app.py
CHANGED
@@ -50,6 +50,7 @@ os.makedirs(TMP_DIR, exist_ok=True)
|
|
50 |
|
51 |
_text_gen_pipeline = None
|
52 |
_image_gen_pipeline = None
|
|
|
53 |
|
54 |
|
55 |
def start_session(req: gr.Request):
|
@@ -107,6 +108,25 @@ def get_text_gen_pipeline():
|
|
107 |
return None
|
108 |
return _text_gen_pipeline
|
109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
@spaces.GPU()
|
111 |
def refine_prompt(prompt, system_prompt=DEFAULT_SYSTEM_PROMPT, progress=gr.Progress()):
|
112 |
text_gen = get_text_gen_pipeline()
|
@@ -217,11 +237,29 @@ def preload_models():
|
|
217 |
print("Preloading models...")
|
218 |
text_success = get_text_gen_pipeline() is not None
|
219 |
image_success = get_image_gen_pipeline() is not None
|
220 |
-
|
|
|
|
|
221 |
|
222 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
223 |
print(status)
|
224 |
-
return success
|
225 |
|
226 |
|
227 |
def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
|
@@ -274,46 +312,40 @@ def image_to_3d(
|
|
274 |
slat_sampling_steps: int,
|
275 |
req: gr.Request,
|
276 |
) -> Tuple[dict, str]:
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
sparse_structure_sampler_params={
|
300 |
-
"steps": ss_sampling_steps,
|
301 |
-
"cfg_strength": ss_guidance_strength,
|
302 |
-
},
|
303 |
-
slat_sampler_params={
|
304 |
-
"steps": slat_sampling_steps,
|
305 |
-
"cfg_strength": slat_guidance_strength,
|
306 |
-
},
|
307 |
-
)
|
308 |
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
|
|
|
|
|
|
317 |
|
318 |
|
319 |
@spaces.GPU(duration=90)
|
@@ -382,8 +414,8 @@ def process_example_pipeline(example_prompt, system_prompt=DEFAULT_SYSTEM_PROMPT
|
|
382 |
def create_interface():
|
383 |
# Preload models if needed
|
384 |
if PRELOAD_MODELS:
|
385 |
-
|
386 |
-
model_status = "✅
|
387 |
else:
|
388 |
model_status = "ℹ️ Models will be loaded on demand"
|
389 |
|
@@ -520,14 +552,10 @@ def create_interface():
|
|
520 |
|
521 |
|
522 |
if __name__ == "__main__":
|
523 |
-
# Initialize
|
524 |
-
|
525 |
-
|
526 |
-
|
527 |
-
# Preload rembg
|
528 |
-
pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8)))
|
529 |
-
except Exception as e:
|
530 |
-
print(f"Warning when preloading rembg: {e}")
|
531 |
|
532 |
demo = create_interface()
|
533 |
demo.launch()
|
|
|
50 |
|
51 |
_text_gen_pipeline = None
|
52 |
_image_gen_pipeline = None
|
53 |
+
_trellis_pipeline = None
|
54 |
|
55 |
|
56 |
def start_session(req: gr.Request):
|
|
|
108 |
return None
|
109 |
return _text_gen_pipeline
|
110 |
|
111 |
+
@spaces.GPU()
|
112 |
+
def get_trellis_pipeline():
|
113 |
+
global _trellis_pipeline
|
114 |
+
if _trellis_pipeline is None:
|
115 |
+
try:
|
116 |
+
print("Loading Trellis pipeline...")
|
117 |
+
_trellis_pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
|
118 |
+
_trellis_pipeline.cuda()
|
119 |
+
|
120 |
+
# Preload rembg by processing a small test image
|
121 |
+
try:
|
122 |
+
_trellis_pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8)))
|
123 |
+
except Exception as e:
|
124 |
+
print(f"Warning when preloading rembg: {e}")
|
125 |
+
except Exception as e:
|
126 |
+
print(f"Error loading Trellis pipeline: {e}")
|
127 |
+
return None
|
128 |
+
return _trellis_pipeline
|
129 |
+
|
130 |
@spaces.GPU()
|
131 |
def refine_prompt(prompt, system_prompt=DEFAULT_SYSTEM_PROMPT, progress=gr.Progress()):
|
132 |
text_gen = get_text_gen_pipeline()
|
|
|
237 |
print("Preloading models...")
|
238 |
text_success = get_text_gen_pipeline() is not None
|
239 |
image_success = get_image_gen_pipeline() is not None
|
240 |
+
trellis_success = get_trellis_pipeline() is not None
|
241 |
+
|
242 |
+
success = text_success and image_success and trellis_success
|
243 |
|
244 |
+
status_parts = []
|
245 |
+
if text_success:
|
246 |
+
status_parts.append("Mistral ✓")
|
247 |
+
else:
|
248 |
+
status_parts.append("Mistral ✗")
|
249 |
+
|
250 |
+
if image_success:
|
251 |
+
status_parts.append("Flux ✓")
|
252 |
+
else:
|
253 |
+
status_parts.append("Flux ✗")
|
254 |
+
|
255 |
+
if trellis_success:
|
256 |
+
status_parts.append("Trellis ✓")
|
257 |
+
else:
|
258 |
+
status_parts.append("Trellis ✗")
|
259 |
+
|
260 |
+
status = f"Models loaded: {', '.join(status_parts)}"
|
261 |
print(status)
|
262 |
+
return success, status
|
263 |
|
264 |
|
265 |
def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
|
|
|
312 |
slat_sampling_steps: int,
|
313 |
req: gr.Request,
|
314 |
) -> Tuple[dict, str]:
|
315 |
+
try:
|
316 |
+
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
317 |
+
|
318 |
+
# Get the pipeline using the getter function
|
319 |
+
pipeline = get_trellis_pipeline()
|
320 |
+
if pipeline is None:
|
321 |
+
return None, "Trellis pipeline is unavailable."
|
322 |
+
|
323 |
+
outputs = pipeline.run(
|
324 |
+
image,
|
325 |
+
seed=seed,
|
326 |
+
formats=["gaussian", "mesh"],
|
327 |
+
preprocess_image=False,
|
328 |
+
sparse_structure_sampler_params={
|
329 |
+
"steps": ss_sampling_steps,
|
330 |
+
"cfg_strength": ss_guidance_strength,
|
331 |
+
},
|
332 |
+
slat_sampler_params={
|
333 |
+
"steps": slat_sampling_steps,
|
334 |
+
"cfg_strength": slat_guidance_strength,
|
335 |
+
},
|
336 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
337 |
|
338 |
+
video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
|
339 |
+
video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
|
340 |
+
video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
|
341 |
+
video_path = os.path.join(user_dir, 'sample.mp4')
|
342 |
+
imageio.mimsave(video_path, video, fps=15)
|
343 |
+
state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
|
344 |
+
torch.cuda.empty_cache()
|
345 |
+
return state, video_path
|
346 |
+
except Exception as e:
|
347 |
+
print(f"Error in image_to_3d: {str(e)}")
|
348 |
+
return None, f"Error generating 3D model: {str(e)}"
|
349 |
|
350 |
|
351 |
@spaces.GPU(duration=90)
|
|
|
414 |
def create_interface():
|
415 |
# Preload models if needed
|
416 |
if PRELOAD_MODELS:
|
417 |
+
model_success, model_status_details = preload_models()
|
418 |
+
model_status = f"✅ {model_status_details}" if model_success else f"⚠️ {model_status_details}"
|
419 |
else:
|
420 |
model_status = "ℹ️ Models will be loaded on demand"
|
421 |
|
|
|
552 |
|
553 |
|
554 |
if __name__ == "__main__":
|
555 |
+
# Initialize models if PRELOAD_MODELS is True
|
556 |
+
if PRELOAD_MODELS:
|
557 |
+
success, status = preload_models()
|
558 |
+
print(status)
|
|
|
|
|
|
|
|
|
559 |
|
560 |
demo = create_interface()
|
561 |
demo.launch()
|