Spaces:
Paused
Paused
import os | |
import torch | |
import gradio as gr | |
from PIL import Image | |
from inference.flovd_demo import generate_video | |
from huggingface_hub import snapshot_download | |
# ----------------------------------- | |
# Step 1: Download model checkpoints | |
# ----------------------------------- | |
hf_token = os.getenv("HF_TOKEN", None) | |
snapshot_download( | |
repo_id="roll-ai/FloVD-weights", | |
repo_type="dataset", | |
local_dir="./", | |
allow_patterns="ckpt/**", | |
token=hf_token, | |
) | |
# ----------------------------------- | |
# Step 2: Setup paths and config | |
# ----------------------------------- | |
BASE_DIR = "app" | |
FVSM_PATH = os.path.join(BASE_DIR, "ckpt/FVSM/FloVD_FVSM_Controlnet.pt") | |
OMSM_PATH = os.path.join(BASE_DIR, "ckpt/OMSM/") | |
DEPTH_CKPT_PATH = os.path.join(BASE_DIR, "ckpt/others/depth_anything_v2_metric_hypersim_vitb.pth") | |
OUTPUT_PATH = os.path.join(BASE_DIR, "results") | |
GEN_VID_DIR = os.path.join(OUTPUT_PATH, "generated_videos") | |
POSE_TYPE = "re10k" | |
CONTROLNET_GUIDANCE_END = 0.4 | |
SPEED = 1.0 | |
NUM_FRAMES = 81 | |
FPS = 16 | |
INFER_STEPS = 50 | |
os.makedirs(GEN_VID_DIR, exist_ok=True) | |
# ----------------------------------- | |
# Helper Functions | |
# ----------------------------------- | |
def list_generated_videos(): | |
try: | |
return sorted([ | |
f for f in os.listdir(GEN_VID_DIR) | |
if f.endswith(".mp4") | |
]) | |
except Exception as e: | |
print(f"[β οΈ] Could not list contents: {str(e)}") | |
return [] | |
def run_flovd(prompt, image, cam_pose_name): | |
try: | |
print("\n----------------------------") | |
print("π Starting video generation") | |
print("----------------------------") | |
image_path = os.path.join(BASE_DIR, "temp_input.png") | |
image.save(image_path) | |
print(f"πΈ Image saved at {image_path}") | |
generate_video( | |
prompt=prompt, | |
fvsm_path=FVSM_PATH, | |
omsm_path=OMSM_PATH, | |
image_path=image_path, | |
cam_pose_name=cam_pose_name, | |
output_path=OUTPUT_PATH, | |
controlnet_guidance_end=CONTROLNET_GUIDANCE_END, | |
pose_type=POSE_TYPE, | |
speed=SPEED, | |
use_flow_integration=True, | |
depth_ckpt_path=DEPTH_CKPT_PATH, | |
dtype=torch.float16, | |
num_frames=NUM_FRAMES, | |
fps=FPS, | |
num_inference_steps=INFER_STEPS, | |
) | |
prompt_short = prompt[:30].strip().replace(" ", "_").replace(".", "").replace(",", "") | |
video_filename = f"{prompt_short}_{cam_pose_name}.mp4" | |
video_path = os.path.join(GEN_VID_DIR, video_filename) | |
print(f"\nπ Looking for generated video at: {video_path}") | |
if os.path.exists(video_path) and os.path.getsize(video_path) > 0: | |
print("β Video file found.") | |
return video_path, "β Video generated successfully.", list_generated_videos(), None | |
else: | |
print("β File missing or empty.") | |
return None, f"β File not found at: {video_path}", list_generated_videos(), None | |
except Exception as e: | |
print(f"π₯ Exception occurred: {str(e)}") | |
return None, f"π₯ Exception: {str(e)}", list_generated_videos(), None | |
def get_video_file(video_name): | |
video_path = os.path.join(GEN_VID_DIR, video_name) | |
if os.path.exists(video_path): | |
return video_path | |
else: | |
return None | |
# ----------------------------------- | |
# Step 3: Launch Gradio Interface | |
# ----------------------------------- | |
with gr.Blocks(title="FloVD + CogVideoX") as demo: | |
gr.Markdown("## π₯ FloVD - Camera Motion Guided Video Generation + Downloader") | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Textbox(label="Prompt") | |
image = gr.Image(type="pil", label="Input Image") | |
pose_file = gr.Textbox(label="Camera Pose Filename (e.g. abc.txt)") | |
run_btn = gr.Button("Generate Video") | |
with gr.Column(): | |
output_video = gr.Video(label="Generated Video") | |
status_text = gr.Markdown(label="Status / Logs") # shows errors/success | |
video_selector = gr.Dropdown(choices=list_generated_videos(), label="Select Video to Download") | |
download_btn = gr.Button("Download Selected") | |
file_output = gr.File(label="Download Link") | |
run_btn.click( | |
fn=run_flovd, | |
inputs=[prompt, image, pose_file], | |
outputs=[output_video, status_text, video_selector, file_output], | |
) | |
download_btn.click( | |
fn=get_video_file, | |
inputs=[video_selector], | |
outputs=[file_output], | |
) | |
demo.launch(server_name="0.0.0.0", server_port=7860) | |