import os import gradio as gr import torch import subprocess from PIL import Image from pathlib import Path import io import sys import traceback # ========================================= # 1. Define Hugging Face weights and paths # ========================================= HF_DATASET_URL = "https://huggingface.co/datasets/roll-ai/FloVD-weights/resolve/main/ckpt" WEIGHT_FILES = { "FVSM/FloVD_FVSM_Controlnet.pt": "FVSM/FloVD_FVSM_Controlnet.pt", "OMSM/selected_blocks.safetensors": "OMSM/selected_blocks.safetensors", "OMSM/pytorch_lora_weights.safetensors": "OMSM/pytorch_lora_weights.safetensors", "others/depth_anything_v2_metric_hypersim_vitb.pth": "others/depth_anything_v2_metric_hypersim_vitb.pth" } def download_weights(): print("🔄 Downloading model weights...") for rel_path in WEIGHT_FILES.values(): save_path = Path("ckpt") / rel_path if not save_path.exists(): save_path.parent.mkdir(parents=True, exist_ok=True) url = f"{HF_DATASET_URL}/{rel_path}" print(f"📥 Downloading {url} → {save_path}") subprocess.run(["wget", "-q", "-O", str(save_path), url], check=True) else: print(f"✅ Already exists: {save_path}") download_weights() from inference.flovd_demo import generate_video def run_inference(prompt, image, pose_type, speed, use_flow_integration, cam_pose_name): # Redirect stdout to capture logs log_buffer = io.StringIO() sys_stdout = sys.stdout sys.stdout = log_buffer video_path = None try: print("🚀 Starting inference...") os.makedirs("input_images", exist_ok=True) image_path = "input_images/input_image.png" image.save(image_path) print(f"📸 Saved input image to {image_path}") generate_video( prompt=prompt, image_path=image_path, fvsm_path="./ckpt/FVSM/FloVD_FVSM_Controlnet.pt", omsm_path="./ckpt/OMSM", output_path="./outputs", num_frames=49, fps=16, width=None, height=None, seed=42, guidance_scale=6.0, dtype=torch.float16, controlnet_guidance_end=0.4, use_dynamic_cfg=False, pose_type=pose_type, speed=float(speed), use_flow_integration=use_flow_integration, cam_pose_name=cam_pose_name, depth_ckpt_path="./ckpt/others/depth_anything_v2_metric_hypersim_vitb.pth" ) video_name = f"{prompt[:30].strip().replace(' ', '_')}_{cam_pose_name or 'default'}.mp4" video_path = f"./outputs/generated_videos/{video_name}" print(f"✅ Inference complete. Video saved to {video_path}") except Exception as e: print("🔥 Inference failed with exception:") traceback.print_exc() # Restore stdout and return logs sys.stdout = sys_stdout logs = log_buffer.getvalue() log_buffer.close() return (video_path if video_path and os.path.exists(video_path) else None), logs # ======================== # Gradio Interface # ======================== with gr.Blocks() as demo: gr.Markdown("## 🎥 FloVD: Optical Flow + CogVideoX Video Generation") with gr.Row(): with gr.Column(): prompt = gr.Textbox(label="Prompt", value="A girl riding a bicycle through a park.") image = gr.Image(type="pil", label="Input Image") pose_type = gr.Radio(choices=["manual", "re10k"], value="manual", label="Camera Pose Type") cam_pose_name = gr.Textbox(label="Camera Trajectory Name", placeholder="e.g. zoom_in, tilt_up") speed = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=0.5, label="Speed") use_flow_integration = gr.Checkbox(label="Use Flow Integration", value=False) submit = gr.Button("Generate Video") with gr.Column(): output_video = gr.Video(label="Generated Video") output_logs = gr.Textbox(label="Logs", lines=20, interactive=False) submit.click( fn=run_inference, inputs=[prompt, image, pose_type, speed, use_flow_integration, cam_pose_name], outputs=[output_video, output_logs] ) demo.launch(show_error=True, enable_queue=False)