Flov-space / app.py
roll-ai's picture
Update app.py
cb1cab1 verified
import os
import torch
import gradio as gr
from PIL import Image
from inference.flovd_demo import generate_video
from huggingface_hub import snapshot_download
# -----------------------------------
# Step 1: Download model checkpoints
# -----------------------------------
hf_token = os.getenv("HF_TOKEN", None)
snapshot_download(
repo_id="roll-ai/FloVD-weights",
repo_type="dataset",
local_dir="./",
allow_patterns="ckpt/**",
token=hf_token,
)
# -----------------------------------
# Step 2: Setup paths and config
# -----------------------------------
BASE_DIR = "app"
FVSM_PATH = os.path.join(BASE_DIR, "ckpt/FVSM/FloVD_FVSM_Controlnet.pt")
OMSM_PATH = os.path.join(BASE_DIR, "ckpt/OMSM/")
DEPTH_CKPT_PATH = os.path.join(BASE_DIR, "ckpt/others/depth_anything_v2_metric_hypersim_vitb.pth")
OUTPUT_PATH = os.path.join(BASE_DIR, "results")
GEN_VID_DIR = os.path.join(OUTPUT_PATH, "generated_videos")
POSE_TYPE = "re10k"
CONTROLNET_GUIDANCE_END = 0.4
SPEED = 1.0
NUM_FRAMES = 81
FPS = 16
INFER_STEPS = 50
os.makedirs(GEN_VID_DIR, exist_ok=True)
# -----------------------------------
# Helper Functions
# -----------------------------------
def list_generated_videos():
try:
return sorted([
f for f in os.listdir(GEN_VID_DIR)
if f.endswith(".mp4")
])
except Exception as e:
print(f"[⚠️] Could not list contents: {str(e)}")
return []
def run_flovd(prompt, image, cam_pose_name):
try:
print("\n----------------------------")
print("πŸš€ Starting video generation")
print("----------------------------")
image_path = os.path.join(BASE_DIR, "temp_input.png")
image.save(image_path)
print(f"πŸ“Έ Image saved at {image_path}")
generate_video(
prompt=prompt,
fvsm_path=FVSM_PATH,
omsm_path=OMSM_PATH,
image_path=image_path,
cam_pose_name=cam_pose_name,
output_path=OUTPUT_PATH,
controlnet_guidance_end=CONTROLNET_GUIDANCE_END,
pose_type=POSE_TYPE,
speed=SPEED,
use_flow_integration=True,
depth_ckpt_path=DEPTH_CKPT_PATH,
dtype=torch.float16,
num_frames=NUM_FRAMES,
fps=FPS,
num_inference_steps=INFER_STEPS,
)
prompt_short = prompt[:30].strip().replace(" ", "_").replace(".", "").replace(",", "")
video_filename = f"{prompt_short}_{cam_pose_name}.mp4"
video_path = os.path.join(GEN_VID_DIR, video_filename)
print(f"\nπŸ“ Looking for generated video at: {video_path}")
if os.path.exists(video_path) and os.path.getsize(video_path) > 0:
print("βœ… Video file found.")
return video_path, "βœ… Video generated successfully.", list_generated_videos(), None
else:
print("❌ File missing or empty.")
return None, f"❌ File not found at: {video_path}", list_generated_videos(), None
except Exception as e:
print(f"πŸ”₯ Exception occurred: {str(e)}")
return None, f"πŸ”₯ Exception: {str(e)}", list_generated_videos(), None
def get_video_file(video_name):
video_path = os.path.join(GEN_VID_DIR, video_name)
if os.path.exists(video_path):
return video_path
else:
return None
# -----------------------------------
# Step 3: Launch Gradio Interface
# -----------------------------------
with gr.Blocks(title="FloVD + CogVideoX") as demo:
gr.Markdown("## πŸŽ₯ FloVD - Camera Motion Guided Video Generation + Downloader")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt")
image = gr.Image(type="pil", label="Input Image")
pose_file = gr.Textbox(label="Camera Pose Filename (e.g. abc.txt)")
run_btn = gr.Button("Generate Video")
with gr.Column():
output_video = gr.Video(label="Generated Video")
status_text = gr.Markdown(label="Status / Logs") # shows errors/success
video_selector = gr.Dropdown(choices=list_generated_videos(), label="Select Video to Download")
download_btn = gr.Button("Download Selected")
file_output = gr.File(label="Download Link")
run_btn.click(
fn=run_flovd,
inputs=[prompt, image, pose_file],
outputs=[output_video, status_text, video_selector, file_output],
)
download_btn.click(
fn=get_video_file,
inputs=[video_selector],
outputs=[file_output],
)
demo.launch(server_name="0.0.0.0", server_port=7860)