alexnasa commited on
Commit
03f4b66
·
verified ·
1 Parent(s): a8c20ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -16
app.py CHANGED
@@ -82,14 +82,8 @@ def save_splat_file(splat_data, output_path):
82
  with open(output_path, "wb") as f:
83
  f.write(splat_data)
84
 
85
- def get_reconstructed_scene(outdir, model, device):
86
 
87
- image_files = sorted(
88
- [
89
- os.path.join(outdir, "images", f)
90
- for f in os.listdir(os.path.join(outdir, "images"))
91
- ]
92
- )
93
  images = [process_image(img_path) for img_path in image_files]
94
  images = torch.stack(images, dim=0).unsqueeze(0).to(device) # [1, K, 3, 448, 448]
95
  b, v, c, h, w = images.shape
@@ -245,7 +239,7 @@ def generate_splats_from_video(video_path, session_id=None):
245
  return plyfile, rgb_vid, depth_vid, image_paths
246
 
247
  @spaces.GPU()
248
- def generate_splats_from_images(images_folder, session_id=None):
249
 
250
  if session_id is None:
251
  session_id = uuid.uuid4().hex
@@ -256,16 +250,11 @@ def generate_splats_from_images(images_folder, session_id=None):
256
 
257
  base_dir = os.path.join(os.environ["ANYSPLAT_PROCESSED"], session_id)
258
 
259
- all_files = (
260
- sorted(os.listdir(images_folder))
261
- if os.path.isdir(images_folder)
262
- else []
263
- )
264
- all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
265
 
266
  print("Running run_model...")
267
  with torch.no_grad():
268
- plyfile, video, depth_colored = get_reconstructed_scene(base_dir, model, device)
269
 
270
  end_time = time.time()
271
  print(f"Total time: {end_time - start_time:.2f} seconds (including IO)")
@@ -413,7 +402,7 @@ if __name__ == "__main__":
413
 
414
  submit_btn.click(
415
  fn=generate_splats_from_images,
416
- inputs=[target_dir_output, session_state],
417
  outputs=[reconstruction_output, rgb_video, depth_video])
418
 
419
  input_video.upload(
 
82
  with open(output_path, "wb") as f:
83
  f.write(splat_data)
84
 
85
+ def get_reconstructed_scene(outdir, image_files, model, device):
86
 
 
 
 
 
 
 
87
  images = [process_image(img_path) for img_path in image_files]
88
  images = torch.stack(images, dim=0).unsqueeze(0).to(device) # [1, K, 3, 448, 448]
89
  b, v, c, h, w = images.shape
 
239
  return plyfile, rgb_vid, depth_vid, image_paths
240
 
241
  @spaces.GPU()
242
+ def generate_splats_from_images(image_paths, session_id=None):
243
 
244
  if session_id is None:
245
  session_id = uuid.uuid4().hex
 
250
 
251
  base_dir = os.path.join(os.environ["ANYSPLAT_PROCESSED"], session_id)
252
 
253
+ all_files = [os.path.basename(p) for p in image_paths]
 
 
 
 
 
254
 
255
  print("Running run_model...")
256
  with torch.no_grad():
257
+ plyfile, video, depth_colored = get_reconstructed_scene(base_dir, all_files, model, device)
258
 
259
  end_time = time.time()
260
  print(f"Total time: {end_time - start_time:.2f} seconds (including IO)")
 
402
 
403
  submit_btn.click(
404
  fn=generate_splats_from_images,
405
+ inputs=[image_gallery, session_state],
406
  outputs=[reconstruction_output, rgb_video, depth_video])
407
 
408
  input_video.upload(