lionelgarnier commited on
Commit
0aaba44
·
1 Parent(s): cc9aa28

refactor session handling to use a fixed temporary directory and remove session management functions

Browse files
Files changed (1) hide show
  1. app.py +10 -35
app.py CHANGED
@@ -53,27 +53,6 @@ _image_gen_pipeline = None
53
  _trellis_pipeline = None
54
 
55
 
56
- def start_session(req: gr.Request):
57
- """Create a temporary directory for the user session"""
58
- try:
59
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
60
- os.makedirs(user_dir, exist_ok=True)
61
- print(f"Session started: {req.session_hash}")
62
- except Exception as e:
63
- print(f"Error starting session: {str(e)}")
64
-
65
-
66
- def end_session(req: gr.Request):
67
- """Clean up the temporary directory when the session ends"""
68
- try:
69
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
70
- if os.path.exists(user_dir):
71
- shutil.rmtree(user_dir)
72
- print(f"Session ended: {req.session_hash}")
73
- except Exception as e:
74
- print(f"Error ending session: {str(e)}")
75
-
76
-
77
  @spaces.GPU()
78
  def get_image_gen_pipeline():
79
  global _image_gen_pipeline
@@ -320,10 +299,11 @@ def image_to_3d(
320
  ss_sampling_steps: int,
321
  slat_guidance_strength: float,
322
  slat_sampling_steps: int,
323
- req: gr.Request,
324
  ) -> Tuple[dict, str]:
325
  try:
326
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
 
 
327
 
328
  # Get the pipeline using the getter function
329
  pipeline = get_trellis_pipeline()
@@ -348,7 +328,7 @@ def image_to_3d(
348
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
349
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
350
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
351
- video_path = os.path.join(user_dir, 'sample.mp4')
352
  imageio.mimsave(video_path, video, fps=15)
353
  state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
354
  torch.cuda.empty_cache()
@@ -363,7 +343,6 @@ def extract_glb(
363
  state: dict,
364
  mesh_simplify: float,
365
  texture_size: int,
366
- req: gr.Request,
367
  ) -> Tuple[str, str]:
368
  """
369
  Extract a GLB file from the 3D model.
@@ -376,17 +355,17 @@ def extract_glb(
376
  Returns:
377
  str: The path to the extracted GLB file.
378
  """
379
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
380
  gs, mesh = unpack_state(state)
381
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
382
- glb_path = os.path.join(user_dir, 'sample.glb')
383
  glb.export(glb_path)
384
  torch.cuda.empty_cache()
385
  return glb_path, glb_path
386
 
387
 
388
  @spaces.GPU
389
- def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
390
  """
391
  Extract a Gaussian file from the 3D model.
392
 
@@ -396,9 +375,9 @@ def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
396
  Returns:
397
  str: The path to the extracted Gaussian file.
398
  """
399
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
400
  gs, _ = unpack_state(state)
401
- gaussian_path = os.path.join(user_dir, 'sample.ply')
402
  gs.save_ply(gaussian_path)
403
  torch.cuda.empty_cache()
404
  return gaussian_path, gaussian_path
@@ -430,10 +409,6 @@ def create_interface():
430
  model_status = "ℹ️ Models will be loaded on demand"
431
 
432
  with gr.Blocks(css=css) as demo:
433
- # Set up session management - COMMENT THESE OUT FOR TESTING
434
- # demo.load(start_session)
435
- # demo.unload(end_session)
436
-
437
  gr.Info(model_status)
438
 
439
  # State for storing 3D model data
@@ -604,4 +579,4 @@ if __name__ == "__main__":
604
  print(status)
605
 
606
  demo = create_interface()
607
- demo.launch()
 
53
  _trellis_pipeline = None
54
 
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  @spaces.GPU()
57
  def get_image_gen_pipeline():
58
  global _image_gen_pipeline
 
299
  ss_sampling_steps: int,
300
  slat_guidance_strength: float,
301
  slat_sampling_steps: int,
 
302
  ) -> Tuple[dict, str]:
303
  try:
304
+ # Use a fixed temp directory instead of user-specific
305
+ temp_dir = os.path.join(TMP_DIR, "temp_output")
306
+ os.makedirs(temp_dir, exist_ok=True)
307
 
308
  # Get the pipeline using the getter function
309
  pipeline = get_trellis_pipeline()
 
328
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
329
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
330
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
331
+ video_path = os.path.join(temp_dir, 'sample.mp4')
332
  imageio.mimsave(video_path, video, fps=15)
333
  state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
334
  torch.cuda.empty_cache()
 
343
  state: dict,
344
  mesh_simplify: float,
345
  texture_size: int,
 
346
  ) -> Tuple[str, str]:
347
  """
348
  Extract a GLB file from the 3D model.
 
355
  Returns:
356
  str: The path to the extracted GLB file.
357
  """
358
+ temp_dir = os.path.join(TMP_DIR, "temp_output")
359
  gs, mesh = unpack_state(state)
360
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
361
+ glb_path = os.path.join(temp_dir, 'sample.glb')
362
  glb.export(glb_path)
363
  torch.cuda.empty_cache()
364
  return glb_path, glb_path
365
 
366
 
367
  @spaces.GPU
368
+ def extract_gaussian(state: dict) -> Tuple[str, str]:
369
  """
370
  Extract a Gaussian file from the 3D model.
371
 
 
375
  Returns:
376
  str: The path to the extracted Gaussian file.
377
  """
378
+ temp_dir = os.path.join(TMP_DIR, "temp_output")
379
  gs, _ = unpack_state(state)
380
+ gaussian_path = os.path.join(temp_dir, 'sample.ply')
381
  gs.save_ply(gaussian_path)
382
  torch.cuda.empty_cache()
383
  return gaussian_path, gaussian_path
 
409
  model_status = "ℹ️ Models will be loaded on demand"
410
 
411
  with gr.Blocks(css=css) as demo:
 
 
 
 
412
  gr.Info(model_status)
413
 
414
  # State for storing 3D model data
 
579
  print(status)
580
 
581
  demo = create_interface()
582
+ demo.launch()