alexnasa commited on
Commit
66a9013
·
verified ·
1 Parent(s): 8d48dde

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -38
app.py CHANGED
@@ -129,13 +129,46 @@ def get_reconstructed_scene(outdir, model, device):
129
  torch.cuda.empty_cache()
130
  return splatfile, video, depth_colored
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
- # 2) Handle uploaded video/images --> produce target_dir + images
134
  def extract_frames(input_video, session_id):
135
- """
136
- Create a new 'target_dir' + 'images' subfolder, and place user-uploaded
137
- images or extracted frames from video into it. Return (target_dir, image_paths).
138
- """
139
  start_time = time.time()
140
  gc.collect()
141
  torch.cuda.empty_cache()
@@ -144,9 +177,9 @@ def extract_frames(input_video, session_id):
144
  target_dir = base_dir
145
  target_dir_images = os.path.join(target_dir, "images")
146
 
147
- # Clean up if somehow that folder already exists
148
  if os.path.exists(target_dir):
149
  shutil.rmtree(target_dir)
 
150
  os.makedirs(target_dir)
151
  os.makedirs(target_dir_images)
152
 
@@ -187,18 +220,22 @@ def extract_frames(input_video, session_id):
187
  return target_dir, image_paths
188
 
189
 
190
- def update_gallery_on_upload(input_video, session_id):
191
- """
192
- Whenever user uploads or changes files, immediately handle them
193
- and show in the gallery. Return (target_dir, image_paths).
194
- If nothing is uploaded, returns "None" and empty list.
195
- """
196
- if not input_video and not input_images:
197
  return None, None, None
198
 
199
  target_dir, image_paths = extract_frames(input_video, session_id)
200
  return None, target_dir, image_paths
201
 
 
 
 
 
 
 
 
 
202
  @spaces.GPU()
203
  def generate_splats_from_video(video_path, session_id=None):
204
 
@@ -239,35 +276,14 @@ def generate_splats_from_images(images_folder, session_id=None):
239
  return plyfile, video, depth_colored
240
 
241
  def cleanup(request: gr.Request):
242
- """
243
- Clean up session-specific directories and temporary files when the user session ends.
244
-
245
- This function is triggered when the Gradio demo is unloaded (e.g., when the user
246
- closes the browser tab or navigates away). It removes all temporary files and
247
- directories created during the user's session to free up storage space.
248
-
249
- Args:
250
- request (gr.Request): Gradio request object containing session information
251
- """
252
  sid = request.session_hash
253
  if sid:
254
  d1 = os.path.join(os.environ["ANYSPLAT_PROCESSED"], sid)
255
  shutil.rmtree(d1, ignore_errors=True)
256
 
257
  def start_session(request: gr.Request):
258
- """
259
- Initialize a new user session and return the session identifier.
260
-
261
- This function is triggered when the Gradio demo loads and creates a unique
262
- session hash that will be used to organize outputs and temporary files
263
- for this specific user session.
264
-
265
- Args:
266
- request (gr.Request): Gradio request object containing session information
267
-
268
- Returns:
269
- str: Unique session hash identifier
270
- """
271
  return request.session_hash
272
 
273
 
@@ -322,7 +338,7 @@ if __name__ == "__main__":
322
  with gr.Tab("Video"):
323
  input_video = gr.Video(label="Upload Video", sources=["upload"], interactive=True, height=512)
324
  with gr.Tab("Images"):
325
- input_images = gr.File(label="Upload Files", height=512)
326
 
327
  submit_btn = gr.Button(
328
  "Generate Gaussian Splat", scale=1, variant="primary"
@@ -397,11 +413,17 @@ if __name__ == "__main__":
397
  outputs=[reconstruction_output, rgb_video, depth_video])
398
 
399
  input_video.upload(
400
- fn=update_gallery_on_upload,
401
  inputs=[input_video, session_state],
402
  outputs=[reconstruction_output, target_dir_output, image_gallery],
403
  )
404
 
 
 
 
 
 
 
405
  demo.unload(cleanup)
406
  demo.queue()
407
  demo.launch(show_error=True, share=True)
 
