alexnasa commited on
Commit
d0b0cf2
·
verified ·
1 Parent(s): d9c86b4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -104
app.py CHANGED
@@ -27,6 +27,8 @@ from src.model.model.anysplat import AnySplat
27
  from src.model.ply_export import export_ply
28
  from src.utils.image import process_image
29
 
 
 
30
 
31
  # 1) Core model inference
32
  def get_reconstructed_scene(outdir, model, device):
@@ -77,7 +79,7 @@ def get_reconstructed_scene(outdir, model, device):
77
 
78
 
79
  # 2) Handle uploaded video/images --> produce target_dir + images
80
- def handle_uploads(input_video, input_images):
81
  """
82
  Create a new 'target_dir' + 'images' subfolder, and place user-uploaded
83
  images or extracted frames from video into it. Return (target_dir, image_paths).
@@ -86,9 +88,8 @@ def handle_uploads(input_video, input_images):
86
  gc.collect()
87
  torch.cuda.empty_cache()
88
 
89
- # Create a unique folder name
90
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
91
- target_dir = f"input_images_{timestamp}"
92
  target_dir_images = os.path.join(target_dir, "images")
93
 
94
  # Clean up if somehow that folder already exists
@@ -160,34 +161,24 @@ def update_gallery_on_upload(input_video, input_images):
160
 
161
 
162
  @spaces.GPU()
163
- # 4) Reconstruction: uses the target_dir plus any viz parameters
164
- def gradio_demo(
165
- target_dir,
166
- ):
167
- """
168
- Perform reconstruction using the already-created target_dir/images.
169
- """
170
- if not os.path.isdir(target_dir) or target_dir == "None":
171
- return None, None, None
172
 
173
  start_time = time.time()
174
  gc.collect()
175
  torch.cuda.empty_cache()
176
-
177
- # Prepare frame_filter dropdown
178
- target_dir_images = os.path.join(target_dir, "images")
179
  all_files = (
180
- sorted(os.listdir(target_dir_images))
181
- if os.path.isdir(target_dir_images)
182
  else []
183
  )
184
  all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
185
 
186
  print("Running run_model...")
187
  with torch.no_grad():
188
- plyfile, video, depth_colored = get_reconstructed_scene(
189
- target_dir, model, device
190
- )
191
 
192
  end_time = time.time()
193
  print(f"Total time: {end_time - start_time:.2f} seconds (including IO)")
@@ -195,11 +186,21 @@ def gradio_demo(
195
  return plyfile, video, depth_colored
196
 
197
 
198
- def clear_fields():
199
  """
200
- Clears the 3D viewer, the stored target_dir, and empties the gallery.
 
 
 
 
 
 
 
 
 
 
201
  """
202
- return None, None, None
203
 
204
 
205
  if __name__ == "__main__":
@@ -264,8 +265,9 @@ if __name__ == "__main__":
264
  }
265
  """
266
  with gr.Blocks(css=css, title="AnySplat Demo", theme=theme) as demo:
267
-
268
-
 
269
  target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
270
  is_example = gr.Textbox(label="is_example", visible=False, value="None")
271
  num_images = gr.Textbox(label="num_images", visible=False, value="None")
@@ -275,7 +277,6 @@ if __name__ == "__main__":
275
 
276
  with gr.Column(elem_id="col-container"):
277
 
