File size: 6,036 Bytes
e9fa53a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import asyncio
import json
import mimetypes
import os
import re
import uuid
from typing import Tuple

import aiohttp
import gradio as gr
from PIL import Image


def get_ext(url):
    rule = r"\.(.*?)\?"
    rst = re.findall(rule, url)[0]

    return rst.split(".")[-1]


async def download_file(url, local_filename):
    async with aiohttp.ClientSession() as session:
        async with session.get(url) as response:
            ext = get_ext(url)

            if response.status == 200:
                filename_with_ext = os.path.abspath(f"{local_filename}.{ext}")
                content = await response.read()

                with open(filename_with_ext, "wb") as f:
                    f.write(content)

                if ext == "webp":
                    im = Image.open(filename_with_ext).convert("RGB")
                    im.save(f"{local_filename}.jpg", "jpeg")
                    os.remove(filename_with_ext)
                    return f"{local_filename}.jpg"
                else:
                    return filename_with_ext
            else:
                raise RuntimeError(f"{url} download failed")


class HedraClient:
    def __init__(self):
        self._base_url = "https://mercury.dev.dream-ai.com/api"
        self._check_task_url = "https://mercury.dev.dream-ai.com/api/v1/projects/{task_id}"
        self._key = "sk_hedra-TxkxBe8htuAuGXwoPYgjHhYpwcQ3gdFmcGdRTLksRKUcSQEpm7VCNzSNj2680fZC"
        self.timeout = aiohttp.ClientTimeout(total=10)
        os.makedirs("temp", exist_ok=True)

    async def post_audio(self, audio_url):
        headers = {
            "X-API-KEY": self._key,
        }
        local_audio = await download_file(audio_url, f"temp/{str(uuid.uuid4())}")
        try:
            async with aiohttp.ClientSession() as session:
                data = aiohttp.FormData()
                data.add_field("file", open(local_audio, "rb"))
                async with session.post(
                    f"{self._base_url}/v1/audio", headers=headers, data={"file": open(local_audio, "rb")}
                ) as resp:
                    return await resp.json()
        finally:
            if os.path.exists(local_audio):
                os.remove(local_audio)

    async def post_image(self, image_url):
        headers = {
            "X-API-KEY": self._key,
        }
        local_image = await download_file(image_url, f"temp/{str(uuid.uuid4())}")
        try:
            async with aiohttp.ClientSession() as session:
                data = aiohttp.FormData()
                data.add_field("file", open(local_image, "rb"))
                async with session.post(
                    f"{self._base_url}/v1/portrait", headers=headers, data={"file": open(local_image, "rb")}, timeout=10
                ) as resp:
                    return await resp.json()
        finally:
            if os.path.exists(local_image):
                os.remove(local_image)

    async def submit_task(self, audio_url: str, image_url: str, aspect_ratio: str) -> Tuple[str, str]:
        headers = {
            "X-API-KEY": self._key,
        }

        audio_task = asyncio.create_task(self.post_audio(audio_url))
        image_task = asyncio.create_task(self.post_image(image_url))
        audio_result, image_result = await asyncio.gather(audio_task, image_task)
        payload = {
            "voiceUrl": audio_result["url"],
            "avatarImage": image_result["url"],
            "aspectRatio": aspect_ratio,
        }

        async with aiohttp.ClientSession(headers=headers, timeout=self.timeout) as session:
            async with session.post(f"{self._base_url}/v1/characters", json=payload) as response:
                data = await response.json()
                task_id = data.get("jobId", None)
                assert task_id is not None, f"Failed to submit task, {data}"
                request_id = data.get("request_id", None)
                return task_id, request_id

    async def get_response(self, task_id: str) -> Tuple[str, float]:
        headers = {
            "X-API-KEY": self._key,
        }
        async with aiohttp.ClientSession(headers=headers, timeout=self.timeout) as session:
            while True:
                async with session.get(self._check_task_url.format(task_id=task_id)) as response:
                    data = await response.json()
                    status = data.get("status", None)
                    if status == "Completed":
                        video_url = data.get("videoUrl", None)
                        assert video_url is not None, f"Failed to get video_url from response[{data}]"
                        video_duration = 4
                        return video_url, video_duration
                    elif status in ["Failed"] or status is None:
                        raise RuntimeError(
                            f"Task {task_id} failed or was canceled. {data.get('output', {}).get('message', '')}"
                        )
                    else:
                        await asyncio.sleep(4)


def gradio_interface(audio_url, image_url, aspect_ratio):
    client = HedraClient()

    async def process(audio_url, image_url, aspect_ratio):
        task_id, request_id = await client.submit_task(audio_url, image_url, aspect_ratio)
        video_url, video_duration = await client.get_response(task_id)
        return video_url, video_duration

    loop = asyncio.new_event_loop()
    asyncio.set_event_loop(loop)
    video_url, video_duration = loop.run_until_complete(process(audio_url, image_url, aspect_ratio))
    return video_url, video_duration


iface = gr.Interface(
    fn=gradio_interface,
    inputs=[
        gr.inputs.Textbox(label="Audio URL"),
        gr.inputs.Textbox(label="Image URL"),
        gr.inputs.Textbox(label="Aspect Ratio"),
    ],
    outputs=[
        gr.outputs.Textbox(label="Video URL"),
        gr.outputs.Textbox(label="Video Duration"),
    ],
    title="Hedra Gradio Interface",
    description="Submit audio and image URLs to generate a video.",
)

iface.launch(share=True)