Spaces:
Running
Running
Update main.py
Browse files
main.py
CHANGED
@@ -1,27 +1,28 @@
|
|
1 |
import os
|
2 |
import time
|
3 |
import uuid
|
4 |
-
from typing import List, Optional, Literal, Any, Dict
|
5 |
|
6 |
import httpx
|
7 |
from fastapi import FastAPI, HTTPException
|
8 |
from pydantic import BaseModel
|
9 |
-
from huggingface_hub import InferenceClient
|
10 |
-
import asyncio
|
11 |
-
|
12 |
-
|
13 |
-
# ---------------- Config (env) ----------------
|
14 |
-
HF_TOKEN = os.getenv("HF_TOKEN") # Hugging Face token (works for provider=fal-ai)
|
15 |
-
WAN_MODEL = os.getenv("WAN_MODEL", "Wan-AI/Wan2.2-T2V-A14B")
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
UPLOAD_URL = os.getenv("UPLOAD_URL", "https://upload.snapzion.com/api/public-upload")
|
18 |
-
UPLOAD_ACCESS_TOKEN = os.getenv("UPLOAD_ACCESS_TOKEN") # your bearer
|
19 |
|
20 |
-
#
|
21 |
-
|
|
|
22 |
|
23 |
|
24 |
-
#
|
25 |
class ChatMessage(BaseModel):
|
26 |
role: Literal["system", "user", "assistant", "tool"]
|
27 |
content: str
|
@@ -65,9 +66,8 @@ class ChatCompletionsResponse(BaseModel):
|
|
65 |
}
|
66 |
|
67 |
|
68 |
-
#
|
69 |
def extract_prompt(messages: List[ChatMessage]) -> str:
|
70 |
-
"""Use the last user message as the prompt. Fallback to joining all user messages."""
|
71 |
for m in reversed(messages):
|
72 |
if m.role == "user" and m.content and m.content.strip():
|
73 |
return m.content.strip()
|
@@ -77,42 +77,89 @@ def extract_prompt(messages: List[ChatMessage]) -> str:
|
|
77 |
return "\n".join(user_texts).strip()
|
78 |
|
79 |
|
80 |
-
async def
|
81 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
if not HF_TOKEN:
|
83 |
raise HTTPException(status_code=500, detail="HF_TOKEN is not set.")
|
84 |
-
client = InferenceClient(provider="fal-ai", api_key=HF_TOKEN)
|
85 |
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
raise HTTPException(status_code=504, detail="Video generation timed out.")
|
98 |
-
except Exception as e:
|
99 |
-
raise HTTPException(status_code=502, detail=f"Video generation failed: {e}")
|
100 |
|
101 |
-
|
102 |
-
|
103 |
-
|
|
|
|
|
|
|
104 |
|
105 |
-
if isinstance(result, dict):
|
106 |
-
# common keys: "video" (bytes), "seed", etc.
|
107 |
-
vid = result.get("video")
|
108 |
-
if isinstance(vid, (bytes, bytearray)):
|
109 |
-
return bytes(vid)
|
110 |
|
111 |
-
|
|
|
|
|
|
|
|
|
|
|
112 |
|
113 |
|
114 |
async def upload_video_bytes(mp4_bytes: bytes) -> str:
|
115 |
-
"""Uploads MP4 to Snapzion uploader and returns public URL."""
|
116 |
if not UPLOAD_ACCESS_TOKEN:
|
117 |
raise HTTPException(status_code=500, detail="UPLOAD_ACCESS_TOKEN is not set.")
|
118 |
headers = {"Authorization": f"Bearer {UPLOAD_ACCESS_TOKEN}"}
|
@@ -120,12 +167,10 @@ async def upload_video_bytes(mp4_bytes: bytes) -> str:
|
|
120 |
|
121 |
async with httpx.AsyncClient(timeout=None) as client:
|
122 |
resp = await client.post(UPLOAD_URL, headers=headers, files=files)
|
123 |
-
|
124 |
if resp.status_code >= 400:
|
125 |
raise HTTPException(status_code=502, detail=f"Upload failed: {resp.text}")
|
126 |
|
127 |
data = resp.json()
|
128 |
-
# Try common URL fields (adjust if your API returns a different shape)
|
129 |
url = (
|
130 |
data.get("url")
|
131 |
or data.get("fileUrl")
|
@@ -137,31 +182,44 @@ async def upload_video_bytes(mp4_bytes: bytes) -> str:
|
|
137 |
return url
|
138 |
|
139 |
|
140 |
-
#
|
141 |
-
app = FastAPI(title="OpenAI-Compatible T2V Proxy (
|
142 |
-
|
143 |
|
144 |
@app.get("/health")
|
145 |
async def health():
|
146 |
-
return {"status": "ok", "
|
147 |
-
|
148 |
|
149 |
@app.post("/v1/chat/completions", response_model=ChatCompletionsResponse)
|
150 |
async def chat_completions(req: ChatCompletionsRequest):
|
151 |
"""
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
"""
|
158 |
prompt = extract_prompt(req.messages)
|
159 |
-
mp4 = await generate_video_bytes(prompt)
|
160 |
-
video_url = await upload_video_bytes(mp4)
|
161 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
now = int(time.time())
|
163 |
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
|
164 |
-
content = f"✅ Video generated & uploaded.\n**Prompt:** {prompt}\n**URL:** {
|
165 |
|
166 |
return ChatCompletionsResponse(
|
167 |
id=completion_id,
|
|
|
1 |
import os
|
2 |
import time
|
3 |
import uuid
|
4 |
+
from typing import List, Optional, Literal, Any, Dict
|
5 |
|
6 |
import httpx
|
7 |
from fastapi import FastAPI, HTTPException
|
8 |
from pydantic import BaseModel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
+
# ---------- Config via env ----------
|
11 |
+
HF_TOKEN = os.getenv("HF_TOKEN") # e.g. hf_jwt_...
|
12 |
+
# Default matches your curl submit endpoint
|
13 |
+
HF_SUBMIT_URL = os.getenv(
|
14 |
+
"HF_SUBMIT_URL",
|
15 |
+
"https://router.huggingface.co/fal-ai/fal-ai/wan/v2.2-a14b/text-to-video?_subdomain=queue",
|
16 |
+
)
|
17 |
UPLOAD_URL = os.getenv("UPLOAD_URL", "https://upload.snapzion.com/api/public-upload")
|
18 |
+
UPLOAD_ACCESS_TOKEN = os.getenv("UPLOAD_ACCESS_TOKEN") # your Snapzion bearer
|
19 |
|
20 |
+
# Polling/backoff
|
21 |
+
POLL_INTERVAL_SEC = float(os.getenv("POLL_INTERVAL_SEC", "3"))
|
22 |
+
POLL_TIMEOUT_SEC = int(os.getenv("POLL_TIMEOUT_SEC", "900")) # 15 min max
|
23 |
|
24 |
|
25 |
+
# ---------- OpenAI-compatible schemas ----------
|
26 |
class ChatMessage(BaseModel):
|
27 |
role: Literal["system", "user", "assistant", "tool"]
|
28 |
content: str
|
|
|
66 |
}
|
67 |
|
68 |
|
69 |
+
# ---------- Helpers ----------
|
70 |
def extract_prompt(messages: List[ChatMessage]) -> str:
|
|
|
71 |
for m in reversed(messages):
|
72 |
if m.role == "user" and m.content and m.content.strip():
|
73 |
return m.content.strip()
|
|
|
77 |
return "\n".join(user_texts).strip()
|
78 |
|
79 |
|
80 |
+
async def hf_queue_submit(prompt: str) -> Dict[str, str]:
|
81 |
+
"""
|
82 |
+
Submit a job to the HF router queue endpoint.
|
83 |
+
Returns a dict containing status_url and response_url (from HF).
|
84 |
+
"""
|
85 |
+
if not HF_TOKEN:
|
86 |
+
raise HTTPException(status_code=500, detail="HF_TOKEN is not set.")
|
87 |
+
|
88 |
+
headers = {
|
89 |
+
"Authorization": f"Bearer {HF_TOKEN}",
|
90 |
+
"Content-Type": "application/json",
|
91 |
+
"Accept": "*/*",
|
92 |
+
}
|
93 |
+
payload = {"prompt": prompt}
|
94 |
+
|
95 |
+
async with httpx.AsyncClient(timeout=None) as client:
|
96 |
+
resp = await client.post(HF_SUBMIT_URL, headers=headers, json=payload)
|
97 |
+
if resp.status_code >= 400:
|
98 |
+
raise HTTPException(status_code=502, detail=f"HF submit error: {resp.text}")
|
99 |
+
|
100 |
+
data = resp.json()
|
101 |
+
status_url = data.get("status_url") or data.get("urls", {}).get("status_url")
|
102 |
+
response_url = data.get("response_url") or data.get("urls", {}).get("response_url")
|
103 |
+
if not status_url or not response_url:
|
104 |
+
raise HTTPException(status_code=502, detail=f"Unexpected HF submit response: {data}")
|
105 |
+
|
106 |
+
return {"status_url": status_url, "response_url": response_url}
|
107 |
+
|
108 |
+
|
109 |
+
async def hf_queue_wait(status_url: str) -> None:
|
110 |
+
"""
|
111 |
+
Polls the HF queue status_url until COMPLETED or error states.
|
112 |
+
"""
|
113 |
if not HF_TOKEN:
|
114 |
raise HTTPException(status_code=500, detail="HF_TOKEN is not set.")
|
|
|
115 |
|
116 |
+
headers = {"Authorization": f"Bearer {HF_TOKEN}", "Accept": "*/*"}
|
117 |
+
start = time.time()
|
118 |
+
|
119 |
+
async with httpx.AsyncClient(timeout=None) as client:
|
120 |
+
while True:
|
121 |
+
resp = await client.get(status_url, headers=headers)
|
122 |
+
if resp.status_code >= 400:
|
123 |
+
raise HTTPException(status_code=502, detail=f"HF status error: {resp.text}")
|
124 |
+
|
125 |
+
data = resp.json()
|
126 |
+
status = data.get("status")
|
127 |
+
|
128 |
+
if status in ("COMPLETED", "SUCCEEDED"):
|
129 |
+
return
|
130 |
+
if status in ("FAILED", "ERROR", "CANCELLED", "CANCELED"):
|
131 |
+
raise HTTPException(status_code=502, detail=f"HF job failed: {data}")
|
132 |
+
|
133 |
+
if time.time() - start > POLL_TIMEOUT_SEC:
|
134 |
+
raise HTTPException(status_code=504, detail="HF queue timed out.")
|
135 |
+
|
136 |
+
time.sleep(POLL_INTERVAL_SEC)
|
137 |
+
|
138 |
|
139 |
+
async def hf_queue_fetch_result(response_url: str) -> Dict[str, Any]:
|
140 |
+
"""
|
141 |
+
Fetch the final response JSON, which includes {"video": {"url": ...}, ...}
|
142 |
+
"""
|
143 |
+
if not HF_TOKEN:
|
144 |
+
raise HTTPException(status_code=500, detail="HF_TOKEN is not set.")
|
|
|
|
|
|
|
145 |
|
146 |
+
headers = {"Authorization": f"Bearer {HF_TOKEN}", "Accept": "*/*"}
|
147 |
+
async with httpx.AsyncClient(timeout=None) as client:
|
148 |
+
resp = await client.get(response_url, headers=headers)
|
149 |
+
if resp.status_code >= 400:
|
150 |
+
raise HTTPException(status_code=502, detail=f"HF result error: {resp.text}")
|
151 |
+
return resp.json()
|
152 |
|
|
|
|
|
|
|
|
|
|
|
153 |
|
154 |
+
async def download_video(url: str) -> bytes:
|
155 |
+
async with httpx.AsyncClient(timeout=None) as client:
|
156 |
+
resp = await client.get(url)
|
157 |
+
if resp.status_code >= 400:
|
158 |
+
raise HTTPException(status_code=502, detail=f"Download failed: {resp.text}")
|
159 |
+
return resp.content
|
160 |
|
161 |
|
162 |
async def upload_video_bytes(mp4_bytes: bytes) -> str:
|
|
|
163 |
if not UPLOAD_ACCESS_TOKEN:
|
164 |
raise HTTPException(status_code=500, detail="UPLOAD_ACCESS_TOKEN is not set.")
|
165 |
headers = {"Authorization": f"Bearer {UPLOAD_ACCESS_TOKEN}"}
|
|
|
167 |
|
168 |
async with httpx.AsyncClient(timeout=None) as client:
|
169 |
resp = await client.post(UPLOAD_URL, headers=headers, files=files)
|
|
|
170 |
if resp.status_code >= 400:
|
171 |
raise HTTPException(status_code=502, detail=f"Upload failed: {resp.text}")
|
172 |
|
173 |
data = resp.json()
|
|
|
174 |
url = (
|
175 |
data.get("url")
|
176 |
or data.get("fileUrl")
|
|
|
182 |
return url
|
183 |
|
184 |
|
185 |
+
# ---------- FastAPI ----------
|
186 |
+
app = FastAPI(title="OpenAI-Compatible T2V Proxy (HF Router Queue)")
|
|
|
187 |
|
188 |
@app.get("/health")
|
189 |
async def health():
|
190 |
+
return {"status": "ok", "submit_url": HF_SUBMIT_URL}
|
|
|
191 |
|
192 |
@app.post("/v1/chat/completions", response_model=ChatCompletionsResponse)
|
193 |
async def chat_completions(req: ChatCompletionsRequest):
|
194 |
"""
|
195 |
+
1) submit to HF router queue (Bearer HF_TOKEN)
|
196 |
+
2) poll status_url until COMPLETED
|
197 |
+
3) fetch response_url -> video.url
|
198 |
+
4) download MP4, upload to Snapzion
|
199 |
+
5) return URL in OpenAI chat shape
|
200 |
"""
|
201 |
prompt = extract_prompt(req.messages)
|
|
|
|
|
202 |
|
203 |
+
# 1) Submit
|
204 |
+
urls = await hf_queue_submit(prompt)
|
205 |
+
|
206 |
+
# 2) Wait
|
207 |
+
await hf_queue_wait(urls["status_url"])
|
208 |
+
|
209 |
+
# 3) Fetch result JSON
|
210 |
+
result = await hf_queue_fetch_result(urls["response_url"])
|
211 |
+
video_url = (result.get("video") or {}).get("url")
|
212 |
+
if not video_url:
|
213 |
+
raise HTTPException(status_code=502, detail=f"HF result missing video.url: {result}")
|
214 |
+
|
215 |
+
# 4) Download + re-upload
|
216 |
+
mp4 = await download_video(video_url)
|
217 |
+
public_url = await upload_video_bytes(mp4)
|
218 |
+
|
219 |
+
# 5) Respond OpenAI-style
|
220 |
now = int(time.time())
|
221 |
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
|
222 |
+
content = f"✅ Video generated & uploaded.\n**Prompt:** {prompt}\n**URL:** {public_url}"
|
223 |
|
224 |
return ChatCompletionsResponse(
|
225 |
id=completion_id,
|