129
  torch.cuda.empty_cache()
130
  return splatfile, video, depth_colored
131
 
132
+ def extract_images(input_video, session_id):
133
+
134
+ start_time = time.time()
135
+ gc.collect()
136
+ torch.cuda.empty_cache()
137
+
138
+ base_dir = os.path.join(os.environ["ANYSPLAT_PROCESSED"], session_id)
139
+ target_dir = base_dir
140
+ target_dir_images = os.path.join(target_dir, "images")
141
+
142
+ if os.path.exists(target_dir):
143
+ shutil.rmtree(target_dir)
144
+
145
+ os.makedirs(target_dir)
146
+ os.makedirs(target_dir_images)
147
+
148
+ image_paths = []
149
+
150
+ if input_images is not None:
151
+ for file_data in input_images:
152
+ if isinstance(file_data, dict) and "name" in file_data:
153
+ file_path = file_data["name"]
154
+ else:
155
+ file_path = file_data
156
+ dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
157
+ shutil.copy(file_path, dst_path)
158
+ image_paths.append(dst_path)
159
+
160
+ # Sort final images for gallery
161
+ image_paths = sorted(image_paths)
162
+
163
+ end_time = time.time()
164
+ print(
165
+ f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds"
166
+ )
167
+ return target_dir, image_paths
168
+
169
 
 
170
  def extract_frames(input_video, session_id):
171
+
 
 
 
172
  start_time = time.time()
173
  gc.collect()
174
  torch.cuda.empty_cache()
 
177
  target_dir = base_dir
178
  target_dir_images = os.path.join(target_dir, "images")
179
 
 
180
  if os.path.exists(target_dir):
181
  shutil.rmtree(target_dir)
182
+
183
  os.makedirs(target_dir)
184
  os.makedirs(target_dir_images)
185
 
 
220
  return target_dir, image_paths
221
 
222
 
223
+ def update_gallery_on_video_upload(input_video, session_id):
224
+
225
+ if not input_video:
 
 
 
 
226
  return None, None, None
227
 
228
  target_dir, image_paths = extract_frames(input_video, session_id)
229
  return None, target_dir, image_paths
230
 
231
+ def update_gallery_on_images_upload(input_images, session_id):
232
+
233
+ if not input_images:
234
+ return None, None, None
235
+
236
+ target_dir, image_paths = extract_images(input_images, session_id)
237
+ return None, target_dir, image_paths
238
+
239
  @spaces.GPU()
240
  def generate_splats_from_video(video_path, session_id=None):
241
 
 
276
  return plyfile, video, depth_colored
277
 
278
  def cleanup(request: gr.Request):
279
+
 
 
 
 
 
 
 
 
 
280
  sid = request.session_hash
281
  if sid:
282
  d1 = os.path.join(os.environ["ANYSPLAT_PROCESSED"], sid)
283
  shutil.rmtree(d1, ignore_errors=True)
284
 
285
  def start_session(request: gr.Request):
286
+
 
 
 
 
 
 
 
 
 
 
 
 
287
  return request.session_hash
288
 
289
 
 
338
  with gr.Tab("Video"):
339
  input_video = gr.Video(label="Upload Video", sources=["upload"], interactive=True, height=512)
340
  with gr.Tab("Images"):
341
+ input_images = gr.File(file_count="multiple", label="Upload Files", height=512)
342
 
343
  submit_btn = gr.Button(
344
  "Generate Gaussian Splat", scale=1, variant="primary"
 
413
  outputs=[reconstruction_output, rgb_video, depth_video])
414
 
415
  input_video.upload(
416
+ fn=update_gallery_on_video_upload,
417
  inputs=[input_video, session_state],
418
  outputs=[reconstruction_output, target_dir_output, image_gallery],
419
  )
420
 
421
+ input_images.upload(
422
+ fn=update_gallery_on_images_upload,
423
+ inputs=[input_images, session_state],
424
+ outputs=[reconstruction_output, target_dir_output, image_gallery],
425
+ )
426
+
427
  demo.unload(cleanup)
428
  demo.queue()
429
  demo.launch(show_error=True, share=True)