AnySplat / app.py
alexnasa's picture
Update app.py
5f3c363 verified
raw
history blame
13.5 kB
import spaces
import torch
print(f'torch version:{torch.__version__}')
import functools
import gc
import os
import subprocess
import shutil
import sys
import tempfile
import time
from datetime import datetime
from pathlib import Path
import uuid
import cv2
import gradio as gr
from huggingface_hub import hf_hub_download
from PIL import Image
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.misc.image_io import save_interpolated_video
from src.model.model.anysplat import AnySplat
from src.model.ply_export import export_ply
from src.utils.image import process_image
import open3d as o3d
os.environ["ANYSPLAT_PROCESSED"] = f"{os.getcwd()}/proprocess_results"
def get_reconstructed_scene(outdir, model, device):
image_files = sorted(
[
os.path.join(outdir, "images", f)
for f in os.listdir(os.path.join(outdir, "images"))
]
)
images = [process_image(img_path) for img_path in image_files]
images = torch.stack(images, dim=0).unsqueeze(0).to(device) # [1, K, 3, 448, 448]
b, v, c, h, w = images.shape
assert c == 3, "Images must have 3 channels"
gaussians, pred_context_pose = model.inference((images + 1) * 0.5)
pred_all_extrinsic = pred_context_pose["extrinsic"]
pred_all_intrinsic = pred_context_pose["intrinsic"]
video, depth_colored = save_interpolated_video(
pred_all_extrinsic,
pred_all_intrinsic,
b,
h,
w,
gaussians,
outdir,
model.decoder,
)
plyfile = os.path.join(outdir, "gaussians.ply")
glbfile = os.path.join(outdir, "gaussians.glb")
export_ply(
gaussians.means[0],
gaussians.scales[0],
gaussians.rotations[0],
gaussians.harmonics[0],
gaussians.opacities[0],
Path(plyfile),
save_sh_dc_only=True,
)
import trimesh
import numpy as np
# 1. Load PLY and preserve attributes
mesh = trimesh.load(plyfile, process=False)
# 2. Check or assign vertex colors
if mesh.visual.vertex_colors is None or mesh.visual.vertex_colors.shape[1] < 4:
# Example: assume mesh.metadata['vertex_color'] holds (N×3) array
rgb = np.array(mesh.metadata['vertex_color'], dtype=np.uint8)
alpha = np.full((rgb.shape[0], 1), 255, dtype=np.uint8)
mesh.visual.vertex_colors = np.concatenate([rgb, alpha], axis=1)
# 3. Export GLB
mesh.export(glbfile, file_type='glb')
print("Export complete: scene_colored.glb")
# Clean up
torch.cuda.empty_cache()
return glbfile, video, depth_colored
# 2) Handle uploaded video/images --> produce target_dir + images
def extract_frames(input_video, session_id):
"""
Create a new 'target_dir' + 'images' subfolder, and place user-uploaded
images or extracted frames from video into it. Return (target_dir, image_paths).
"""
start_time = time.time()
gc.collect()
torch.cuda.empty_cache()
base_dir = os.path.join(os.environ["ANYSPLAT_PROCESSED"], session_id)
target_dir = base_dir
target_dir_images = os.path.join(target_dir, "images")
# Clean up if somehow that folder already exists
if os.path.exists(target_dir):
shutil.rmtree(target_dir)
os.makedirs(target_dir)
os.makedirs(target_dir_images)
image_paths = []
if input_video is not None:
if isinstance(input_video, dict) and "name" in input_video:
video_path = input_video["name"]
else:
video_path = input_video
vs = cv2.VideoCapture(video_path)
fps = vs.get(cv2.CAP_PROP_FPS)
frame_interval = int(fps * 1) # 1 frame/sec
count = 0
video_frame_num = 0
while True:
gotit, frame = vs.read()
if not gotit:
break
count += 1
if count % frame_interval == 0:
image_path = os.path.join(
target_dir_images, f"{video_frame_num:06}.png"
)
cv2.imwrite(image_path, frame)
image_paths.append(image_path)
video_frame_num += 1
# Sort final images for gallery
image_paths = sorted(image_paths)
end_time = time.time()
print(
f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds"
)
return target_dir, image_paths
def update_gallery_on_upload(input_video, session_id):
"""
Whenever user uploads or changes files, immediately handle them
and show in the gallery. Return (target_dir, image_paths).
If nothing is uploaded, returns "None" and empty list.
"""
if not input_video and not input_images:
return None, None, None
target_dir, image_paths = extract_frames(input_video, session_id)
return None, target_dir, image_paths
@spaces.GPU()
def generate_splats_from_video(video_path, session_id=None):
if session_id is None:
session_id = uuid.uuid4().hex
images_folder, image_paths = extract_frames(video_path, session_id)
plyfile, rgb_vid, depth_vid = generate_splats_from_images(images_folder, session_id)
return plyfile, rgb_vid, depth_vid, image_paths
@spaces.GPU()
def generate_splats_from_images(images_folder, session_id=None):
if session_id is None:
session_id = uuid.uuid4().hex
start_time = time.time()
gc.collect()
torch.cuda.empty_cache()
base_dir = os.path.join(os.environ["ANYSPLAT_PROCESSED"], session_id)
all_files = (
sorted(os.listdir(images_folder))
if os.path.isdir(images_folder)
else []
)
all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
print("Running run_model...")
with torch.no_grad():
plyfile, video, depth_colored = get_reconstructed_scene(base_dir, model, device)
end_time = time.time()
print(f"Total time: {end_time - start_time:.2f} seconds (including IO)")
return plyfile, video, depth_colored
def cleanup(request: gr.Request):
"""
Clean up session-specific directories and temporary files when the user session ends.
This function is triggered when the Gradio demo is unloaded (e.g., when the user
closes the browser tab or navigates away). It removes all temporary files and
directories created during the user's session to free up storage space.
Args:
request (gr.Request): Gradio request object containing session information
"""
sid = request.session_hash
if sid:
d1 = os.path.join(os.environ["ANYSPLAT_PROCESSED"], sid)
shutil.rmtree(d1, ignore_errors=True)
def start_session(request: gr.Request):
"""
Initialize a new user session and return the session identifier.
This function is triggered when the Gradio demo loads and creates a unique
session hash that will be used to organize outputs and temporary files
for this specific user session.
Args:
request (gr.Request): Gradio request object containing session information
Returns:
str: Unique session hash identifier
"""
return request.session_hash
if __name__ == "__main__":
share = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model
model = AnySplat.from_pretrained(
"lhjiang/anysplat"
)
model = model.to(device)
model.eval()
for param in model.parameters():
param.requires_grad = False
theme = gr.themes.Ocean()
theme.set(
checkbox_label_background_fill_selected="*button_primary_background_fill",
checkbox_label_text_color_selected="*button_primary_text_color",
)
css = """
#col-container {
margin: 0 auto;
max-width: 1024px;
}
"""
with gr.Blocks(css=css, title="AnySplat Demo", theme=theme) as demo:
session_state = gr.State()
demo.load(start_session, outputs=[session_state])
target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
is_example = gr.Textbox(label="is_example", visible=False, value="None")
num_images = gr.Textbox(label="num_images", visible=False, value="None")
dataset_name = gr.Textbox(label="dataset_name", visible=False, value="None")
scene_name = gr.Textbox(label="scene_name", visible=False, value="None")
image_type = gr.Textbox(label="image_type", visible=False, value="None")
with gr.Column(elem_id="col-container"):
gr.HTML(
"""
<div style="text-align: center;">
<p style="font-size:16px; display: inline; margin: 0;">
<strong>AnySplat</strong> – Feed-forward 3D Gaussian Splatting from Unconstrained Views
</p>
<a href="https://github.com/OpenRobotLab/AnySplat" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
<img src="https://img.shields.io/badge/GitHub-Repo-blue" alt="GitHub Repo">
</a>
</div>
"""
)
with gr.Row():
with gr.Column():
input_video = gr.Video(label="Upload Video", interactive=True, height=512)
submit_btn = gr.Button(
"Reconstruct", scale=1, variant="primary"
)
image_gallery = gr.Gallery(
label="Preview",
columns=4,
height="300px",
show_download_button=True,
object_fit="contain",
preview=True,
)
with gr.Column():
with gr.Column():
reconstruction_output = gr.Model3D(
label="3D Reconstructed Gaussian Splat",
height=512,
zoom_speed=0.5,
pan_speed=0.5,
camera_position=[20, 20, 20],
)
with gr.Row():
rgb_video = gr.Video(
label="RGB Video", interactive=False, autoplay=True
)
depth_video = gr.Video(
label="Depth Video",
interactive=False,
autoplay=True,
)
with gr.Row():
examples = [
["examples/video/re10k_1eca36ec55b88fe4.mp4"],
["examples/video/bungeenerf_colosseum.mp4"],
["examples/video/fox.mp4"],
["examples/video/matrixcity_street.mp4"],
["examples/video/vrnerf_apartment.mp4"],
# [None, "examples/video/vrnerf_kitchen.mp4", "vrnerf", "kitchen", "17", "Real", "True",],
# [None, "examples/video/vrnerf_riverview.mp4", "vrnerf", "riverview", "12", "Real", "True",],
# [None, "examples/video/vrnerf_workshop.mp4", "vrnerf", "workshop", "32", "Real", "True",],
# [None, "examples/video/fillerbuster_ramen.mp4", "fillerbuster", "ramen", "32", "Real", "True",],
# [None, "examples/video/meganerf_rubble.mp4", "meganerf", "rubble", "10", "Real", "True",],
# [None, "examples/video/llff_horns.mp4", "llff", "horns", "12", "Real", "True",],
# [None, "examples/video/llff_fortress.mp4", "llff", "fortress", "7", "Real", "True",],
# [None, "examples/video/dtu_scan_106.mp4", "dtu", "scan_106", "20", "Real", "True",],
# [None, "examples/video/horizongs_hillside_summer.mp4", "horizongs", "hillside_summer", "55", "Synthetic", "True",],
# [None, "examples/video/kitti360.mp4", "kitti360", "kitti360", "64", "Real", "True",],
]
gr.Examples(
examples=examples,
inputs=[
input_video
],
outputs=[
reconstruction_output,
rgb_video,
depth_video,
image_gallery
],
fn=generate_splats_from_video,
cache_examples=True,
)
submit_btn.click(
fn=generate_splats_from_images,
inputs=[target_dir_output, session_state],
outputs=[reconstruction_output, rgb_video, depth_video])
input_video.change(
fn=update_gallery_on_upload,
inputs=[input_video, session_state],
outputs=[reconstruction_output, target_dir_output, image_gallery],
)
demo.unload(cleanup)
demo.queue()
demo.launch(show_error=True, share=True)