File size: 7,929 Bytes
c7e3d6d
 
 
3a4123b
c7e3d6d
 
 
 
534be3f
3a4123b
 
 
 
 
 
 
534be3f
3a4123b
534be3f
3a4123b
 
 
c7e3d6d
 
3a4123b
c7e3d6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a4123b
c7e3d6d
534be3f
 
 
 
c7e3d6d
 
 
 
 
3a4123b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7e3d6d
 
534be3f
3a4123b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
534be3f
3a4123b
 
 
 
 
 
534be3f
3a4123b
 
 
 
 
 
534be3f
 
3a4123b
 
 
 
 
 
534be3f
 
 
c7e3d6d
 
534be3f
 
 
 
 
c7e3d6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a4123b
 
534be3f
 
 
3a4123b
c7e3d6d
 
 
 
3a4123b
 
 
 
 
c7e3d6d
 
 
3a4123b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7e3d6d
 
3a4123b
c7e3d6d
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
import os
import time
import uuid
from typing import List, Optional, Literal, Any, Dict

import httpx
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel

# ---------- Config via env ----------
HF_TOKEN = os.getenv("HF_TOKEN")  # e.g. hf_jwt_...
# Default matches your curl submit endpoint
HF_SUBMIT_URL = os.getenv(
    "HF_SUBMIT_URL",
    "https://router.huggingface.co/fal-ai/fal-ai/wan/v2.2-a14b/text-to-video?_subdomain=queue",
)
UPLOAD_URL = os.getenv("UPLOAD_URL", "https://upload.snapzion.com/api/public-upload")
UPLOAD_ACCESS_TOKEN = os.getenv("UPLOAD_ACCESS_TOKEN")  # your Snapzion bearer

# Polling/backoff
POLL_INTERVAL_SEC = float(os.getenv("POLL_INTERVAL_SEC", "3"))
POLL_TIMEOUT_SEC = int(os.getenv("POLL_TIMEOUT_SEC", "900"))  # 15 min max


# ---------- OpenAI-compatible schemas ----------
class ChatMessage(BaseModel):
    role: Literal["system", "user", "assistant", "tool"]
    content: str


class ChatCompletionsRequest(BaseModel):
    model: str
    messages: List[ChatMessage]
    temperature: Optional[float] = None
    max_tokens: Optional[int] = None
    stream: Optional[bool] = False
    n: Optional[int] = 1
    top_p: Optional[float] = None
    presence_penalty: Optional[float] = None
    frequency_penalty: Optional[float] = None
    tools: Optional[Any] = None
    tool_choice: Optional[Any] = None


class ChatChoiceMessage(BaseModel):
    role: Literal["assistant"]
    content: str


class ChatChoice(BaseModel):
    index: int
    message: ChatChoiceMessage
    finish_reason: Literal["stop", "length", "content_filter", "tool_calls"] = "stop"


class ChatCompletionsResponse(BaseModel):
    id: str
    object: Literal["chat.completion"] = "chat.completion"
    created: int
    model: str
    choices: List[ChatChoice]
    usage: Optional[Dict[str, int]] = {
        "prompt_tokens": 0,
        "completion_tokens": 0,
        "total_tokens": 0,
    }


# ---------- Helpers ----------
def extract_prompt(messages: List[ChatMessage]) -> str:
    for m in reversed(messages):
        if m.role == "user" and m.content and m.content.strip():
            return m.content.strip()
    user_texts = [m.content for m in messages if m.role == "user" and m.content]
    if not user_texts:
        raise HTTPException(status_code=400, detail="No user prompt provided.")
    return "\n".join(user_texts).strip()


async def hf_queue_submit(prompt: str) -> Dict[str, str]:
    """
    Submit a job to the HF router queue endpoint.
    Returns a dict containing status_url and response_url (from HF).
    """
    if not HF_TOKEN:
        raise HTTPException(status_code=500, detail="HF_TOKEN is not set.")

    headers = {
        "Authorization": f"Bearer {HF_TOKEN}",
        "Content-Type": "application/json",
        "Accept": "*/*",
    }
    payload = {"prompt": prompt}

    async with httpx.AsyncClient(timeout=None) as client:
        resp = await client.post(HF_SUBMIT_URL, headers=headers, json=payload)
    if resp.status_code >= 400:
        raise HTTPException(status_code=502, detail=f"HF submit error: {resp.text}")

    data = resp.json()
    status_url = data.get("status_url") or data.get("urls", {}).get("status_url")
    response_url = data.get("response_url") or data.get("urls", {}).get("response_url")
    if not status_url or not response_url:
        raise HTTPException(status_code=502, detail=f"Unexpected HF submit response: {data}")

    return {"status_url": status_url, "response_url": response_url}


