File size: 4,621 Bytes
813d218
 
1dddb04
813d218
 
b93ca3e
813d218
1dddb04
d8bfbd8
1dddb04
 
813d218
b93ca3e
 
 
927afcf
b93ca3e
1dddb04
b93ca3e
1732a51
1dddb04
5d8a662
1dddb04
cb1cab1
b272a96
 
 
 
9170b3e
5d8a662
1dddb04
 
 
d8bfbd8
 
 
 
9170b3e
 
5d8a662
 
 
 
9170b3e
 
5d8a662
 
 
 
9170b3e
 
5d8a662
1dddb04
828b5b4
9ebab32
9170b3e
 
 
 
b272a96
9ebab32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9170b3e
828b5b4
9170b3e
b272a96
9170b3e
b272a96
9ebab32
b272a96
 
9170b3e
9ebab32
 
b272a96
5d8a662
 
 
 
 
 
 
828b5b4
1dddb04
5d8a662
1dddb04
 
5d8a662
 
 
 
 
 
 
 
 
 
 
b272a96
5d8a662
 
 
 
 
 
 
b272a96
5d8a662
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import os
import torch
import gradio as gr
from PIL import Image
from inference.flovd_demo import generate_video
from huggingface_hub import snapshot_download

# -----------------------------------
# Step 1: Download model checkpoints
# -----------------------------------
hf_token = os.getenv("HF_TOKEN", None)

snapshot_download(
    repo_id="roll-ai/FloVD-weights",
    repo_type="dataset",
    local_dir="./",
    allow_patterns="ckpt/**",
    token=hf_token,
)

# -----------------------------------
# Step 2: Setup paths and config
# -----------------------------------
BASE_DIR = "app"
FVSM_PATH = os.path.join(BASE_DIR, "ckpt/FVSM/FloVD_FVSM_Controlnet.pt")
OMSM_PATH = os.path.join(BASE_DIR, "ckpt/OMSM/")
DEPTH_CKPT_PATH = os.path.join(BASE_DIR, "ckpt/others/depth_anything_v2_metric_hypersim_vitb.pth")
OUTPUT_PATH = os.path.join(BASE_DIR, "results")
GEN_VID_DIR = os.path.join(OUTPUT_PATH, "generated_videos")

POSE_TYPE = "re10k"
CONTROLNET_GUIDANCE_END = 0.4
SPEED = 1.0
NUM_FRAMES = 81
FPS = 16
INFER_STEPS = 50

os.makedirs(GEN_VID_DIR, exist_ok=True)

# -----------------------------------
# Helper Functions
# -----------------------------------

def list_generated_videos():
    try:
        return sorted([
            f for f in os.listdir(GEN_VID_DIR)
            if f.endswith(".mp4")
        ])
    except Exception as e:
        print(f"[⚠️] Could not list contents: {str(e)}")
        return []

def run_flovd(prompt, image, cam_pose_name):
    try:
        print("\n----------------------------")
        print("πŸš€ Starting video generation")
        print("----------------------------")

        image_path = os.path.join(BASE_DIR, "temp_input.png")
        image.save(image_path)
        print(f"πŸ“Έ Image saved at {image_path}")

        generate_video(
            prompt=prompt,
            fvsm_path=FVSM_PATH,
            omsm_path=OMSM_PATH,
            image_path=image_path,
            cam_pose_name=cam_pose_name,
            output_path=OUTPUT_PATH,
            controlnet_guidance_end=CONTROLNET_GUIDANCE_END,
            pose_type=POSE_TYPE,
            speed=SPEED,
            use_flow_integration=True,
            depth_ckpt_path=DEPTH_CKPT_PATH,
            dtype=torch.float16,
            num_frames=NUM_FRAMES,
            fps=FPS,
            num_inference_steps=INFER_STEPS,
        )

        prompt_short = prompt[:30].strip().replace(" ", "_").replace(".", "").replace(",", "")
        video_filename = f"{prompt_short}_{cam_pose_name}.mp4"
        video_path = os.path.join(GEN_VID_DIR, video_filename)

        print(f"\nπŸ“ Looking for generated video at: {video_path}")
        if os.path.exists(video_path) and os.path.getsize(video_path) > 0:
            print("βœ… Video file found.")
            return video_path, "βœ… Video generated successfully.", list_generated_videos(), None
        else:
            print("❌ File missing or empty.")
            return None, f"❌ File not found at: {video_path}", list_generated_videos(), None

    except Exception as e:
        print(f"πŸ”₯ Exception occurred: {str(e)}")
        return None, f"πŸ”₯ Exception: {str(e)}", list_generated_videos(), None

def get_video_file(video_name):
    video_path = os.path.join(GEN_VID_DIR, video_name)
    if os.path.exists(video_path):
        return video_path
    else:
        return None

# -----------------------------------
# Step 3: Launch Gradio Interface
# -----------------------------------

with gr.Blocks(title="FloVD + CogVideoX") as demo:
    gr.Markdown("## πŸŽ₯ FloVD - Camera Motion Guided Video Generation + Downloader")

    with gr.Row():
        with gr.Column():
            prompt = gr.Textbox(label="Prompt")
            image = gr.Image(type="pil", label="Input Image")
            pose_file = gr.Textbox(label="Camera Pose Filename (e.g. abc.txt)")
            run_btn = gr.Button("Generate Video")
        with gr.Column():
            output_video = gr.Video(label="Generated Video")
            status_text = gr.Markdown(label="Status / Logs")  # shows errors/success
            video_selector = gr.Dropdown(choices=list_generated_videos(), label="Select Video to Download")
            download_btn = gr.Button("Download Selected")
            file_output = gr.File(label="Download Link")

    run_btn.click(
        fn=run_flovd,
        inputs=[prompt, image, pose_file],
        outputs=[output_video, status_text, video_selector, file_output],
    )

    download_btn.click(
        fn=get_video_file,
        inputs=[video_selector],
        outputs=[file_output],
    )

demo.launch(server_name="0.0.0.0", server_port=7860)