Spaces:
Paused
Paused
File size: 4,621 Bytes
813d218 1dddb04 813d218 b93ca3e 813d218 1dddb04 d8bfbd8 1dddb04 813d218 b93ca3e 927afcf b93ca3e 1dddb04 b93ca3e 1732a51 1dddb04 5d8a662 1dddb04 cb1cab1 b272a96 9170b3e 5d8a662 1dddb04 d8bfbd8 9170b3e 5d8a662 9170b3e 5d8a662 9170b3e 5d8a662 1dddb04 828b5b4 9ebab32 9170b3e b272a96 9ebab32 9170b3e 828b5b4 9170b3e b272a96 9170b3e b272a96 9ebab32 b272a96 9170b3e 9ebab32 b272a96 5d8a662 828b5b4 1dddb04 5d8a662 1dddb04 5d8a662 b272a96 5d8a662 b272a96 5d8a662 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import os
import torch
import gradio as gr
from PIL import Image
from inference.flovd_demo import generate_video
from huggingface_hub import snapshot_download
# -----------------------------------
# Step 1: Download model checkpoints
# -----------------------------------
hf_token = os.getenv("HF_TOKEN", None)
snapshot_download(
repo_id="roll-ai/FloVD-weights",
repo_type="dataset",
local_dir="./",
allow_patterns="ckpt/**",
token=hf_token,
)
# -----------------------------------
# Step 2: Setup paths and config
# -----------------------------------
BASE_DIR = "app"
FVSM_PATH = os.path.join(BASE_DIR, "ckpt/FVSM/FloVD_FVSM_Controlnet.pt")
OMSM_PATH = os.path.join(BASE_DIR, "ckpt/OMSM/")
DEPTH_CKPT_PATH = os.path.join(BASE_DIR, "ckpt/others/depth_anything_v2_metric_hypersim_vitb.pth")
OUTPUT_PATH = os.path.join(BASE_DIR, "results")
GEN_VID_DIR = os.path.join(OUTPUT_PATH, "generated_videos")
POSE_TYPE = "re10k"
CONTROLNET_GUIDANCE_END = 0.4
SPEED = 1.0
NUM_FRAMES = 81
FPS = 16
INFER_STEPS = 50
os.makedirs(GEN_VID_DIR, exist_ok=True)
# -----------------------------------
# Helper Functions
# -----------------------------------
def list_generated_videos():
try:
return sorted([
f for f in os.listdir(GEN_VID_DIR)
if f.endswith(".mp4")
])
except Exception as e:
print(f"[β οΈ] Could not list contents: {str(e)}")
return []
def run_flovd(prompt, image, cam_pose_name):
try:
print("\n----------------------------")
print("π Starting video generation")
print("----------------------------")
image_path = os.path.join(BASE_DIR, "temp_input.png")
image.save(image_path)
print(f"πΈ Image saved at {image_path}")
generate_video(
prompt=prompt,
fvsm_path=FVSM_PATH,
omsm_path=OMSM_PATH,
image_path=image_path,
cam_pose_name=cam_pose_name,
output_path=OUTPUT_PATH,
controlnet_guidance_end=CONTROLNET_GUIDANCE_END,
pose_type=POSE_TYPE,
speed=SPEED,
use_flow_integration=True,
depth_ckpt_path=DEPTH_CKPT_PATH,
dtype=torch.float16,
num_frames=NUM_FRAMES,
fps=FPS,
num_inference_steps=INFER_STEPS,
)
prompt_short = prompt[:30].strip().replace(" ", "_").replace(".", "").replace(",", "")
video_filename = f"{prompt_short}_{cam_pose_name}.mp4"
video_path = os.path.join(GEN_VID_DIR, video_filename)
print(f"\nπ Looking for generated video at: {video_path}")
if os.path.exists(video_path) and os.path.getsize(video_path) > 0:
print("β
Video file found.")
return video_path, "β
Video generated successfully.", list_generated_videos(), None
else:
print("β File missing or empty.")
return None, f"β File not found at: {video_path}", list_generated_videos(), None
except Exception as e:
print(f"π₯ Exception occurred: {str(e)}")
return None, f"π₯ Exception: {str(e)}", list_generated_videos(), None
def get_video_file(video_name):
video_path = os.path.join(GEN_VID_DIR, video_name)
if os.path.exists(video_path):
return video_path
else:
return None
# -----------------------------------
# Step 3: Launch Gradio Interface
# -----------------------------------
with gr.Blocks(title="FloVD + CogVideoX") as demo:
gr.Markdown("## π₯ FloVD - Camera Motion Guided Video Generation + Downloader")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt")
image = gr.Image(type="pil", label="Input Image")
pose_file = gr.Textbox(label="Camera Pose Filename (e.g. abc.txt)")
run_btn = gr.Button("Generate Video")
with gr.Column():
output_video = gr.Video(label="Generated Video")
status_text = gr.Markdown(label="Status / Logs") # shows errors/success
video_selector = gr.Dropdown(choices=list_generated_videos(), label="Select Video to Download")
download_btn = gr.Button("Download Selected")
file_output = gr.File(label="Download Link")
run_btn.click(
fn=run_flovd,
inputs=[prompt, image, pose_file],
outputs=[output_video, status_text, video_selector, file_output],
)
download_btn.click(
fn=get_video_file,
inputs=[video_selector],
outputs=[file_output],
)
demo.launch(server_name="0.0.0.0", server_port=7860)
|