Spaces:
Runtime error
Runtime error
lionelgarnier
commited on
Commit
·
0aaba44
1
Parent(s):
cc9aa28
refactor session handling to use a fixed temporary directory and remove session management functions
Browse files
app.py
CHANGED
@@ -53,27 +53,6 @@ _image_gen_pipeline = None
|
|
53 |
_trellis_pipeline = None
|
54 |
|
55 |
|
56 |
-
def start_session(req: gr.Request):
|
57 |
-
"""Create a temporary directory for the user session"""
|
58 |
-
try:
|
59 |
-
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
60 |
-
os.makedirs(user_dir, exist_ok=True)
|
61 |
-
print(f"Session started: {req.session_hash}")
|
62 |
-
except Exception as e:
|
63 |
-
print(f"Error starting session: {str(e)}")
|
64 |
-
|
65 |
-
|
66 |
-
def end_session(req: gr.Request):
|
67 |
-
"""Clean up the temporary directory when the session ends"""
|
68 |
-
try:
|
69 |
-
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
70 |
-
if os.path.exists(user_dir):
|
71 |
-
shutil.rmtree(user_dir)
|
72 |
-
print(f"Session ended: {req.session_hash}")
|
73 |
-
except Exception as e:
|
74 |
-
print(f"Error ending session: {str(e)}")
|
75 |
-
|
76 |
-
|
77 |
@spaces.GPU()
|
78 |
def get_image_gen_pipeline():
|
79 |
global _image_gen_pipeline
|
@@ -320,10 +299,11 @@ def image_to_3d(
|
|
320 |
ss_sampling_steps: int,
|
321 |
slat_guidance_strength: float,
|
322 |
slat_sampling_steps: int,
|
323 |
-
req: gr.Request,
|
324 |
) -> Tuple[dict, str]:
|
325 |
try:
|
326 |
-
|
|
|
|
|
327 |
|
328 |
# Get the pipeline using the getter function
|
329 |
pipeline = get_trellis_pipeline()
|
@@ -348,7 +328,7 @@ def image_to_3d(
|
|
348 |
video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
|
349 |
video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
|
350 |
video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
|
351 |
-
video_path = os.path.join(
|
352 |
imageio.mimsave(video_path, video, fps=15)
|
353 |
state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
|
354 |
torch.cuda.empty_cache()
|
@@ -363,7 +343,6 @@ def extract_glb(
|
|
363 |
state: dict,
|
364 |
mesh_simplify: float,
|
365 |
texture_size: int,
|
366 |
-
req: gr.Request,
|
367 |
) -> Tuple[str, str]:
|
368 |
"""
|
369 |
Extract a GLB file from the 3D model.
|
@@ -376,17 +355,17 @@ def extract_glb(
|
|
376 |
Returns:
|
377 |
str: The path to the extracted GLB file.
|
378 |
"""
|
379 |
-
|
380 |
gs, mesh = unpack_state(state)
|
381 |
glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
|
382 |
-
glb_path = os.path.join(
|
383 |
glb.export(glb_path)
|
384 |
torch.cuda.empty_cache()
|
385 |
return glb_path, glb_path
|
386 |
|
387 |
|
388 |
@spaces.GPU
|
389 |
-
def extract_gaussian(state: dict
|
390 |
"""
|
391 |
Extract a Gaussian file from the 3D model.
|
392 |
|
@@ -396,9 +375,9 @@ def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
|
|
396 |
Returns:
|
397 |
str: The path to the extracted Gaussian file.
|
398 |
"""
|
399 |
-
|
400 |
gs, _ = unpack_state(state)
|
401 |
-
gaussian_path = os.path.join(
|
402 |
gs.save_ply(gaussian_path)
|
403 |
torch.cuda.empty_cache()
|
404 |
return gaussian_path, gaussian_path
|
@@ -430,10 +409,6 @@ def create_interface():
|
|
430 |
model_status = "ℹ️ Models will be loaded on demand"
|
431 |
|
432 |
with gr.Blocks(css=css) as demo:
|
433 |
-
# Set up session management - COMMENT THESE OUT FOR TESTING
|
434 |
-
# demo.load(start_session)
|
435 |
-
# demo.unload(end_session)
|
436 |
-
|
437 |
gr.Info(model_status)
|
438 |
|
439 |
# State for storing 3D model data
|
@@ -604,4 +579,4 @@ if __name__ == "__main__":
|
|
604 |
print(status)
|
605 |
|
606 |
demo = create_interface()
|
607 |
-
demo.launch()
|
|
|
53 |
_trellis_pipeline = None
|
54 |
|
55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
@spaces.GPU()
|
57 |
def get_image_gen_pipeline():
|
58 |
global _image_gen_pipeline
|
|
|
299 |
ss_sampling_steps: int,
|
300 |
slat_guidance_strength: float,
|
301 |
slat_sampling_steps: int,
|
|
|
302 |
) -> Tuple[dict, str]:
|
303 |
try:
|
304 |
+
# Use a fixed temp directory instead of user-specific
|
305 |
+
temp_dir = os.path.join(TMP_DIR, "temp_output")
|
306 |
+
os.makedirs(temp_dir, exist_ok=True)
|
307 |
|
308 |
# Get the pipeline using the getter function
|
309 |
pipeline = get_trellis_pipeline()
|
|
|
328 |
video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
|
329 |
video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
|
330 |
video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
|
331 |
+
video_path = os.path.join(temp_dir, 'sample.mp4')
|
332 |
imageio.mimsave(video_path, video, fps=15)
|
333 |
state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
|
334 |
torch.cuda.empty_cache()
|
|
|
343 |
state: dict,
|
344 |
mesh_simplify: float,
|
345 |
texture_size: int,
|
|
|
346 |
) -> Tuple[str, str]:
|
347 |
"""
|
348 |
Extract a GLB file from the 3D model.
|
|
|
355 |
Returns:
|
356 |
str: The path to the extracted GLB file.
|
357 |
"""
|
358 |
+
temp_dir = os.path.join(TMP_DIR, "temp_output")
|
359 |
gs, mesh = unpack_state(state)
|
360 |
glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
|
361 |
+
glb_path = os.path.join(temp_dir, 'sample.glb')
|
362 |
glb.export(glb_path)
|
363 |
torch.cuda.empty_cache()
|
364 |
return glb_path, glb_path
|
365 |
|
366 |
|
367 |
@spaces.GPU
|
368 |
+
def extract_gaussian(state: dict) -> Tuple[str, str]:
|
369 |
"""
|
370 |
Extract a Gaussian file from the 3D model.
|
371 |
|
|
|
375 |
Returns:
|
376 |
str: The path to the extracted Gaussian file.
|
377 |
"""
|
378 |
+
temp_dir = os.path.join(TMP_DIR, "temp_output")
|
379 |
gs, _ = unpack_state(state)
|
380 |
+
gaussian_path = os.path.join(temp_dir, 'sample.ply')
|
381 |
gs.save_ply(gaussian_path)
|
382 |
torch.cuda.empty_cache()
|
383 |
return gaussian_path, gaussian_path
|
|
|
409 |
model_status = "ℹ️ Models will be loaded on demand"
|
410 |
|
411 |
with gr.Blocks(css=css) as demo:
|
|
|
|
|
|
|
|
|
412 |
gr.Info(model_status)
|
413 |
|
414 |
# State for storing 3D model data
|
|
|
579 |
print(status)
|
580 |
|
581 |
demo = create_interface()
|
582 |
+
demo.launch()
|