kelseye commited on
Commit
e550e76
·
verified ·
1 Parent(s): 032b609

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +192 -0
app.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os
3
+ import oss2
4
+ import sys
5
+ import uuid
6
+ import shutil
7
+ import time
8
+ import gradio as gr
9
+ import requests
10
+
11
+ os.system("pip install dashscope")
12
+ import dashscope
13
+ from dashscope.utils.oss_utils import check_and_upload_local
14
+
15
+ DASHSCOPE_API_KEY = os.getenv("DASHSCOPE_API_KEY")
16
+ dashscope.api_key = DASHSCOPE_API_KEY
17
+
18
+
19
+ class WanS2VApp:
20
+ def __init__(self):
21
+ pass
22
+
23
+ def predict(
24
+ self,
25
+ ref_img,
26
+ audio,
27
+ resolution="480P",
28
+ style="speech",
29
+ ):
30
+ # Upload files to OSS if needed and get URLs
31
+ _, image_url = check_and_upload_local("wan2.2-s2v", ref_img, DASHSCOPE_API_KEY)
32
+ _, audio_url = check_and_upload_local("wan2.2-s2v", audio, DASHSCOPE_API_KEY)
33
+
34
+ # Prepare the request payload
35
+ payload = {
36
+ "model": "wan2.2-s2v",
37
+ "input": {
38
+ "image_url": image_url,
39
+ "audio_url": audio_url
40
+ },
41
+ "parameters": {
42
+ "style": style,
43
+ "resolution": resolution,
44
+ }
45
+ }
46
+
47
+ # Set up headers
48
+ headers = {
49
+ "X-DashScope-Async": "enable",
50
+ "X-DashScope-OssResourceResolve": "enable",
51
+ "Authorization": f"Bearer {DASHSCOPE_API_KEY}",
52
+ "Content-Type": "application/json"
53
+ }
54
+
55
+ # Make the initial API request
56
+ url = "https://dashscope.aliyuncs.com/api/v1/services/aigc/image2video/video-synthesis/"
57
+ response = requests.post(url, json=payload, headers=headers)
58
+
59
+ # Check if request was successful
60
+ if response.status_code != 200:
61
+ raise Exception(f"Initial request failed with status code {response.status_code}: {response.text}")
62
+
63
+ # Get the task ID from response
64
+ result = response.json()
65
+ task_id = result.get("output", {}).get("task_id")
66
+ if not task_id:
67
+ raise Exception("Failed to get task ID from response")
68
+
69
+ # Poll for results
70
+ get_url = f"https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}"
71
+ headers = {
72
+ "Authorization": f"Bearer {DASHSCOPE_API_KEY}",
73
+ "Content-Type": "application/json"
74
+ }
75
+
76
+ while True:
77
+ response = requests.get(get_url, headers=headers)
78
+ if response.status_code != 200:
79
+ raise Exception(f"Failed to get task status: {response.status_code}: {response.text}")
80
+
81
+ result = response.json()
82
+ print(result)
83
+ task_status = result.get("output", {}).get("task_status")
84
+
85
+ if task_status == "SUCCEEDED":
86
+ # Task completed successfully, return video URL
87
+ video_url = result["output"]["results"]["video_url"]
88
+ return video_url
89
+ elif task_status == "FAILED":
90
+ # Task failed, raise an exception with error message
91
+ error_msg = result.get("output", {}).get("message", "Unknown error")
92
+ raise Exception(f"Task failed: {error_msg}")
93
+ else:
94
+ # Task is still running, wait and retry
95
+ time.sleep(5) # Wait 5 seconds before polling again
96
+
97
+ def start_app():
98
+ import argparse
99
+ parser = argparse.ArgumentParser(description="Wan2.2-S2V 视频生成工具")
100
+ args = parser.parse_args()
101
+
102
+
103
+ app = WanS2VApp()
104
+ with gr.Blocks(title="Wan2.2-S2V 视频生成") as demo:
105
+ # gr.Markdown("# Wan2.2-S2V 视频生成工具")
106
+ gr.HTML("""
107
+ <div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
108
+ Wan2.2-S2V
109
+ </div>
110
+ """)
111
+ gr.Markdown("基于音频和参考图像生成视频")
112
+
113
+ with gr.Row():
114
+ with gr.Column():
115
+ ref_img = gr.Image(
116
+ label="Input image(输入图像)",
117
+ type="filepath",
118
+ sources=["upload"],
119
+ )
120
+
121
+ audio = gr.Audio(
122
+ label="Audio(音频文件)",
123
+ type="filepath",
124
+ sources=["upload"],
125
+ )
126
+
127
+ resolution = gr.Dropdown(
128
+ label="Resolution(分辨率)",
129
+ choices=["480P", "720P"],
130
+ value="480P",
131
+ info="Inference Resolution, default: 480P(推理分辨率,默认480P)"
132
+ )
133
+ run_button = gr.Button("Generate Video(生成视频)")
134
+
135
+ with gr.Column():
136
+ output_video = gr.Video(label="Output Video(输出视频)")
137
+
138
+
139
+ run_button.click(
140
+ fn=app.predict,
141
+ inputs=[
142
+ ref_img,
143
+ audio,
144
+ resolution,
145
+ ],
146
+ outputs=[output_video],
147
+ )
148
+
149
+
150
+ examples_dir = "examples"
151
+ if os.path.exists(examples_dir):
152
+ example_data = []
153
+
154
+ files_dict = {}
155
+ for file in os.listdir(examples_dir):
156
+ file_path = os.path.join(examples_dir, file)
157
+ name, ext = os.path.splitext(file)
158
+
159
+ if ext.lower() in [".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".webp"]:
160
+ if name not in files_dict:
161
+ files_dict[name] = {}
162
+ files_dict[name]["image"] = file_path
163
+ elif ext.lower() in [".mp3", ".wav"]:
164
+ if name not in files_dict:
165
+ files_dict[name] = {}
166
+ files_dict[name]["audio"] = file_path
167
+
168
+ for name, files in files_dict.items():
169
+ if "image" in files and "audio" in files:
170
+ example_data.append([
171
+ files["image"],
172
+ files["audio"],
173
+ "480P"
174
+ ])
175
+
176
+ if example_data:
177
+ gr.Examples(
178
+ examples=example_data,
179
+ inputs=[ref_img, audio, resolution],
180
+ outputs=output_video,
181
+ fn=app.predict,
182
+ cache_examples=False,
183
+ )
184
+
185
+ demo.launch(
186
+ server_name="0.0.0.0",
187
+ server_port=7860
188
+ )
189
+
190
+
191
+ if __name__ == "__main__":
192
+ start_app()