import os import torch import gradio as gr from PIL import Image from inference.flovd_demo import generate_video from huggingface_hub import snapshot_download # ----------------------------------- # Step 1: Download model checkpoints # ----------------------------------- hf_token = os.getenv("HF_TOKEN", None) snapshot_download( repo_id="roll-ai/FloVD-weights", repo_type="dataset", local_dir="./", allow_patterns="ckpt/**", token=hf_token, ) # ----------------------------------- # Step 2: Setup paths and config # ----------------------------------- BASE_DIR = "app" FVSM_PATH = os.path.join(BASE_DIR, "ckpt/FVSM/FloVD_FVSM_Controlnet.pt") OMSM_PATH = os.path.join(BASE_DIR, "ckpt/OMSM/") DEPTH_CKPT_PATH = os.path.join(BASE_DIR, "ckpt/others/depth_anything_v2_metric_hypersim_vitb.pth") OUTPUT_PATH = os.path.join(BASE_DIR, "results") GEN_VID_DIR = os.path.join(OUTPUT_PATH, "generated_videos") POSE_TYPE = "re10k" CONTROLNET_GUIDANCE_END = 0.4 SPEED = 1.0 NUM_FRAMES = 81 FPS = 16 INFER_STEPS = 50 os.makedirs(GEN_VID_DIR, exist_ok=True) # ----------------------------------- # Helper Functions # ----------------------------------- def list_generated_videos(): try: return sorted([ f for f in os.listdir(GEN_VID_DIR) if f.endswith(".mp4") ]) except Exception as e: print(f"[āš ļø] Could not list contents: {str(e)}") return [] def run_flovd(prompt, image, cam_pose_name): try: print("\n----------------------------") print("šŸš€ Starting video generation") print("----------------------------") image_path = os.path.join(BASE_DIR, "temp_input.png") image.save(image_path) print(f"šŸ“ø Image saved at {image_path}") generate_video( prompt=prompt, fvsm_path=FVSM_PATH, omsm_path=OMSM_PATH, image_path=image_path, cam_pose_name=cam_pose_name, output_path=OUTPUT_PATH, controlnet_guidance_end=CONTROLNET_GUIDANCE_END, pose_type=POSE_TYPE, speed=SPEED, use_flow_integration=True, depth_ckpt_path=DEPTH_CKPT_PATH, dtype=torch.float16, num_frames=NUM_FRAMES, fps=FPS, num_inference_steps=INFER_STEPS, ) prompt_short = prompt[:30].strip().replace(" ", "_").replace(".", "").replace(",", "") video_filename = f"{prompt_short}_{cam_pose_name}.mp4" video_path = os.path.join(GEN_VID_DIR, video_filename) print(f"\nšŸ“ Looking for generated video at: {video_path}") if os.path.exists(video_path) and os.path.getsize(video_path) > 0: print("āœ… Video file found.") return video_path, "āœ… Video generated successfully.", list_generated_videos(), None else: print("āŒ File missing or empty.") return None, f"āŒ File not found at: {video_path}", list_generated_videos(), None except Exception as e: print(f"šŸ”„ Exception occurred: {str(e)}") return None, f"šŸ”„ Exception: {str(e)}", list_generated_videos(), None def get_video_file(video_name): video_path = os.path.join(GEN_VID_DIR, video_name) if os.path.exists(video_path): return video_path else: return None # ----------------------------------- # Step 3: Launch Gradio Interface # ----------------------------------- with gr.Blocks(title="FloVD + CogVideoX") as demo: gr.Markdown("## šŸŽ„ FloVD - Camera Motion Guided Video Generation + Downloader") with gr.Row(): with gr.Column(): prompt = gr.Textbox(label="Prompt") image = gr.Image(type="pil", label="Input Image") pose_file = gr.Textbox(label="Camera Pose Filename (e.g. abc.txt)") run_btn = gr.Button("Generate Video") with gr.Column(): output_video = gr.Video(label="Generated Video") status_text = gr.Markdown(label="Status / Logs") # shows errors/success video_selector = gr.Dropdown(choices=list_generated_videos(), label="Select Video to Download") download_btn = gr.Button("Download Selected") file_output = gr.File(label="Download Link") run_btn.click( fn=run_flovd, inputs=[prompt, image, pose_file], outputs=[output_video, status_text, video_selector, file_output], ) download_btn.click( fn=get_video_file, inputs=[video_selector], outputs=[file_output], ) demo.launch(server_name="0.0.0.0", server_port=7860)