alexnasa commited on
Commit
79cc590
·
verified ·
1 Parent(s): e1b3a3e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -5
app.py CHANGED
@@ -212,7 +212,6 @@ def extract_frames(input_video, session_id):
212
 
213
 
214
  def update_gallery_on_video_upload(input_video, session_id):
215
-
216
  if not input_video:
217
  return None, None, None
218
 
@@ -229,6 +228,17 @@ def update_gallery_on_images_upload(input_images, session_id):
229
 
230
  @spaces.GPU()
231
  def generate_splats_from_video(video_path, session_id=None):
 
 
 
 
 
 
 
 
 
 
 
232
 
233
  if session_id is None:
234
  session_id = uuid.uuid4().hex
@@ -240,7 +250,16 @@ def generate_splats_from_video(video_path, session_id=None):
240
 
241
  @spaces.GPU()
242
  def generate_splats_from_images(image_paths, session_id=None):
243
-
 
 
 
 
 
 
 
 
 
244
  processed_image_paths = []
245
 
246
  for file_data in image_paths:
@@ -267,12 +286,12 @@ def generate_splats_from_images(image_paths, session_id=None):
267
 
268
  print("Running run_model...")
269
  with torch.no_grad():
270
- plyfile, video, depth_colored = get_reconstructed_scene(base_dir, image_paths, model, device)
271
 
272
  end_time = time.time()
273
  print(f"Total time: {end_time - start_time:.2f} seconds (including IO)")
274
 
275
- return plyfile, video, depth_colored
276
 
277
  def cleanup(request: gr.Request):
278
 
@@ -422,14 +441,16 @@ if __name__ == "__main__":
422
  fn=update_gallery_on_video_upload,
423
  inputs=[input_video, session_state],
424
  outputs=[reconstruction_output, target_dir_output, image_gallery],
 
425
  )
426
 
427
  input_images.upload(
428
  fn=update_gallery_on_images_upload,
429
  inputs=[input_images, session_state],
430
  outputs=[reconstruction_output, target_dir_output, image_gallery],
 
431
  )
432
 
433
  demo.unload(cleanup)
434
  demo.queue()
435
- demo.launch(show_error=True, share=True)
 
212
 
213
 
214
  def update_gallery_on_video_upload(input_video, session_id):
 
215
  if not input_video:
216
  return None, None, None
217
 
 
228
 
229
  @spaces.GPU()
230
  def generate_splats_from_video(video_path, session_id=None):
231
+ """
232
+ Perform Gaussian Splatting from Unconstrained Views a Given Video, using a Feed-forward model.
233
+
234
+ Args:
235
+ video_path (str): Path to the input video file on disk.
236
+ Returns:
237
+ plyfile: Path to the reconstructed 3D object from the given video.
238
+ rgb_vid: Path the the interpolated rgb video, increasing the frame rate using guassian splatting and interpolation of frames.
239
+ depth_vid: Path the the interpolated depth video, increasing the frame rate using guassian splatting and interpolation of frames.
240
+ image_paths: A list of paths from extracted frame from the video that is used for training Gaussian Splatting.
241
+ """
242
 
243
  if session_id is None:
244
  session_id = uuid.uuid4().hex
 
250
 
251
  @spaces.GPU()
252
  def generate_splats_from_images(image_paths, session_id=None):
253
+ """
254
+ Perform Gaussian Splatting from Unconstrained Views a Given Images , using a Feed-forward model.
255
+
256
+ Args:
257
+ image_paths (str): Path to the input image files on disk.
258
+ Returns:
259
+ plyfile: Path to the reconstructed 3D object from the given image files.
260
+ rgb_vid: Path the the interpolated rgb video, increasing the frame rate using guassian splatting and interpolation of frames.
261
+ depth_vid: Path the the interpolated depth video, increasing the frame rate using guassian splatting and interpolation of frames.
262
+ """
263
  processed_image_paths = []
264
 
265
  for file_data in image_paths:
 
286
 
287
  print("Running run_model...")
288
  with torch.no_grad():
289
+ plyfile, rgb_vid, depth_vid = get_reconstructed_scene(base_dir, image_paths, model, device)
290
 
291
  end_time = time.time()
292
  print(f"Total time: {end_time - start_time:.2f} seconds (including IO)")
293
 
294
+ return plyfile, rgb_vid, depth_vid
295
 
296
  def cleanup(request: gr.Request):
297
 
 
441
  fn=update_gallery_on_video_upload,
442
  inputs=[input_video, session_state],
443
  outputs=[reconstruction_output, target_dir_output, image_gallery],
444
+ show_api=False
445
  )
446
 
447
  input_images.upload(
448
  fn=update_gallery_on_images_upload,
449
  inputs=[input_images, session_state],
450
  outputs=[reconstruction_output, target_dir_output, image_gallery],
451
+ show_api=False
452
  )
453
 
454
  demo.unload(cleanup)
455
  demo.queue()
456
+ demo.launch(show_error=True, share=True, mcp_server=True)