AnySplat / app.py
alexnasa's picture
Update app.py
1c31d44 verified
raw
history blame
16.4 kB
import spaces
import torch
print(f'torch version:{torch.__version__}')
import huggingface_hub
print(f' huggingface_hub.__version__ {huggingface_hub.__version__}')
import functools
import gc
import os
os.environ['TORCH_CUDA_ARCH_LIST'] = '9.0'
import subprocess
import shutil
import sys
import tempfile
import time
from datetime import datetime
from pathlib import Path
import uuid
import cv2
import gradio as gr
from huggingface_hub import hf_hub_download
from PIL import Image
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.misc.image_io import save_interpolated_video
from src.model.model.anysplat import AnySplat
from src.model.ply_export import export_ply
from src.utils.image import process_image
os.environ["ANYSPLAT_PROCESSED"] = f"{os.getcwd()}/proprocess_results"
from plyfile import PlyData
import numpy as np
import argparse
from io import BytesIO
def process_ply_to_splat(ply_file_path):
plydata = PlyData.read(ply_file_path)
vert = plydata["vertex"]
sorted_indices = np.argsort(
-np.exp(vert["scale_0"] + vert["scale_1"] + vert["scale_2"])
/ (1 + np.exp(-vert["opacity"]))
)
buffer = BytesIO()
for idx in sorted_indices:
v = plydata["vertex"][idx]
position = np.array([v["x"], v["y"], v["z"]], dtype=np.float32)
scales = np.exp(
np.array(
[v["scale_0"], v["scale_1"], v["scale_2"]],
dtype=np.float32,
)
)
rot = np.array(
[v["rot_0"], v["rot_1"], v["rot_2"], v["rot_3"]],
dtype=np.float32,
)
SH_C0 = 0.28209479177387814
color = np.array(
[
0.5 + SH_C0 * v["f_dc_0"],
0.5 + SH_C0 * v["f_dc_1"],
0.5 + SH_C0 * v["f_dc_2"],
1 / (1 + np.exp(-v["opacity"])),
]
)
buffer.write(position.tobytes())
buffer.write(scales.tobytes())
buffer.write((color * 255).clip(0, 255).astype(np.uint8).tobytes())
buffer.write(
((rot / np.linalg.norm(rot)) * 128 + 128)
.clip(0, 255)
.astype(np.uint8)
.tobytes()
)
return buffer.getvalue()
def save_splat_file(splat_data, output_path):
with open(output_path, "wb") as f:
f.write(splat_data)
def get_reconstructed_scene(outdir, image_files, model, device):
images = [process_image(img_path) for img_path in image_files]
images = torch.stack(images, dim=0).unsqueeze(0).to(device) # [1, K, 3, 448, 448]
b, v, c, h, w = images.shape
assert c == 3, "Images must have 3 channels"
gaussians, pred_context_pose = model.inference((images + 1) * 0.5)
pred_all_extrinsic = pred_context_pose["extrinsic"]
pred_all_intrinsic = pred_context_pose["intrinsic"]
video, depth_colored = save_interpolated_video(
pred_all_extrinsic,
pred_all_intrinsic,
b,
h,
w,
gaussians,
outdir,
model.decoder,
)
plyfile = os.path.join(outdir, "gaussians.ply")
# splatfile = os.path.join(outdir, "gaussians.splat")
export_ply(
gaussians.means[0],
gaussians.scales[0],
gaussians.rotations[0],
gaussians.harmonics[0],
gaussians.opacities[0],
Path(plyfile),
save_sh_dc_only=True,
)
# splat_data = process_ply_to_splat(plyfile)
# save_splat_file(splat_data, splatfile)
# Clean up
torch.cuda.empty_cache()
return plyfile, video, depth_colored
def extract_images(input_images, session_id):
start_time = time.time()
gc.collect()
torch.cuda.empty_cache()
base_dir = os.path.join(os.environ["ANYSPLAT_PROCESSED"], session_id)
target_dir = base_dir
target_dir_images = os.path.join(target_dir, "images")
if os.path.exists(target_dir):
shutil.rmtree(target_dir)
os.makedirs(target_dir)
os.makedirs(target_dir_images)
image_paths = []
if input_images is not None:
for file_data in input_images:
if isinstance(file_data, dict) and "name" in file_data:
file_path = file_data["name"]
else:
file_path = file_data
dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
shutil.copy(file_path, dst_path)
image_paths.append(dst_path)
end_time = time.time()
print(
f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds"
)
return target_dir, image_paths
def extract_frames(input_video, session_id):
start_time = time.time()
gc.collect()
torch.cuda.empty_cache()
base_dir = os.path.join(os.environ["ANYSPLAT_PROCESSED"], session_id)
target_dir = base_dir
target_dir_images = os.path.join(target_dir, "images")
if os.path.exists(target_dir):
shutil.rmtree(target_dir)
os.makedirs(target_dir)
os.makedirs(target_dir_images)
image_paths = []
if input_video is not None:
if isinstance(input_video, dict) and "name" in input_video:
video_path = input_video["name"]
else:
video_path = input_video
vs = cv2.VideoCapture(video_path)
fps = vs.get(cv2.CAP_PROP_FPS)
frame_interval = int(fps * 1) # 1 frame/sec
count = 0
video_frame_num = 0
while True:
gotit, frame = vs.read()
if not gotit:
break
count += 1
if count % frame_interval == 0:
image_path = os.path.join(
target_dir_images, f"{video_frame_num:06}.png"
)
cv2.imwrite(image_path, frame)
image_paths.append(image_path)
video_frame_num += 1
# Sort final images for gallery
image_paths = sorted(image_paths)
end_time = time.time()
print(
f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds"
)
return target_dir, image_paths
def update_gallery_on_video_upload(input_video, session_id):
if not input_video:
return None, None, None
target_dir, image_paths = extract_frames(input_video, session_id)
return None, target_dir, image_paths
def update_gallery_on_images_upload(input_images, session_id):
if not input_images:
return None, None, None
target_dir, image_paths = extract_images(input_images, session_id)
return None, target_dir, image_paths
@spaces.GPU()
def generate_splats_from_video(video_path, session_id=None):
"""
Perform Gaussian Splatting from Unconstrained Views a Given Video, using a Feed-forward model.
Args:
video_path (str): Path to the input video file on disk.
Returns:
plyfile: Path to the reconstructed 3D object from the given video.
rgb_vid: Path the the interpolated rgb video, increasing the frame rate using guassian splatting and interpolation of frames.
depth_vid: Path the the interpolated depth video, increasing the frame rate using guassian splatting and interpolation of frames.
image_paths: A list of paths from extracted frame from the video that is used for training Gaussian Splatting.
"""
if session_id is None:
session_id = uuid.uuid4().hex
images_folder, image_paths = extract_frames(video_path, session_id)
plyfile, rgb_vid, depth_vid = generate_splats_from_images(image_paths, session_id)
return plyfile, rgb_vid, depth_vid, image_paths
@spaces.GPU()
def generate_splats_from_images(image_paths, session_id=None):
"""
Perform Gaussian Splatting from Unconstrained Views a Given Images , using a Feed-forward model.
Args:
image_paths (str): Path to the input image files on disk.
Returns:
plyfile: Path to the reconstructed 3D object from the given image files.
rgb_vid: Path the the interpolated rgb video, increasing the frame rate using guassian splatting and interpolation of frames.
depth_vid: Path the the interpolated depth video, increasing the frame rate using guassian splatting and interpolation of frames.
"""
processed_image_paths = []
for file_data in image_paths:
if isinstance(file_data, tuple):
file_path, _ = file_data
processed_image_paths.append(file_path)
else:
processed_image_paths.append(file_data)
image_paths = processed_image_paths
print(image_paths)
if len(image_paths) == 1:
image_paths.append(image_paths[0])
if session_id is None:
session_id = uuid.uuid4().hex
start_time = time.time()
gc.collect()
torch.cuda.empty_cache()
base_dir = os.path.join(os.environ["ANYSPLAT_PROCESSED"], session_id)
print("Running run_model...")
with torch.no_grad():
plyfile, rgb_vid, depth_vid = get_reconstructed_scene(base_dir, image_paths, model, device)
end_time = time.time()
print(f"Total time: {end_time - start_time:.2f} seconds (including IO)")
return plyfile, rgb_vid, depth_vid
def cleanup(request: gr.Request):
sid = request.session_hash
if sid:
d1 = os.path.join(os.environ["ANYSPLAT_PROCESSED"], sid)
shutil.rmtree(d1, ignore_errors=True)
def start_session(request: gr.Request):
return request.session_hash
if __name__ == "__main__":
share = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model
model = AnySplat.from_pretrained(
"lhjiang/anysplat"
)
model = model.to(device)
model.eval()
for param in model.parameters():
param.requires_grad = False
css = """
#col-container {
margin: 0 auto;
max-width: 1024px;
}
"""
with gr.Blocks(css=css, title="AnySplat Demo") as demo:
session_state = gr.State()
demo.load(start_session, outputs=[session_state])
target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
is_example = gr.Textbox(label="is_example", visible=False, value="None")
num_images = gr.Textbox(label="num_images", visible=False, value="None")
dataset_name = gr.Textbox(label="dataset_name", visible=False, value="None")
scene_name = gr.Textbox(label="scene_name", visible=False, value="None")
image_type = gr.Textbox(label="image_type", visible=False, value="None")
with gr.Column(elem_id="col-container"):
gr.HTML(
"""
<div style="text-align: center;">
<p style="font-size:16px; display: inline; margin: 0;">
<strong>AnySplat</strong> – Feed-forward 3D Gaussian Splatting from Unconstrained Views
</p>
<a href="https://github.com/OpenRobotLab/AnySplat" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
<img src="https://img.shields.io/badge/GitHub-Repo-blue" alt="GitHub Repo">
</a>
</div>
"""
)
with gr.Row():
with gr.Column():
with gr.Tab("Video"):
input_video = gr.Video(label="Upload Video", sources=["upload"], interactive=True, height=512)
with gr.Tab("Images"):
input_images = gr.File(file_count="multiple", label="Upload Files", height=512)
submit_btn = gr.Button(
"πŸ–ŒοΈ Generate Gaussian Splat", scale=1, variant="primary"
)
image_gallery = gr.Gallery(
label="Preview",
columns=4,
height="300px",
show_download_button=True,
object_fit="contain",
preview=True,
)
with gr.Column():
with gr.Column():
gr.HTML(
"""
<p style="opacity: 0.6; font-style: italic;">
This might take a few seconds to load the 3D model
</p>
"""
)
reconstruction_output = gr.Model3D(
label="Ply Gaussian Model",
height=512,
zoom_speed=0.5,
pan_speed=0.5,
# camera_position=[20, 20, 20],
)
with gr.Row():
rgb_video = gr.Video(
label="RGB Video", interactive=False, autoplay=True
)
depth_video = gr.Video(
label="Depth Video",
interactive=False,
autoplay=True,
)
with gr.Row():
examples = [
["examples/video/re10k_1eca36ec55b88fe4.mp4"],
["examples/video/spann3r.mp4"],
["examples/video/bungeenerf_colosseum.mp4"],
["examples/video/fox.mp4"],
["examples/video/vrnerf_apartment.mp4"],
# [None, "examples/video/vrnerf_kitchen.mp4", "vrnerf", "kitchen", "17", "Real", "True",],
# [None, "examples/video/vrnerf_riverview.mp4", "vrnerf", "riverview", "12", "Real", "True",],
# [None, "examples/video/vrnerf_workshop.mp4", "vrnerf", "workshop", "32", "Real", "True",],
# [None, "examples/video/fillerbuster_ramen.mp4", "fillerbuster", "ramen", "32", "Real", "True",],
# [None, "examples/video/meganerf_rubble.mp4", "meganerf", "rubble", "10", "Real", "True",],
# [None, "examples/video/llff_horns.mp4", "llff", "horns", "12", "Real", "True",],
# [None, "examples/video/llff_fortress.mp4", "llff", "fortress", "7", "Real", "True",],
# [None, "examples/video/dtu_scan_106.mp4", "dtu", "scan_106", "20", "Real", "True",],
# [None, "examples/video/horizongs_hillside_summer.mp4", "horizongs", "hillside_summer", "55", "Synthetic", "True",],
# [None, "examples/video/kitti360.mp4", "kitti360", "kitti360", "64", "Real", "True",],
]
gr.Examples(
examples=examples,
inputs=[
input_video
],
outputs=[
reconstruction_output,
rgb_video,
depth_video,
image_gallery
],
fn=generate_splats_from_video,
cache_examples=True,
)
submit_btn.click(
fn=generate_splats_from_images,
inputs=[image_gallery, session_state],
outputs=[reconstruction_output, rgb_video, depth_video])
input_video.upload(
fn=update_gallery_on_video_upload,
inputs=[input_video, session_state],
outputs=[reconstruction_output, target_dir_output, image_gallery],
show_api=False
)
input_images.upload(
fn=update_gallery_on_images_upload,
inputs=[input_images, session_state],
outputs=[reconstruction_output, target_dir_output, image_gallery],
show_api=False
)
demo.unload(cleanup)
demo.queue()
demo.launch(show_error=True, share=True, mcp_server=True)