File size: 6,538 Bytes
e550e76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e519db
e550e76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
# app.py
import os
import oss2
import sys
import uuid
import shutil
import time
import gradio as gr
import requests

os.system("pip install dashscope")
import dashscope
from dashscope.utils.oss_utils import check_and_upload_local

DASHSCOPE_API_KEY = os.getenv("DASHSCOPE_API_KEY")
dashscope.api_key = DASHSCOPE_API_KEY


class WanS2VApp:
    def __init__(self):
        pass

    def predict(
        self, 
        ref_img,
        audio,
        resolution="480P",
        style="speech",
    ):
        # Upload files to OSS if needed and get URLs
        _, image_url = check_and_upload_local("wan2.2-s2v", ref_img, DASHSCOPE_API_KEY)
        _, audio_url = check_and_upload_local("wan2.2-s2v", audio, DASHSCOPE_API_KEY)

        # Prepare the request payload
        payload = {
            "model": "wan2.2-s2v",
            "input": {
                "image_url": image_url,
                "audio_url": audio_url
            },
            "parameters": {
                "style": style,
                "resolution": resolution,
            }
        }
        
        # Set up headers
        headers = {
            "X-DashScope-Async": "enable",
            "X-DashScope-OssResourceResolve": "enable",
            "Authorization": f"Bearer {DASHSCOPE_API_KEY}",
            "Content-Type": "application/json"
        }
        
        # Make the initial API request
        url = "https://dashscope.aliyuncs.com/api/v1/services/aigc/image2video/video-synthesis/"
        response = requests.post(url, json=payload, headers=headers)
        
        # Check if request was successful
        if response.status_code != 200:
            raise Exception(f"Initial request failed with status code {response.status_code}: {response.text}")
        
        # Get the task ID from response
        result = response.json()
        task_id = result.get("output", {}).get("task_id")
        if not task_id:
            raise Exception("Failed to get task ID from response")
        
        # Poll for results
        get_url = f"https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}"
        headers = {
            "Authorization": f"Bearer {DASHSCOPE_API_KEY}",
            "Content-Type": "application/json"
        }
        
        while True:
            response = requests.get(get_url, headers=headers)
            if response.status_code != 200:
                raise Exception(f"Failed to get task status: {response.status_code}: {response.text}")
            
            result = response.json()
            print(result)
            task_status = result.get("output", {}).get("task_status")
            
            if task_status == "SUCCEEDED":
                # Task completed successfully, return video URL
                video_url = result["output"]["results"]["video_url"]
                return video_url
            elif task_status == "FAILED":
                # Task failed, raise an exception with error message
                error_msg = result.get("output", {}).get("message", "Unknown error")
                raise Exception(f"Task failed: {error_msg}")
            else:
                # Task is still running, wait and retry
                time.sleep(5)  # Wait 5 seconds before polling again

def start_app():
    import argparse
    parser = argparse.ArgumentParser(description="Wan2.2-S2V 视频生成工具")
    args = parser.parse_args()
    

    app = WanS2VApp()
    with gr.Blocks(title="Wan2.2-S2V 视频生成") as demo:
        # gr.Markdown("# Wan2.2-S2V 视频生成工具")
        gr.HTML("""
            <div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
                Wan2.2-S2V
            </div>
            """)
        gr.Markdown("Generate video from audio and a reference image. This app uses a distilled model; for the full version, deploy [the open-source model](https://huggingface.co/Wan-AI/Wan2.2-S2V-14B).")

        with gr.Row():
            with gr.Column():    
                ref_img = gr.Image(
                    label="Input image(输入图像)",
                    type="filepath",
                    sources=["upload"],
                )
                
                audio = gr.Audio(
                    label="Audio(音频文件)",
                    type="filepath",
                    sources=["upload"],
                )

                resolution = gr.Dropdown(
                    label="Resolution(分辨率)",
                    choices=["480P", "720P"],
                    value="480P",
                    info="Inference Resolution, default: 480P(推理分辨率,默认480P)"
                )
                run_button = gr.Button("Generate Video(生成视频)")

            with gr.Column():
                output_video = gr.Video(label="Output Video(输出视频)")
        

        run_button.click(
            fn=app.predict,
            inputs=[
                ref_img,
                audio,
                resolution,
            ],
            outputs=[output_video],
        )


        examples_dir = "examples"
        if os.path.exists(examples_dir):
            example_data = []
            
            files_dict = {}
            for file in os.listdir(examples_dir):
                file_path = os.path.join(examples_dir, file)
                name, ext = os.path.splitext(file)
                
                if ext.lower() in [".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".webp"]:
                    if name not in files_dict:
                        files_dict[name] = {}
                    files_dict[name]["image"] = file_path
                elif ext.lower() in [".mp3", ".wav"]:
                    if name not in files_dict:
                        files_dict[name] = {}
                    files_dict[name]["audio"] = file_path
            
            for name, files in files_dict.items():
                if "image" in files and "audio" in files:
                    example_data.append([
                        files["image"], 
                        files["audio"], 
                        "480P"
                    ])

            if example_data:
                gr.Examples(
                    examples=example_data,
                    inputs=[ref_img, audio, resolution],
                    outputs=output_video,
                    fn=app.predict,
                    cache_examples=False,
                )

    demo.launch(
        server_name="0.0.0.0",
        server_port=7860
    )


if __name__ == "__main__":
    start_app()