roll-ai commited on
Commit
813d218
·
verified ·
1 Parent(s): bbccd3b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -0
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import torch
4
+ from PIL import Image
5
+ from inference.flovd_demo import generate_video
6
+ import requests
7
+ import shutil
8
+
9
+ # --------- SETUP CODE: Download models if not found ---------
10
+ FVSM_PATH = "./ckpt/FVSM/FloVD_FVSM_Controlnet.pt"
11
+ OMSM_PATH = "./ckpt/OMSM"
12
+ POSE_DIR = "./assets/manual_poses"
13
+ EXAMPLE_POSE = os.path.join(POSE_DIR, "example.txt")
14
+
15
+ def download_if_missing():
16
+ os.makedirs("ckpt/FVSM", exist_ok=True)
17
+ os.makedirs("ckpt/OMSM", exist_ok=True)
18
+ os.makedirs("assets/manual_poses", exist_ok=True)
19
+ os.makedirs("output/generated_videos", exist_ok=True)
20
+
21
+ # Download FVSM model
22
+ if not os.path.exists(FVSM_PATH):
23
+ print("Downloading FVSM model...")
24
+ url = "https://huggingface.co/datasets/mutqa/FloVD-HF-Assets/resolve/main/FloVD_FVSM_Controlnet.pt"
25
+ r = requests.get(url, stream=True)
26
+ with open(FVSM_PATH, 'wb') as f:
27
+ shutil.copyfileobj(r.raw, f)
28
+
29
+ # Download OMSM weights (as folder with .safetensors inside)
30
+ if not os.listdir(OMSM_PATH):
31
+ print("Cloning OMSM weights...")
32
+ os.system("git clone https://huggingface.co/datasets/mutqa/FloVD-HF-OMSM ckpt/OMSM")
33
+
34
+ # Download example camera pose
35
+ if not os.path.exists(EXAMPLE_POSE):
36
+ print("Downloading example pose...")
37
+ url = "https://huggingface.co/datasets/mutqa/FloVD-HF-Assets/resolve/main/example.txt"
38
+ r = requests.get(url)
39
+ with open(EXAMPLE_POSE, "w") as f:
40
+ f.write(r.text)
41
+
42
+ # --------- UI Function ---------
43
+ def run_flovd(prompt, image, cam_pose_name):
44
+ download_if_missing()
45
+ image_path = "./temp_input.png"
46
+ image.save(image_path)
47
+
48
+ generate_video(
49
+ prompt=prompt,
50
+ fvsm_path=FVSM_PATH,
51
+ omsm_path=OMSM_PATH,
52
+ image_path=image_path,
53
+ cam_pose_name=cam_pose_name,
54
+ output_path="./output/",
55
+ dtype=torch.float16,
56
+ )
57
+ return "./output/generated_videos/your_video.mp4"
58
+
59
+ # --------- Launch Gradio ---------
60
+ iface = gr.Interface(
61
+ fn=run_flovd,
62
+ inputs=[
63
+ gr.Textbox(label="Prompt"),
64
+ gr.Image(type="pil", label="Input Image"),
65
+ gr.Textbox(label="Camera Pose File Name", value="example.txt"),
66
+ ],
67
+ outputs=gr.Video(label="Generated Video"),
68
+ title="FloVD - Optical Flow Video Diffusion with Camera Motion",
69
+ )
70
+
71
+ iface.launch()