278
-
279
  gr.Markdown(
280
  """ # AnySplat – Feed-forward 3D Gaussian Splatting from Unconstrained Views
281
 
@@ -329,92 +330,60 @@ if __name__ == "__main__":
329
 
330
  # ---------------------- Examples section ----------------------
331
 
332
- examples = [
333
- [None, "examples/video/re10k_1eca36ec55b88fe4.mp4", "re10k", "1eca36ec55b88fe4", "2", "Real", "True",],
334
- [None, "examples/video/bungeenerf_colosseum.mp4", "bungeenerf", "colosseum", "8", "Synthetic", "True",],
335
- [None, "examples/video/fox.mp4", "InstantNGP", "fox", "14", "Real", "True",],
336
- [None, "examples/video/matrixcity_street.mp4", "matrixcity", "street", "32", "Synthetic", "True",],
337
- [None, "examples/video/vrnerf_apartment.mp4", "vrnerf", "apartment", "32", "Real", "True",],
338
- [None, "examples/video/vrnerf_kitchen.mp4", "vrnerf", "kitchen", "17", "Real", "True",],
339
- [None, "examples/video/vrnerf_riverview.mp4", "vrnerf", "riverview", "12", "Real", "True",],
340
- [None, "examples/video/vrnerf_workshop.mp4", "vrnerf", "workshop", "32", "Real", "True",],
341
- [None, "examples/video/fillerbuster_ramen.mp4", "fillerbuster", "ramen", "32", "Real", "True",],
342
- [None, "examples/video/meganerf_rubble.mp4", "meganerf", "rubble", "10", "Real", "True",],
343
- [None, "examples/video/llff_horns.mp4", "llff", "horns", "12", "Real", "True",],
344
- [None, "examples/video/llff_fortress.mp4", "llff", "fortress", "7", "Real", "True",],
345
- [None, "examples/video/dtu_scan_106.mp4", "dtu", "scan_106", "20", "Real", "True",],
346
- [None, "examples/video/horizongs_hillside_summer.mp4", "horizongs", "hillside_summer", "55", "Synthetic", "True",],
347
- [None, "examples/video/kitti360.mp4", "kitti360", "kitti360", "64", "Real", "True",],
348
- ]
349
-
350
- def example_pipeline(
351
- input_images,
352
- input_video,
353
- dataset_name,
354
- scene_name,
355
- num_images_str,
356
- image_type,
357
- is_example,
358
- ):
359
- """
360
- 1) Copy example images to new target_dir
361
- 2) Reconstruct
362
- 3) Return model3D + logs + new_dir + updated dropdown + gallery
363
- We do NOT return is_example. It's just an input.
364
- """
365
- target_dir, image_paths = handle_uploads(input_video, input_images)
366
- plyfile, video, depth_colored = gradio_demo(target_dir)
367
- return plyfile, video, depth_colored, target_dir, image_paths
368
 
369
- gr.Examples(
370
- examples=examples,
371
- inputs=[
372
- input_images,
373
- input_video,
374
- dataset_name,
375
- scene_name,
376
- num_images,
377
- image_type,
378
- is_example,
379
- ],
380
- outputs=[
381
- reconstruction_output,
382
- rgb_video,
383
- depth_video,
384
- target_dir_output,
385
- image_gallery,
386
- ],
387
- fn=example_pipeline,
388
- cache_examples=False,
389
- examples_per_page=50,
390
- )
391
-
392
- gr.Markdown("<p style='text-align: center; font-style: italic; color: #666;'>We thank VGGT for their excellent gradio implementation!</p>")
393
-
394
  submit_btn.click(
395
- fn=clear_fields,
396
- inputs=[],
397
- outputs=[reconstruction_output, rgb_video, depth_video],
398
- ).then(
399
- fn=gradio_demo,
400
- inputs=[
401
- target_dir_output,
402
- ],
403
- outputs=[reconstruction_output, rgb_video, depth_video],
404
- ).then(
405
- fn=lambda: "False", inputs=[], outputs=[is_example]
406
- )
407
 
408
  input_video.change(
409
  fn=update_gallery_on_upload,
410
- inputs=[input_video, input_images],
411
  outputs=[reconstruction_output, target_dir_output, image_gallery],
412
  )
413
  input_images.change(
414
  fn=update_gallery_on_upload,
415
- inputs=[input_video, input_images],
416
  outputs=[reconstruction_output, target_dir_output, image_gallery],
417
  )
418
  demo.queue().launch(show_error=True, share=True)
419
-
420
- # We thank VGGT for their excellent gradio implementation
 
27
  from src.model.ply_export import export_ply
28
  from src.utils.image import process_image
29
 
30
+ os.environ["ANYSPLAT_PROCESSED"] = f"{os.getcwd()}/proprocess_results"
31
+
32
 
33
  # 1) Core model inference
34
  def get_reconstructed_scene(outdir, model, device):
 
79
 
80
 
81
  # 2) Handle uploaded video/images --> produce target_dir + images
82
+ def handle_uploads(input_video, input_images, session_id):
83
  """
84
  Create a new 'target_dir' + 'images' subfolder, and place user-uploaded
85
  images or extracted frames from video into it. Return (target_dir, image_paths).
 
88
  gc.collect()
89
  torch.cuda.empty_cache()
90
 
91
+ base_dir = os.path.join(os.environ["ANYSPLAT_PROCESSED"], session_id)
92
+ target_dir = base_dir
 
93
  target_dir_images = os.path.join(target_dir, "images")
94
 
95
  # Clean up if somehow that folder already exists
 
161
 
162
 
163
  @spaces.GPU()
164
+ def generate_splat(images_folder, session_id=None):
 
 
 
 
 
 
 
 
165
 
166
  start_time = time.time()
167
  gc.collect()
168
  torch.cuda.empty_cache()
169
+
170
+ base_dir = os.path.join(os.environ["ANYSPLAT_PROCESSED"], session_id)
171
+
172
  all_files = (
173
+ sorted(os.listdir(images_folder))
174
+ if os.path.isdir(images_folder)
175
  else []
176
  )
177
  all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
178
 
179
  print("Running run_model...")
180
  with torch.no_grad():
181
+ plyfile, video, depth_colored = get_reconstructed_scene(base_dir, model, device)
 
 
182
 
