|
import spaces |
|
import torch |
|
|
|
print(f'torch version:{torch.__version__}') |
|
import functools |
|
import gc |
|
import os |
|
import subprocess |
|
|
|
import shutil |
|
import sys |
|
import tempfile |
|
import time |
|
from datetime import datetime |
|
from pathlib import Path |
|
import uuid |
|
|
|
import cv2 |
|
import gradio as gr |
|
|
|
from huggingface_hub import hf_hub_download |
|
from PIL import Image |
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
|
from src.misc.image_io import save_interpolated_video |
|
from src.model.model.anysplat import AnySplat |
|
from src.model.ply_export import export_ply |
|
from src.utils.image import process_image |
|
|
|
os.environ["ANYSPLAT_PROCESSED"] = f"{os.getcwd()}/proprocess_results" |
|
|
|
from plyfile import PlyData |
|
import numpy as np |
|
import argparse |
|
from io import BytesIO |
|
|
|
|
|
def process_ply_to_splat(ply_file_path): |
|
plydata = PlyData.read(ply_file_path) |
|
vert = plydata["vertex"] |
|
sorted_indices = np.argsort( |
|
-np.exp(vert["scale_0"] + vert["scale_1"] + vert["scale_2"]) |
|
/ (1 + np.exp(-vert["opacity"])) |
|
) |
|
buffer = BytesIO() |
|
for idx in sorted_indices: |
|
v = plydata["vertex"][idx] |
|
position = np.array([v["x"], v["y"], v["z"]], dtype=np.float32) |
|
scales = np.exp( |
|
np.array( |
|
[v["scale_0"], v["scale_1"], v["scale_2"]], |
|
dtype=np.float32, |
|
) |
|
) |
|
rot = np.array( |
|
[v["rot_0"], v["rot_1"], v["rot_2"], v["rot_3"]], |
|
dtype=np.float32, |
|
) |
|
SH_C0 = 0.28209479177387814 |
|
color = np.array( |
|
[ |
|
0.5 + SH_C0 * v["f_dc_0"], |
|
0.5 + SH_C0 * v["f_dc_1"], |
|
0.5 + SH_C0 * v["f_dc_2"], |
|
1 / (1 + np.exp(-v["opacity"])), |
|
] |
|
) |
|
buffer.write(position.tobytes()) |
|
buffer.write(scales.tobytes()) |
|
buffer.write((color * 255).clip(0, 255).astype(np.uint8).tobytes()) |
|
buffer.write( |
|
((rot / np.linalg.norm(rot)) * 128 + 128) |
|
.clip(0, 255) |
|
.astype(np.uint8) |
|
.tobytes() |
|
) |
|
|
|
return buffer.getvalue() |
|
|
|
def save_splat_file(splat_data, output_path): |
|
with open(output_path, "wb") as f: |
|
f.write(splat_data) |
|
|
|
def get_reconstructed_scene(outdir, image_files, model, device): |
|
|
|
images = [process_image(img_path) for img_path in image_files] |
|
images = torch.stack(images, dim=0).unsqueeze(0).to(device) |
|
b, v, c, h, w = images.shape |
|
|
|
assert c == 3, "Images must have 3 channels" |
|
|
|
gaussians, pred_context_pose = model.inference((images + 1) * 0.5) |
|
|
|
pred_all_extrinsic = pred_context_pose["extrinsic"] |
|
pred_all_intrinsic = pred_context_pose["intrinsic"] |
|
video, depth_colored = save_interpolated_video( |
|
pred_all_extrinsic, |
|
pred_all_intrinsic, |
|
b, |
|
h, |
|
w, |
|
gaussians, |
|
outdir, |
|
model.decoder, |
|
) |
|
plyfile = os.path.join(outdir, "gaussians.ply") |
|
|
|
|
|
export_ply( |
|
gaussians.means[0], |
|
gaussians.scales[0], |
|
gaussians.rotations[0], |
|
gaussians.harmonics[0], |
|
gaussians.opacities[0], |
|
Path(plyfile), |
|
save_sh_dc_only=True, |
|
) |
|
|
|
|
|
|
|
|
|
torch.cuda.empty_cache() |
|
return plyfile, video, depth_colored |
|
|
|
def extract_images(input_images, session_id): |
|
|
|
start_time = time.time() |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
base_dir = os.path.join(os.environ["ANYSPLAT_PROCESSED"], session_id) |
|
target_dir = base_dir |
|
target_dir_images = os.path.join(target_dir, "images") |
|
|
|
if os.path.exists(target_dir): |
|
shutil.rmtree(target_dir) |
|
|
|
os.makedirs(target_dir) |
|
os.makedirs(target_dir_images) |
|
|
|
image_paths = [] |
|
|
|
if input_images is not None: |
|
for file_data in input_images: |
|
if isinstance(file_data, dict) and "name" in file_data: |
|
file_path = file_data["name"] |
|
else: |
|
file_path = file_data |
|
dst_path = os.path.join(target_dir_images, os.path.basename(file_path)) |
|
shutil.copy(file_path, dst_path) |
|
image_paths.append(dst_path) |
|
|
|
end_time = time.time() |
|
print( |
|
f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds" |
|
) |
|
return target_dir, image_paths |
|
|
|
|
|
def extract_frames(input_video, session_id): |
|
|
|
start_time = time.time() |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
base_dir = os.path.join(os.environ["ANYSPLAT_PROCESSED"], session_id) |
|
target_dir = base_dir |
|
target_dir_images = os.path.join(target_dir, "images") |
|
|
|
if os.path.exists(target_dir): |
|
shutil.rmtree(target_dir) |
|
|
|
os.makedirs(target_dir) |
|
os.makedirs(target_dir_images) |
|
|
|
image_paths = [] |
|
|
|
if input_video is not None: |
|
if isinstance(input_video, dict) and "name" in input_video: |
|
video_path = input_video["name"] |
|
else: |
|
video_path = input_video |
|
|
|
vs = cv2.VideoCapture(video_path) |
|
fps = vs.get(cv2.CAP_PROP_FPS) |
|
frame_interval = int(fps * 1) |
|
|
|
count = 0 |
|
video_frame_num = 0 |
|
while True: |
|
gotit, frame = vs.read() |
|
if not gotit: |
|
break |
|
count += 1 |
|
if count % frame_interval == 0: |
|
image_path = os.path.join( |
|
target_dir_images, f"{video_frame_num:06}.png" |
|
) |
|
cv2.imwrite(image_path, frame) |
|
image_paths.append(image_path) |
|
video_frame_num += 1 |
|
|
|
|
|
image_paths = sorted(image_paths) |
|
|
|
end_time = time.time() |
|
print( |
|
f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds" |
|
) |
|
return target_dir, image_paths |
|
|
|
|
|
def update_gallery_on_video_upload(input_video, session_id): |
|
if not input_video: |
|
return None, None, None |
|
|
|
target_dir, image_paths = extract_frames(input_video, session_id) |
|
return None, target_dir, image_paths |
|
|
|
def update_gallery_on_images_upload(input_images, session_id): |
|
|
|
if not input_images: |
|
return None, None, None |
|
|
|
target_dir, image_paths = extract_images(input_images, session_id) |
|
return None, target_dir, image_paths |
|
|
|
@spaces.GPU() |
|
def generate_splats_from_video(video_path, session_id=None): |
|
""" |
|
Perform Gaussian Splatting from Unconstrained Views a Given Video, using a Feed-forward model. |
|
|
|
Args: |
|
video_path (str): Path to the input video file on disk. |
|
Returns: |
|
plyfile: Path to the reconstructed 3D object from the given video. |
|
rgb_vid: Path the the interpolated rgb video, increasing the frame rate using guassian splatting and interpolation of frames. |
|
depth_vid: Path the the interpolated depth video, increasing the frame rate using guassian splatting and interpolation of frames. |
|
image_paths: A list of paths from extracted frame from the video that is used for training Gaussian Splatting. |
|
""" |
|
|
|
if session_id is None: |
|
session_id = uuid.uuid4().hex |
|
|
|
images_folder, image_paths = extract_frames(video_path, session_id) |
|
plyfile, rgb_vid, depth_vid = generate_splats_from_images(image_paths, session_id) |
|
|
|
return plyfile, rgb_vid, depth_vid, image_paths |
|
|
|
@spaces.GPU() |
|
def generate_splats_from_images(image_paths, session_id=None): |
|
""" |
|
Perform Gaussian Splatting from Unconstrained Views a Given Images , using a Feed-forward model. |
|
|
|
Args: |
|
image_paths (str): Path to the input image files on disk. |
|
Returns: |
|
plyfile: Path to the reconstructed 3D object from the given image files. |
|
rgb_vid: Path the the interpolated rgb video, increasing the frame rate using guassian splatting and interpolation of frames. |
|
depth_vid: Path the the interpolated depth video, increasing the frame rate using guassian splatting and interpolation of frames. |
|
""" |
|
processed_image_paths = [] |
|
|
|
for file_data in image_paths: |
|
if isinstance(file_data, tuple): |
|
file_path, _ = file_data |
|
processed_image_paths.append(file_path) |
|
else: |
|
processed_image_paths.append(file_data) |
|
|
|
image_paths = processed_image_paths |
|
print(image_paths) |
|
|
|
if len(image_paths) == 1: |
|
image_paths.append(image_paths[0]) |
|
|
|
if session_id is None: |
|
session_id = uuid.uuid4().hex |
|
|
|
start_time = time.time() |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
base_dir = os.path.join(os.environ["ANYSPLAT_PROCESSED"], session_id) |
|
|
|
print("Running run_model...") |
|
with torch.no_grad(): |
|
plyfile, rgb_vid, depth_vid = get_reconstructed_scene(base_dir, image_paths, model, device) |
|
|
|
end_time = time.time() |
|
print(f"Total time: {end_time - start_time:.2f} seconds (including IO)") |
|
|
|
return plyfile, rgb_vid, depth_vid |
|
|
|
def cleanup(request: gr.Request): |
|
|
|
sid = request.session_hash |
|
if sid: |
|
d1 = os.path.join(os.environ["ANYSPLAT_PROCESSED"], sid) |
|
shutil.rmtree(d1, ignore_errors=True) |
|
|
|
def start_session(request: gr.Request): |
|
|
|
return request.session_hash |
|
|
|
|
|
if __name__ == "__main__": |
|
share = True |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
model = AnySplat.from_pretrained( |
|
"lhjiang/anysplat" |
|
) |
|
model = model.to(device) |
|
model.eval() |
|
for param in model.parameters(): |
|
param.requires_grad = False |
|
|
|
css = """ |
|
#col-container { |
|
margin: 0 auto; |
|
max-width: 1024px; |
|
} |
|
""" |
|
with gr.Blocks(css=css, title="AnySplat Demo") as demo: |
|
session_state = gr.State() |
|
demo.load(start_session, outputs=[session_state]) |
|
|
|
target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None") |
|
is_example = gr.Textbox(label="is_example", visible=False, value="None") |
|
num_images = gr.Textbox(label="num_images", visible=False, value="None") |
|
dataset_name = gr.Textbox(label="dataset_name", visible=False, value="None") |
|
scene_name = gr.Textbox(label="scene_name", visible=False, value="None") |
|
image_type = gr.Textbox(label="image_type", visible=False, value="None") |
|
|
|
with gr.Column(elem_id="col-container"): |
|
|
|
gr.HTML( |
|
""" |
|
<div style="text-align: center;"> |
|
<p style="font-size:16px; display: inline; margin: 0;"> |
|
<strong>AnySplat</strong> – Feed-forward 3D Gaussian Splatting from Unconstrained Views |
|
</p> |
|
<a href="https://github.com/OpenRobotLab/AnySplat" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;"> |
|
<img src="https://img.shields.io/badge/GitHub-Repo-blue" alt="GitHub Repo"> |
|
</a> |
|
</div> |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
|
|
with gr.Tab("Video"): |
|
input_video = gr.Video(label="Upload Video", sources=["upload"], interactive=True, height=512) |
|
with gr.Tab("Images"): |
|
input_images = gr.File(file_count="multiple", label="Upload Files", height=512) |
|
|
|
submit_btn = gr.Button( |
|
"Generate Gaussian Splat", scale=1, variant="primary" |
|
) |
|
|
|
image_gallery = gr.Gallery( |
|
label="Preview", |
|
columns=4, |
|
height="300px", |
|
show_download_button=True, |
|
object_fit="contain", |
|
preview=True, |
|
) |
|
|
|
with gr.Column(): |
|
with gr.Column(): |
|
gr.HTML( |
|
""" |
|
<p style="opacity: 0.6; font-style: italic;"> |
|
This might take a few seconds to load the 3D model |
|
</p> |
|
""" |
|
) |
|
reconstruction_output = gr.Model3D( |
|
label="Ply Gaussian Model", |
|
height=512, |
|
zoom_speed=0.5, |
|
pan_speed=0.5, |
|
|
|
) |
|
|
|
with gr.Row(): |
|
rgb_video = gr.Video( |
|
label="RGB Video", interactive=False, autoplay=True |
|
) |
|
depth_video = gr.Video( |
|
label="Depth Video", |
|
interactive=False, |
|
autoplay=True, |
|
) |
|
with gr.Row(): |
|
examples = [ |
|
["examples/video/re10k_1eca36ec55b88fe4.mp4"], |
|
["examples/video/spann3r.mp4"], |
|
["examples/video/bungeenerf_colosseum.mp4"], |
|
["examples/video/fox.mp4"], |
|
["examples/video/vrnerf_apartment.mp4"], |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
] |
|
|
|
gr.Examples( |
|
examples=examples, |
|
inputs=[ |
|
input_video |
|
], |
|
outputs=[ |
|
reconstruction_output, |
|
rgb_video, |
|
depth_video, |
|
image_gallery |
|
], |
|
fn=generate_splats_from_video, |
|
cache_examples=True, |
|
) |
|
|
|
|
|
submit_btn.click( |
|
fn=generate_splats_from_images, |
|
inputs=[image_gallery, session_state], |
|
outputs=[reconstruction_output, rgb_video, depth_video]) |
|
|
|
input_video.upload( |
|
fn=update_gallery_on_video_upload, |
|
inputs=[input_video, session_state], |
|
outputs=[reconstruction_output, target_dir_output, image_gallery], |
|
show_api=False |
|
) |
|
|
|
input_images.upload( |
|
fn=update_gallery_on_images_upload, |
|
inputs=[input_images, session_state], |
|
outputs=[reconstruction_output, target_dir_output, image_gallery], |
|
show_api=False |
|
) |
|
|
|
demo.unload(cleanup) |
|
demo.queue() |
|
demo.launch(show_error=True, share=True, mcp_server=True) |