async def hf_queue_wait(status_url: str) -> None:
    """
    Polls the HF queue status_url until COMPLETED or error states.
    """
    if not HF_TOKEN:
        raise HTTPException(status_code=500, detail="HF_TOKEN is not set.")

    headers = {"Authorization": f"Bearer {HF_TOKEN}", "Accept": "*/*"}
    start = time.time()

    async with httpx.AsyncClient(timeout=None) as client:
        while True:
            resp = await client.get(status_url, headers=headers)
            if resp.status_code >= 400:
                raise HTTPException(status_code=502, detail=f"HF status error: {resp.text}")

            data = resp.json()
            status = data.get("status")

            if status in ("COMPLETED", "SUCCEEDED"):
                return
            if status in ("FAILED", "ERROR", "CANCELLED", "CANCELED"):
                raise HTTPException(status_code=502, detail=f"HF job failed: {data}")

            if time.time() - start > POLL_TIMEOUT_SEC:
                raise HTTPException(status_code=504, detail="HF queue timed out.")

            time.sleep(POLL_INTERVAL_SEC)


async def hf_queue_fetch_result(response_url: str) -> Dict[str, Any]:
    """
    Fetch the final response JSON, which includes {"video": {"url": ...}, ...}
    """
    if not HF_TOKEN:
        raise HTTPException(status_code=500, detail="HF_TOKEN is not set.")

    headers = {"Authorization": f"Bearer {HF_TOKEN}", "Accept": "*/*"}
    async with httpx.AsyncClient(timeout=None) as client:
        resp = await client.get(response_url, headers=headers)
    if resp.status_code >= 400:
        raise HTTPException(status_code=502, detail=f"HF result error: {resp.text}")
    return resp.json()


async def download_video(url: str) -> bytes:
    async with httpx.AsyncClient(timeout=None) as client:
        resp = await client.get(url)
    if resp.status_code >= 400:
        raise HTTPException(status_code=502, detail=f"Download failed: {resp.text}")
    return resp.content


async def upload_video_bytes(mp4_bytes: bytes) -> str:
    if not UPLOAD_ACCESS_TOKEN:
        raise HTTPException(status_code=500, detail="UPLOAD_ACCESS_TOKEN is not set.")
    headers = {"Authorization": f"Bearer {UPLOAD_ACCESS_TOKEN}"}
    files = {"file": ("video.mp4", mp4_bytes, "video/mp4")}

    async with httpx.AsyncClient(timeout=None) as client:
        resp = await client.post(UPLOAD_URL, headers=headers, files=files)
    if resp.status_code >= 400:
        raise HTTPException(status_code=502, detail=f"Upload failed: {resp.text}")

    data = resp.json()
    url = (
        data.get("url")
        or data.get("fileUrl")
        or data.get("publicUrl")
        or data.get("data", {}).get("url")
    )
    if not url:
        raise HTTPException(status_code=502, detail=f"Upload response missing URL: {data}")
    return url


# ---------- FastAPI ----------
app = FastAPI(title="OpenAI-Compatible T2V Proxy (HF Router Queue)")

@app.get("/health")
async def health():
    return {"status": "ok", "submit_url": HF_SUBMIT_URL}

@app.post("/v1/chat/completions", response_model=ChatCompletionsResponse)
async def chat_completions(req: ChatCompletionsRequest):
    """
    1) submit to HF router queue (Bearer HF_TOKEN)
    2) poll status_url until COMPLETED
    3) fetch response_url -> video.url
    4) download MP4, upload to Snapzion
    5) return URL in OpenAI chat shape
    """
    prompt = extract_prompt(req.messages)

    # 1) Submit
    urls = await hf_queue_submit(prompt)

    # 2) Wait
    await hf_queue_wait(urls["status_url"])

    # 3) Fetch result JSON
    result = await hf_queue_fetch_result(urls["response_url"])
    video_url = (result.get("video") or {}).get("url")
    if not video_url:
        raise HTTPException(status_code=502, detail=f"HF result missing video.url: {result}")

    # 4) Download + re-upload
    mp4 = await download_video(video_url)
    public_url = await upload_video_bytes(mp4)

    # 5) Respond OpenAI-style
    now = int(time.time())
    completion_id = f"chatcmpl-{uuid.uuid4().hex}"
    content = f"✅ Video generated & uploaded.\n**Prompt:** {prompt}\n**URL:** {public_url}"

    return ChatCompletionsResponse(
        id=completion_id,
        created=now,
        model=req.model,
        choices=[
            ChatChoice(
                index=0,
                message=ChatChoiceMessage(role="assistant", content=content),
                finish_reason="stop",
            )
        ],
    )