File size: 3,576 Bytes
7e06d6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ded78ff
 
 
f4a2ddb
7e06d6b
ded78ff
fd66ee7
 
 
 
7e06d6b
fd66ee7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e06d6b
ded78ff
 
 
 
 
 
 
 
 
 
 
 
 
7e06d6b
ded78ff
7e06d6b
ded78ff
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
import gradio as gr
import torch
import subprocess
from PIL import Image
from pathlib import Path

# =========================================
# 1. Define Hugging Face weights and paths
# =========================================

HF_DATASET_URL = "https://huggingface.co/datasets/roll-ai/FloVD-weights/resolve/main/ckpt"
WEIGHT_FILES = {
    "FVSM/FloVD_FVSM_Controlnet.pt": "FVSM/FloVD_FVSM_Controlnet.pt",
    "OMSM/selected_blocks.safetensors": "OMSM/selected_blocks.safetensors",
    "OMSM/pytorch_lora_weights.safetensors": "OMSM/pytorch_lora_weights.safetensors",
    "others/depth_anything_v2_metric_hypersim_vitb.pth": "others/depth_anything_v2_metric_hypersim_vitb.pth"
}

def download_weights():
    print("πŸ”„ Downloading model weights...")
    for rel_path in WEIGHT_FILES.values():
        save_path = Path("ckpt") / rel_path
        if not save_path.exists():
            save_path.parent.mkdir(parents=True, exist_ok=True)
            url = f"{HF_DATASET_URL}/{rel_path}"
            print(f"πŸ“₯ Downloading {url} β†’ {save_path}")
            subprocess.run(["wget", "-q", "-O", str(save_path), url], check=True)
        else:
            print(f"βœ… Already exists: {save_path}")

download_weights()
import gradio as gr
import torch
import os
from inference.flovd_demo import generate_video

def run_inference(prompt, image, pose_type, speed, use_flow_integration, cam_pose_name):
    try:
        os.makedirs("input_images", exist_ok=True)
        image_path = "input_images/input_image.png"
        image.save(image_path)

        generate_video(
            prompt=prompt,
            image_path=image_path,
            fvsm_path="./ckpt/FVSM/FloVD_FVSM_Controlnet.pt",
            omsm_path="./ckpt/OMSM",
            output_path="./outputs",
            num_frames=49,
            fps=16,
            width=None,
            height=None,
            seed=42,
            guidance_scale=6.0,
            dtype=torch.float16,
            controlnet_guidance_end=0.4,
            use_dynamic_cfg=False,
            pose_type=pose_type,
            speed=float(speed),
            use_flow_integration=use_flow_integration,
            cam_pose_name=cam_pose_name,
            depth_ckpt_path="./ckpt/others/depth_anything_v2_metric_hypersim_vitb.pth"
        )
        return f"./outputs/generated_videos/{prompt[:30].strip().replace(' ', '_')}_{cam_pose_name or 'default'}.mp4"

    except Exception as e:
        print("πŸ”₯ Inference failed:")
        import traceback
        traceback.print_exc()
        return f"⚠️ Error during inference: {e}"

with gr.Blocks() as demo:
    gr.Markdown("## πŸŽ₯ FloVD: Optical Flow + CogVideoX Video Generation")
    with gr.Row():
        with gr.Column():
            prompt = gr.Textbox(label="Prompt", value="A girl riding a bicycle through a park.")
            image = gr.Image(type="pil", label="Input Image")
            pose_type = gr.Radio(choices=["manual", "re10k"], value="manual", label="Camera Pose Type")
            cam_pose_name = gr.Textbox(label="Camera Trajectory Name", placeholder="e.g. zoom_in, tilt_up")
            speed = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=0.5, label="Speed")
            use_flow_integration = gr.Checkbox(label="Use Flow Integration", value=False)
            submit = gr.Button("Generate Video")
        with gr.Column():
            output_video = gr.Video(label="Generated Video")

    submit.click(fn=run_inference, inputs=[prompt, image, pose_type, speed, use_flow_integration, cam_pose_name], outputs=output_video)

demo.launch()