File size: 3,408 Bytes
7e06d6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ded78ff
 
 
 
7e06d6b
ded78ff
 
 
 
7e06d6b
ded78ff
7e06d6b
ded78ff
 
 
 
7e06d6b
 
ded78ff
 
 
 
 
7e06d6b
ded78ff
 
 
 
 
 
7e06d6b
ded78ff
7e06d6b
ded78ff
 
 
 
 
 
 
 
 
 
 
 
 
7e06d6b
ded78ff
7e06d6b
ded78ff
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
import gradio as gr
import torch
import subprocess
from PIL import Image
from pathlib import Path

# =========================================
# 1. Define Hugging Face weights and paths
# =========================================

HF_DATASET_URL = "https://huggingface.co/datasets/roll-ai/FloVD-weights/resolve/main/ckpt"
WEIGHT_FILES = {
    "FVSM/FloVD_FVSM_Controlnet.pt": "FVSM/FloVD_FVSM_Controlnet.pt",
    "OMSM/selected_blocks.safetensors": "OMSM/selected_blocks.safetensors",
    "OMSM/pytorch_lora_weights.safetensors": "OMSM/pytorch_lora_weights.safetensors",
    "others/depth_anything_v2_metric_hypersim_vitb.pth": "others/depth_anything_v2_metric_hypersim_vitb.pth"
}

def download_weights():
    print("πŸ”„ Downloading model weights...")
    for rel_path in WEIGHT_FILES.values():
        save_path = Path("ckpt") / rel_path
        if not save_path.exists():
            save_path.parent.mkdir(parents=True, exist_ok=True)
            url = f"{HF_DATASET_URL}/{rel_path}"
            print(f"πŸ“₯ Downloading {url} β†’ {save_path}")
            subprocess.run(["wget", "-q", "-O", str(save_path), url], check=True)
        else:
            print(f"βœ… Already exists: {save_path}")

download_weights()
import gradio as gr
import torch
import os
from inference_script import generate_video  # Assuming your script is saved as inference_script.py

def run_inference(prompt, image, pose_type, speed, use_flow_integration, cam_pose_name):
    os.makedirs("input_images", exist_ok=True)
    image_path = "input_images/input_image.png"
    image.save(image_path)

    generate_video(
        prompt=prompt,
        image_path=image_path,
        fvsm_path="./ckpt/FVSM",  # Expected to be downloaded from HF dataset
        omsm_path="./ckpt/OMSM",  # Expected to be downloaded from HF dataset
        output_path="./outputs",
        num_frames=49,
        fps=16,
        width=None,
        height=None,
        seed=42,
        guidance_scale=6.0,
        dtype=torch.float16,
        controlnet_guidance_end=0.4,
        use_dynamic_cfg=False,
        pose_type=pose_type,
        speed=float(speed),
        use_flow_integration=use_flow_integration,
        cam_pose_name=cam_pose_name,
        depth_ckpt_path="./ckpt/others/depth_anything_v2_metric_hypersim_vitb.pth"
    )
    return f"./outputs/generated_videos/{prompt[:30].strip().replace(' ', '_')}_{cam_pose_name or 'default'}.mp4"

with gr.Blocks() as demo:
    gr.Markdown("## πŸŽ₯ FloVD: Optical Flow + CogVideoX Video Generation")
    with gr.Row():
        with gr.Column():
            prompt = gr.Textbox(label="Prompt", value="A girl riding a bicycle through a park.")
            image = gr.Image(type="pil", label="Input Image")
            pose_type = gr.Radio(choices=["manual", "re10k"], value="manual", label="Camera Pose Type")
            cam_pose_name = gr.Textbox(label="Camera Trajectory Name", placeholder="e.g. zoom_in, tilt_up")
            speed = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=0.5, label="Speed")
            use_flow_integration = gr.Checkbox(label="Use Flow Integration", value=False)
            submit = gr.Button("Generate Video")
        with gr.Column():
            output_video = gr.Video(label="Generated Video")

    submit.click(fn=run_inference, inputs=[prompt, image, pose_type, speed, use_flow_integration, cam_pose_name], outputs=output_video)

demo.launch()