alexnasa commited on
Commit
2375b22
·
verified ·
1 Parent(s): d49f3be

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -74
app.py CHANGED
@@ -13,6 +13,7 @@ import tempfile
13
  import time
14
  from datetime import datetime
15
  from pathlib import Path
 
16
 
17
  import cv2
18
  import gradio as gr
@@ -30,9 +31,9 @@ from src.utils.image import process_image
30
  os.environ["ANYSPLAT_PROCESSED"] = f"{os.getcwd()}/proprocess_results"
31
 
32
 
33
- # 1) Core model inference
34
  def get_reconstructed_scene(outdir, model, device):
35
- # Load Images
36
  image_files = sorted(
37
  [
38
  os.path.join(outdir, "images", f)
@@ -45,10 +46,9 @@ def get_reconstructed_scene(outdir, model, device):
45
 
46
  assert c == 3, "Images must have 3 channels"
47
 
48
- # Run Inference
49
  gaussians, pred_context_pose = model.inference((images + 1) * 0.5)
50
 
51
- # Save the results
52
  pred_all_extrinsic = pred_context_pose["extrinsic"]
53
  pred_all_intrinsic = pred_context_pose["intrinsic"]
54
  video, depth_colored = save_interpolated_video(
@@ -79,7 +79,7 @@ def get_reconstructed_scene(outdir, model, device):
79
 
80
 
81
  # 2) Handle uploaded video/images --> produce target_dir + images
82
- def handle_uploads(input_video, input_images, session_id):
83
  """
84
  Create a new 'target_dir' + 'images' subfolder, and place user-uploaded
85
  images or extracted frames from video into it. Return (target_dir, image_paths).
@@ -100,18 +100,6 @@ def handle_uploads(input_video, input_images, session_id):
100
 
101
  image_paths = []
102
 
103
- # --- Handle images ---
104
- if input_images is not None:
105
- for file_data in input_images:
106
- if isinstance(file_data, dict) and "name" in file_data:
107
- file_path = file_data["name"]
108
- else:
109
- file_path = file_data
110
- dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
111
- shutil.copy(file_path, dst_path)
112
- image_paths.append(dst_path)
113
-
114
- # --- Handle video ---
115
  if input_video is not None:
116
  if isinstance(input_video, dict) and "name" in input_video:
117
  video_path = input_video["name"]
@@ -147,8 +135,7 @@ def handle_uploads(input_video, input_images, session_id):
147
  return target_dir, image_paths
148
 
149
 
150
- # 3) Update gallery on upload
151
- def update_gallery_on_upload(input_video, input_images, session_id):
152
  """
153
  Whenever user uploads or changes files, immediately handle them
154
  and show in the gallery. Return (target_dir, image_paths).
@@ -157,13 +144,25 @@ def update_gallery_on_upload(input_video, input_images, session_id):
157
  if not input_video and not input_images:
158
  return None, None, None
159
 
160
- target_dir, image_paths = handle_uploads(input_video, input_images, session_id)
161
  return None, target_dir, image_paths
162
 
 
 
 
 
 
163
 
 
 
 
 
164
  @spaces.GPU()
165
- def generate_splat(images_folder, session_id=None):
166
 
 
 
 
167
  start_time = time.time()
168
  gc.collect()
169
  torch.cuda.empty_cache()
@@ -186,7 +185,22 @@ def generate_splat(images_folder, session_id=None):
186
 
187
  return plyfile, video, depth_colored
188
 
189
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  def start_session(request: gr.Request):
191
  """
192
  Initialize a new user session and return the session identifier.
@@ -227,43 +241,6 @@ if __name__ == "__main__":
227
  margin: 0 auto;
228
  max-width: 1024px;
229
  }
230
-
231
- .custom-log * {
232
- font-style: italic;
233
- font-size: 22px !important;
234
- background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%);
235
- -webkit-background-clip: text;
236
- background-clip: text;
237
- font-weight: bold !important;
238
- color: transparent !important;
239
- text-align: center !important;
240
- }
241
-
242
- .example-log * {
243
- font-style: italic;
244
- font-size: 16px !important;
245
- background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%);
246
- -webkit-background-clip: text;
247
- background-clip: text;
248
- color: transparent !important;
249
- }
250
-
251
- #my_radio .wrap {
252
- display: flex;
253
- flex-wrap: nowrap;
254
- justify-content: center;
255
- align-items: center;
256
- }
257
-
258
- #my_radio .wrap label {
259
- display: flex;
260
- width: 50%;
261
- justify-content: center;
262
- align-items: center;
263
- margin: 0;
264
- padding: 10px 0;
265
- box-sizing: border-box;
266
- }
267
  """
268
  with gr.Blocks(css=css, title="AnySplat Demo", theme=theme) as demo:
269
  session_state = gr.State()
@@ -287,12 +264,7 @@ if __name__ == "__main__":
287
  with gr.Row():
288
  with gr.Column():
289
  input_video = gr.Video(label="Upload Video", interactive=True, height=512)
290
- input_images = gr.File(
291
- file_count="multiple",
292
- label="Upload Images",
293
- interactive=True,
294
- visible=False
295
- )
296
  submit_btn = gr.Button(
297
  "Reconstruct", scale=1, variant="primary"
298
  )
@@ -306,7 +278,6 @@ if __name__ == "__main__":
306
  preview=True,
307
  )
308
 
309
-
310
  with gr.Column():
311
  with gr.Column():
312
  reconstruction_output = gr.Model3D(
@@ -373,18 +344,16 @@ if __name__ == "__main__":
373
  # )
374
 
375
  submit_btn.click(
376
- fn=generate_splat,
377
  inputs=[target_dir_output, session_state],
378
  outputs=[reconstruction_output, rgb_video, depth_video])
379
 
380
  input_video.change(
381
  fn=update_gallery_on_upload,
382
- inputs=[input_video, input_images, session_state],
383
  outputs=[reconstruction_output, target_dir_output, image_gallery],
384
  )
385
- input_images.change(
386
- fn=update_gallery_on_upload,
387
- inputs=[input_video, input_images, session_state],
388
- outputs=[reconstruction_output, target_dir_output, image_gallery],
389
- )
390
- demo.queue().launch(show_error=True, share=True)
 
13
  import time
14
  from datetime import datetime
15
  from pathlib import Path
16
+ import uuid
17
 
18
  import cv2
19
  import gradio as gr
 
31
  os.environ["ANYSPLAT_PROCESSED"] = f"{os.getcwd()}/proprocess_results"
32
 
33
 
34
+
35
  def get_reconstructed_scene(outdir, model, device):
36
+
37
  image_files = sorted(
38
  [
39
  os.path.join(outdir, "images", f)
 
46
 
47
  assert c == 3, "Images must have 3 channels"
48
 
49
+
50
  gaussians, pred_context_pose = model.inference((images + 1) * 0.5)
51
 
 
52
  pred_all_extrinsic = pred_context_pose["extrinsic"]
53
  pred_all_intrinsic = pred_context_pose["intrinsic"]
54
  video, depth_colored = save_interpolated_video(
 
79
 
80
 
81
  # 2) Handle uploaded video/images --> produce target_dir + images
82
+ def extract_frames(input_video, session_id):
83
  """
84
  Create a new 'target_dir' + 'images' subfolder, and place user-uploaded
85
  images or extracted frames from video into it. Return (target_dir, image_paths).
 
100
 
101
  image_paths = []
102
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  if input_video is not None:
104
  if isinstance(input_video, dict) and "name" in input_video:
105
  video_path = input_video["name"]
 
135
  return target_dir, image_paths
136
 
137
 
138
+ def update_gallery_on_upload(input_video, session_id):
 
139
  """
140
  Whenever user uploads or changes files, immediately handle them
141
  and show in the gallery. Return (target_dir, image_paths).
 
144
  if not input_video and not input_images:
145
  return None, None, None
146
 
147
+ target_dir, image_paths = extract_frames(input_video, session_id)
148
  return None, target_dir, image_paths
149
 
150
+ @spaces.GPU()
151
+ def generate_splat_from_video(video_path, session_id=None):
152
+
153
+ if session_id is None:
154
+ session_id = uuid.uuid4().hex
155
 
156
+ target_dir, image_paths = extract_frames(input_video, session_id)
157
+
158
+ return generate_splat_from_images(images_folder, session_id), image_paths
159
+
160
  @spaces.GPU()
161
+ def generate_splat_from_images(images_folder, session_id=None):
162
 
163
+ if session_id is None:
164
+ session_id = uuid.uuid4().hex
165
+
166
  start_time = time.time()
167
  gc.collect()
168
  torch.cuda.empty_cache()
 
185
 
186
  return plyfile, video, depth_colored
187
 
188
+ def cleanup(request: gr.Request):
189
+ """
190
+ Clean up session-specific directories and temporary files when the user session ends.
191
+
192
+ This function is triggered when the Gradio demo is unloaded (e.g., when the user
193
+ closes the browser tab or navigates away). It removes all temporary files and
194
+ directories created during the user's session to free up storage space.
195
+
196
+ Args:
197
+ request (gr.Request): Gradio request object containing session information
198
+ """
199
+ sid = request.session_hash
200
+ if sid:
201
+ d1 = os.path.join(os.environ["ANYSPLAT_PROCESSED"], sid)
202
+ shutil.rmtree(d1, ignore_errors=True)
203
+
204
  def start_session(request: gr.Request):
205
  """
206
  Initialize a new user session and return the session identifier.
 
241
  margin: 0 auto;
242
  max-width: 1024px;
243
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  """
245
  with gr.Blocks(css=css, title="AnySplat Demo", theme=theme) as demo:
246
  session_state = gr.State()
 
264
  with gr.Row():
265
  with gr.Column():
266
  input_video = gr.Video(label="Upload Video", interactive=True, height=512)
267
+
 
 
 
 
 
268
  submit_btn = gr.Button(
269
  "Reconstruct", scale=1, variant="primary"
270
  )
 
278
  preview=True,
279
  )
280
 
 
281
  with gr.Column():
282
  with gr.Column():
283
  reconstruction_output = gr.Model3D(
 
344
  # )
345
 
346
  submit_btn.click(
347
+ fn=generate_splat_from_images,
348
  inputs=[target_dir_output, session_state],
349
  outputs=[reconstruction_output, rgb_video, depth_video])
350
 
351
  input_video.change(
352
  fn=update_gallery_on_upload,
353
+ inputs=[input_video, session_state],
354
  outputs=[reconstruction_output, target_dir_output, image_gallery],
355
  )
356
+
357
+ demo.unload(cleanup)
358
+ demo.queue()
359
+ demo.launch(show_error=True, share=True)