roll-ai commited on
Commit
b93ca3e
·
verified ·
1 Parent(s): a4e9b00

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -55
app.py CHANGED
@@ -5,66 +5,78 @@ from PIL import Image
5
  from inference.flovd_demo import generate_video
6
  import requests
7
  import shutil
 
 
8
 
9
- # --------- SETUP CODE: Download models if not found ---------
10
- FVSM_PATH = "./ckpt/FVSM/FloVD_FVSM_Controlnet.pt"
11
- OMSM_PATH = "./ckpt/OMSM"
12
- POSE_DIR = "./assets/manual_poses"
13
- EXAMPLE_POSE = os.path.join(POSE_DIR, "example.txt")
14
 
15
- def download_if_missing():
16
- os.makedirs("ckpt/FVSM", exist_ok=True)
17
- os.makedirs("ckpt/OMSM", exist_ok=True)
18
- os.makedirs("assets/manual_poses", exist_ok=True)
19
- os.makedirs("output/generated_videos", exist_ok=True)
 
 
20
 
21
- # Download FVSM model
22
- if not os.path.exists(FVSM_PATH):
23
- print("Downloading FVSM model...")
24
- url = "https://huggingface.co/datasets/mutqa/FloVD-HF-Assets/resolve/main/FloVD_FVSM_Controlnet.pt"
25
- r = requests.get(url, stream=True)
26
- with open(FVSM_PATH, 'wb') as f:
27
- shutil.copyfileobj(r.raw, f)
28
 
29
- # Download OMSM weights (as folder with .safetensors inside)
30
- if not os.listdir(OMSM_PATH):
31
- print("Cloning OMSM weights...")
32
- os.system("git clone https://huggingface.co/datasets/mutqa/FloVD-HF-OMSM ckpt/OMSM")
 
33
 
34
- # Download example camera pose
35
- if not os.path.exists(EXAMPLE_POSE):
36
- print("Downloading example pose...")
37
- url = "https://huggingface.co/datasets/mutqa/FloVD-HF-Assets/resolve/main/example.txt"
38
- r = requests.get(url)
39
- with open(EXAMPLE_POSE, "w") as f:
40
- f.write(r.text)
41
 
42
- # --------- UI Function ---------
43
- def run_flovd(prompt, image, cam_pose_name):
44
- download_if_missing()
45
- image_path = "./temp_input.png"
46
- image.save(image_path)
47
 
48
- generate_video(
49
- prompt=prompt,
50
- fvsm_path=FVSM_PATH,
51
- omsm_path=OMSM_PATH,
52
- image_path=image_path,
53
- cam_pose_name=cam_pose_name,
54
- output_path="./output/",
55
- dtype=torch.float16,
56
- )
57
- return "./output/generated_videos/your_video.mp4"
58
 
59
- # --------- Launch Gradio ---------
60
- iface = gr.Interface(
61
- fn=run_flovd,
62
- inputs=[
63
- gr.Textbox(label="Prompt"),
64
- gr.Image(type="pil", label="Input Image"),
65
- gr.Textbox(label="Camera Pose File Name", value="example.txt"),
66
- ],
67
- outputs=gr.Video(label="Generated Video"),
68
- title="FloVD - Optical Flow Video Diffusion with Camera Motion",
69
- )
70
- iface.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  from inference.flovd_demo import generate_video
6
  import requests
7
  import shutil
8
+ import os
9
+ from huggingface_hub import snapshot_download
10
 
11
+ hf_token = os.getenv("HF_TOKEN")
 
 
 
 
12
 
13
+ snapshot_download(
14
+ repo_id="roll-ai/FloVD-weights",
15
+ repo_type="dataset",
16
+ local_dir="./ckpt",
17
+ allow_patterns="ckpt/**",
18
+ token=hf_token
19
+ )
20
 
21
+ # # --------- SETUP CODE: Download models if not found ---------
22
+ # FVSM_PATH = "./ckpt/FVSM/FloVD_FVSM_Controlnet.pt"
23
+ # OMSM_PATH = "./ckpt/OMSM"
24
+ # POSE_DIR = "./assets/manual_poses"
25
+ # EXAMPLE_POSE = os.path.join(POSE_DIR, "example.txt")
 
 
26
 
27
+ # def download_if_missing():
28
+ # os.makedirs("ckpt/FVSM", exist_ok=True)
29
+ # os.makedirs("ckpt/OMSM", exist_ok=True)
30
+ # os.makedirs("assets/manual_poses", exist_ok=True)
31
+ # os.makedirs("output/generated_videos", exist_ok=True)
32
 
33
+ # # Download FVSM model
34
+ # if not os.path.exists(FVSM_PATH):
35
+ # print("Downloading FVSM model...")
36
+ # url = "https://huggingface.co/datasets/mutqa/FloVD-HF-Assets/resolve/main/FloVD_FVSM_Controlnet.pt"
37
+ # r = requests.get(url, stream=True)
38
+ # with open(FVSM_PATH, 'wb') as f:
39
+ # shutil.copyfileobj(r.raw, f)
40
 
41
+ # # Download OMSM weights (as folder with .safetensors inside)
42
+ # if not os.listdir(OMSM_PATH):
43
+ # print("Cloning OMSM weights...")
44
+ # os.system("git clone https://huggingface.co/datasets/mutqa/FloVD-HF-OMSM ckpt/OMSM")
 
45
 
46
+ # # Download example camera pose
47
+ # if not os.path.exists(EXAMPLE_POSE):
48
+ # print("Downloading example pose...")
49
+ # url = "https://huggingface.co/datasets/mutqa/FloVD-HF-Assets/resolve/main/example.txt"
50
+ # r = requests.get(url)
51
+ # with open(EXAMPLE_POSE, "w") as f:
52
+ # f.write(r.text)
 
 
 
53
 
54
+ # # --------- UI Function ---------
55
+ # def run_flovd(prompt, image, cam_pose_name):
56
+ # download_if_missing()
57
+ # image_path = "./temp_input.png"
58
+ # image.save(image_path)
59
+
60
+ # generate_video(
61
+ # prompt=prompt,
62
+ # fvsm_path=FVSM_PATH,
63
+ # omsm_path=OMSM_PATH,
64
+ # image_path=image_path,
65
+ # cam_pose_name=cam_pose_name,
66
+ # output_path="./output/",
67
+ # dtype=torch.float16,
68
+ # )
69
+ # return "./output/generated_videos/your_video.mp4"
70
+
71
+ # # --------- Launch Gradio ---------
72
+ # iface = gr.Interface(
73
+ # fn=run_flovd,
74
+ # inputs=[
75
+ # gr.Textbox(label="Prompt"),
76
+ # gr.Image(type="pil", label="Input Image"),
77
+ # gr.Textbox(label="Camera Pose File Name", value="example.txt"),
78
+ # ],
79
+ # outputs=gr.Video(label="Generated Video"),
80
+ # title="FloVD - Optical Flow Video Diffusion with Camera Motion",
81
+ # )
82
+ # iface.launch(server_name="0.0.0.0", server_port=7860)