183
  end_time = time.time()
184
  print(f"Total time: {end_time - start_time:.2f} seconds (including IO)")
 
186
  return plyfile, video, depth_colored
187
 
188
 
189
+ def start_session(request: gr.Request):
190
  """
191
+ Initialize a new user session and return the session identifier.
192
+
193
+ This function is triggered when the Gradio demo loads and creates a unique
194
+ session hash that will be used to organize outputs and temporary files
195
+ for this specific user session.
196
+
197
+ Args:
198
+ request (gr.Request): Gradio request object containing session information
199
+
200
+ Returns:
201
+ str: Unique session hash identifier
202
  """
203
+ return request.session_hash
204
 
205
 
206
  if __name__ == "__main__":
 
265
  }
266
  """
267
  with gr.Blocks(css=css, title="AnySplat Demo", theme=theme) as demo:
268
+ session_state = gr.State()
269
+ demo.load(start_session, outputs=[session_state])
270
+
271
  target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
272
  is_example = gr.Textbox(label="is_example", visible=False, value="None")
273
  num_images = gr.Textbox(label="num_images", visible=False, value="None")
 
277
 
278
  with gr.Column(elem_id="col-container"):
279
 
 
280
  gr.Markdown(
281
  """ # AnySplat – Feed-forward 3D Gaussian Splatting from Unconstrained Views
282
 
 
330
 
331
  # ---------------------- Examples section ----------------------
332
 
333
+ # examples = [
334
+ # [None, "examples/video/re10k_1eca36ec55b88fe4.mp4", "re10k", "1eca36ec55b88fe4", "2", "Real", "True",],
335
+ # [None, "examples/video/bungeenerf_colosseum.mp4", "bungeenerf", "colosseum", "8", "Synthetic", "True",],
336
+ # [None, "examples/video/fox.mp4", "InstantNGP", "fox", "14", "Real", "True",],
337
+ # [None, "examples/video/matrixcity_street.mp4", "matrixcity", "street", "32", "Synthetic", "True",],
338
+ # [None, "examples/video/vrnerf_apartment.mp4", "vrnerf", "apartment", "32", "Real", "True",],
339
+ # [None, "examples/video/vrnerf_kitchen.mp4", "vrnerf", "kitchen", "17", "Real", "True",],
340
+ # [None, "examples/video/vrnerf_riverview.mp4", "vrnerf", "riverview", "12", "Real", "True",],
341
+ # [None, "examples/video/vrnerf_workshop.mp4", "vrnerf", "workshop", "32", "Real", "True",],
342
+ # [None, "examples/video/fillerbuster_ramen.mp4", "fillerbuster", "ramen", "32", "Real", "True",],
343
+ # [None, "examples/video/meganerf_rubble.mp4", "meganerf", "rubble", "10", "Real", "True",],
344
+ # [None, "examples/video/llff_horns.mp4", "llff", "horns", "12", "Real", "True",],
345
+ # [None, "examples/video/llff_fortress.mp4", "llff", "fortress", "7", "Real", "True",],
346
+ # [None, "examples/video/dtu_scan_106.mp4", "dtu", "scan_106", "20", "Real", "True",],
347
+ # [None, "examples/video/horizongs_hillside_summer.mp4", "horizongs", "hillside_summer", "55", "Synthetic", "True",],
348
+ # [None, "examples/video/kitti360.mp4", "kitti360", "kitti360", "64", "Real", "True",],
349
+ # ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
 
351
+ # gr.Examples(
352
+ # examples=examples,
353
+ # inputs=[
354
+ # input_images,
355
+ # input_video,
356
+ # dataset_name,
357
+ # scene_name,
358
+ # num_images,
359
+ # image_type,
360
+ # is_example,
361
+ # ],
362
+ # outputs=[
363
+ # reconstruction_output,
364
+ # rgb_video,
365
+ # depth_video,
366
+ # target_dir_output,
367
+ # image_gallery,
368
+ # ],
369
+ # fn=example_pipeline,
370
+ # cache_examples=False,
371
+ # examples_per_page=50,
372
+ # )
373
+
 
 
374
  submit_btn.click(
375
+ fn=generate_splat,
376
+ inputs=[target_dir_output,],
377
+ outputs=[reconstruction_output, rgb_video, depth_video])
 
 
 
 
 
 
 
 
 
378
 
379
  input_video.change(
380
  fn=update_gallery_on_upload,
381
+ inputs=[input_video, input_images, session_state],
382
  outputs=[reconstruction_output, target_dir_output, image_gallery],
383
  )
384
  input_images.change(
385
  fn=update_gallery_on_upload,
386
+ inputs=[input_video, input_images, session_state],
387
  outputs=[reconstruction_output, target_dir_output, image_gallery],
388
  )
389
  demo.queue().launch(show_error=True, share=True)