alexnasa commited on
Commit
2f4530b
·
verified ·
1 Parent(s): 90a8eed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -7
app.py CHANGED
@@ -82,8 +82,14 @@ def save_splat_file(splat_data, output_path):
82
  with open(output_path, "wb") as f:
83
  f.write(splat_data)
84
 
85
- def get_reconstructed_scene(outdir, image_files, model, device):
86
 
 
 
 
 
 
 
87
  images = [process_image(img_path) for img_path in image_files]
88
  images = torch.stack(images, dim=0).unsqueeze(0).to(device) # [1, K, 3, 448, 448]
89
  b, v, c, h, w = images.shape
@@ -234,12 +240,12 @@ def generate_splats_from_video(video_path, session_id=None):
234
  session_id = uuid.uuid4().hex
235
 
236
  images_folder, image_paths = extract_frames(video_path, session_id)
237
- plyfile, rgb_vid, depth_vid = generate_splats_from_images(image_paths, session_id)
238
 
239
  return plyfile, rgb_vid, depth_vid, image_paths
240
 
241
  @spaces.GPU()
242
- def generate_splats_from_images(image_paths, session_id=None):
243
 
244
  if session_id is None:
245
  session_id = uuid.uuid4().hex
@@ -250,9 +256,16 @@ def generate_splats_from_images(image_paths, session_id=None):
250
 
251
  base_dir = os.path.join(os.environ["ANYSPLAT_PROCESSED"], session_id)
252
 
 
 
 
 
 
 
 
253
  print("Running run_model...")
254
  with torch.no_grad():
255
- plyfile, video, depth_colored = get_reconstructed_scene(base_dir, image_paths, model, device)
256
 
257
  end_time = time.time()
258
  print(f"Total time: {end_time - start_time:.2f} seconds (including IO)")
@@ -322,7 +335,7 @@ if __name__ == "__main__":
322
  with gr.Tab("Video"):
323
  input_video = gr.Video(label="Upload Video", sources=["upload"], interactive=True, height=512)
324
  with gr.Tab("Images"):
325
- input_images = gr.File(file_count="multiple", label="Upload Files", type="filepath", height=512)
326
 
327
  submit_btn = gr.Button(
328
  "Generate Gaussian Splat", scale=1, variant="primary"
@@ -400,7 +413,7 @@ if __name__ == "__main__":
400
 
401
  submit_btn.click(
402
  fn=generate_splats_from_images,
403
- inputs=[image_gallery, session_state],
404
  outputs=[reconstruction_output, rgb_video, depth_video])
405
 
406
  input_video.upload(
@@ -417,4 +430,4 @@ if __name__ == "__main__":
417
 
418
  demo.unload(cleanup)
419
  demo.queue()
420
- demo.launch(show_error=True, share=True)
 
82
  with open(output_path, "wb") as f:
83
  f.write(splat_data)
84
 
85
+ def get_reconstructed_scene(outdir, model, device):
86
 
87
+ image_files = sorted(
88
+ [
89
+ os.path.join(outdir, "images", f)
90
+ for f in os.listdir(os.path.join(outdir, "images"))
91
+ ]
92
+ )
93
  images = [process_image(img_path) for img_path in image_files]
94
  images = torch.stack(images, dim=0).unsqueeze(0).to(device) # [1, K, 3, 448, 448]
95
  b, v, c, h, w = images.shape
 
240
  session_id = uuid.uuid4().hex
241
 
242
  images_folder, image_paths = extract_frames(video_path, session_id)
243
+ plyfile, rgb_vid, depth_vid = generate_splats_from_images(images_folder, session_id)
244
 
245
  return plyfile, rgb_vid, depth_vid, image_paths
246
 
247
  @spaces.GPU()
248
+ def generate_splats_from_images(images_folder, session_id=None):
249
 
250
  if session_id is None:
251
  session_id = uuid.uuid4().hex
 
256
 
257
  base_dir = os.path.join(os.environ["ANYSPLAT_PROCESSED"], session_id)
258
 
259
+ all_files = (
260
+ sorted(os.listdir(images_folder))
261
+ if os.path.isdir(images_folder)
262
+ else []
263
+ )
264
+ all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
265
+
266
  print("Running run_model...")
267
  with torch.no_grad():
268
+ plyfile, video, depth_colored = get_reconstructed_scene(base_dir, model, device)
269
 
270
  end_time = time.time()
271
  print(f"Total time: {end_time - start_time:.2f} seconds (including IO)")
 
335
  with gr.Tab("Video"):
336
  input_video = gr.Video(label="Upload Video", sources=["upload"], interactive=True, height=512)
337
  with gr.Tab("Images"):
338
+ input_images = gr.File(file_count="multiple", label="Upload Files", height=512)
339
 
340
  submit_btn = gr.Button(
341
  "Generate Gaussian Splat", scale=1, variant="primary"
 
413
 
414
  submit_btn.click(
415
  fn=generate_splats_from_images,
416
+ inputs=[target_dir_output, session_state],
417
  outputs=[reconstruction_output, rgb_video, depth_video])
418
 
419
  input_video.upload(
 
430
 
431
  demo.unload(cleanup)
432
  demo.queue()
433
+ demo.launch(show_error=True, share=True)