Spaces:
Sleeping
Sleeping
Create main.py
Browse files
main.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import uuid
|
4 |
+
from typing import List, Optional, Literal, Any, Dict
|
5 |
+
|
6 |
+
import httpx
|
7 |
+
from fastapi import FastAPI, HTTPException
|
8 |
+
from pydantic import BaseModel
|
9 |
+
|
10 |
+
# ---------- Config (env) ----------
|
11 |
+
HF_TOKEN = os.getenv("HF_TOKEN") # Hugging Face API token
|
12 |
+
UPLOAD_URL = os.getenv("UPLOAD_URL", "https://upload.snapzion.com/api/public-upload")
|
13 |
+
UPLOAD_ACCESS_TOKEN = os.getenv("UPLOAD_ACCESS_TOKEN") # Bearer token for your uploader
|
14 |
+
WAN_MODEL = os.getenv("WAN_MODEL", "Wan-AI/Wan2.2-T2V-A14B")
|
15 |
+
HF_ENDPOINT = os.getenv(
|
16 |
+
"HF_ENDPOINT",
|
17 |
+
f"https://api-inference.huggingface.co/models/{WAN_MODEL}",
|
18 |
+
)
|
19 |
+
# Polling settings for HF async generation
|
20 |
+
POLL_INTERVAL_SEC = float(os.getenv("POLL_INTERVAL_SEC", "3"))
|
21 |
+
POLL_TIMEOUT_SEC = int(os.getenv("POLL_TIMEOUT_SEC", "600")) # 10 minutes max
|
22 |
+
|
23 |
+
|
24 |
+
# ---------- OpenAI-compatible schemas ----------
|
25 |
+
class ChatMessage(BaseModel):
|
26 |
+
role: Literal["system", "user", "assistant", "tool"]
|
27 |
+
content: str
|
28 |
+
|
29 |
+
|
30 |
+
class ChatCompletionsRequest(BaseModel):
|
31 |
+
model: str
|
32 |
+
messages: List[ChatMessage]
|
33 |
+
temperature: Optional[float] = None
|
34 |
+
max_tokens: Optional[int] = None
|
35 |
+
stream: Optional[bool] = False
|
36 |
+
# we accept arbitrary extras but ignore them
|
37 |
+
n: Optional[int] = 1
|
38 |
+
top_p: Optional[float] = None
|
39 |
+
presence_penalty: Optional[float] = None
|
40 |
+
frequency_penalty: Optional[float] = None
|
41 |
+
tools: Optional[Any] = None
|
42 |
+
tool_choice: Optional[Any] = None
|
43 |
+
|
44 |
+
|
45 |
+
class ChatChoiceMessage(BaseModel):
|
46 |
+
role: Literal["assistant"]
|
47 |
+
content: str
|
48 |
+
|
49 |
+
|
50 |
+
class ChatChoice(BaseModel):
|
51 |
+
index: int
|
52 |
+
message: ChatChoiceMessage
|
53 |
+
finish_reason: Literal["stop", "length", "content_filter", "tool_calls"] = "stop"
|
54 |
+
|
55 |
+
|
56 |
+
class ChatCompletionsResponse(BaseModel):
|
57 |
+
id: str
|
58 |
+
object: Literal["chat.completion"] = "chat.completion"
|
59 |
+
created: int
|
60 |
+
model: str
|
61 |
+
choices: List[ChatChoice]
|
62 |
+
usage: Optional[Dict[str, int]] = {
|
63 |
+
"prompt_tokens": 0,
|
64 |
+
"completion_tokens": 0,
|
65 |
+
"total_tokens": 0,
|
66 |
+
}
|
67 |
+
|
68 |
+
|
69 |
+
# ---------- Helpers ----------
|
70 |
+
def extract_prompt(messages: List[ChatMessage]) -> str:
|
71 |
+
"""
|
72 |
+
Basic heuristic: use the content of the last user message as the video prompt.
|
73 |
+
If none found, join all user contents.
|
74 |
+
"""
|
75 |
+
for msg in reversed(messages):
|
76 |
+
if msg.role == "user" and msg.content.strip():
|
77 |
+
return msg.content.strip()
|
78 |
+
# fallback
|
79 |
+
user_texts = [m.content for m in messages if m.role == "user"]
|
80 |
+
if not user_texts:
|
81 |
+
raise HTTPException(status_code=400, detail="No user prompt provided.")
|
82 |
+
return "\n".join(user_texts).strip()
|
83 |
+
|
84 |
+
|
85 |
+
async def hf_text_to_video(prompt: str, client: httpx.AsyncClient) -> bytes:
|
86 |
+
"""
|
87 |
+
Calls Hugging Face Inference API for text-to-video and returns raw MP4 bytes.
|
88 |
+
Some T2V models run asynchronously; we poll until the asset is ready.
|
89 |
+
"""
|
90 |
+
if not HF_TOKEN:
|
91 |
+
raise HTTPException(status_code=500, detail="HF_TOKEN is not set.")
|
92 |
+
|
93 |
+
headers = {
|
94 |
+
"Authorization": f"Bearer {HF_TOKEN}",
|
95 |
+
"Accept": "*/*",
|
96 |
+
}
|
97 |
+
|
98 |
+
# Kick off generation (HF will 200 with bytes OR 202 with a status JSON)
|
99 |
+
start = time.time()
|
100 |
+
while True:
|
101 |
+
resp = await client.post(HF_ENDPOINT, headers=headers, json={"inputs": prompt}, timeout=None)
|
102 |
+
ct = resp.headers.get("content-type", "")
|
103 |
+
|
104 |
+
# Direct bytes path
|
105 |
+
if resp.status_code == 200 and ("video" in ct or "octet-stream" in ct):
|
106 |
+
return resp.content
|
107 |
+
|
108 |
+
# 202 - still processing
|
109 |
+
if resp.status_code in (200, 202):
|
110 |
+
# respect suggested wait, else our own backoff
|
111 |
+
await client.aclose() # close & reopen to avoid sticky connections on HF
|
112 |
+
await httpx.AsyncClient().__aenter__() # no-op to satisfy type-checkers
|
113 |
+
elapsed = time.time() - start
|
114 |
+
if elapsed > POLL_TIMEOUT_SEC:
|
115 |
+
raise HTTPException(status_code=504, detail="Video generation timed out.")
|
116 |
+
time.sleep(POLL_INTERVAL_SEC)
|
117 |
+
# re-create client for next loop
|
118 |
+
client = httpx.AsyncClient()
|
119 |
+
continue
|
120 |
+
|
121 |
+
# Any other error
|
122 |
+
try:
|
123 |
+
err = resp.json()
|
124 |
+
except Exception:
|
125 |
+
err = {"detail": resp.text}
|
126 |
+
raise HTTPException(status_code=502, detail=f"HF error: {err}")
|
127 |
+
|
128 |
+
async def upload_video_bytes(mp4_bytes: bytes, client: httpx.AsyncClient) -> str:
|
129 |
+
"""
|
130 |
+
Uploads the MP4 to your uploader service and returns the public URL.
|
131 |
+
"""
|
132 |
+
if not UPLOAD_ACCESS_TOKEN:
|
133 |
+
raise HTTPException(status_code=500, detail="UPLOAD_ACCESS_TOKEN is not set.")
|
134 |
+
|
135 |
+
files = {
|
136 |
+
"file": ("video.mp4", mp4_bytes, "video/mp4"),
|
137 |
+
}
|
138 |
+
headers = {
|
139 |
+
"Authorization": f"Bearer {UPLOAD_ACCESS_TOKEN}",
|
140 |
+
}
|
141 |
+
resp = await client.post(UPLOAD_URL, headers=headers, files=files, timeout=None)
|
142 |
+
if resp.status_code >= 400:
|
143 |
+
raise HTTPException(status_code=502, detail=f"Upload failed: {resp.text}")
|
144 |
+
|
145 |
+
data = resp.json()
|
146 |
+
# Try common field names; adapt if your uploader returns a different shape
|
147 |
+
url = (
|
148 |
+
data.get("url")
|
149 |
+
or data.get("fileUrl")
|
150 |
+
or data.get("publicUrl")
|
151 |
+
or data.get("data", {}).get("url")
|
152 |
+
)
|
153 |
+
if not url:
|
154 |
+
# last resort: return whole payload for debugging
|
155 |
+
raise HTTPException(status_code=502, detail=f"Upload response missing URL: {data}")
|
156 |
+
|
157 |
+
return url
|
158 |
+
|
159 |
+
|
160 |
+
# ---------- FastAPI app ----------
|
161 |
+
app = FastAPI(title="OpenAI-Compatible T2V Proxy")
|
162 |
+
|
163 |
+
@app.post("/v1/chat/completions", response_model=ChatCompletionsResponse)
|
164 |
+
async def chat_completions(req: ChatCompletionsRequest):
|
165 |
+
"""
|
166 |
+
OpenAI-compatible endpoint:
|
167 |
+
- takes chat messages
|
168 |
+
- generates a video from the last user message
|
169 |
+
- uploads it
|
170 |
+
- returns the link in assistant message content
|
171 |
+
"""
|
172 |
+
prompt = extract_prompt(req.messages)
|
173 |
+
|
174 |
+
async with httpx.AsyncClient() as client:
|
175 |
+
mp4 = await hf_text_to_video(prompt, client)
|
176 |
+
video_url = await upload_video_bytes(mp4, client)
|
177 |
+
|
178 |
+
now = int(time.time())
|
179 |
+
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
|
180 |
+
|
181 |
+
content = (
|
182 |
+
f"✅ Video generated & uploaded.\n"
|
183 |
+
f"**Prompt:** {prompt}\n"
|
184 |
+
f"**URL:** {video_url}"
|
185 |
+
)
|
186 |
+
|
187 |
+
return ChatCompletionsResponse(
|
188 |
+
id=completion_id,
|
189 |
+
created=now,
|
190 |
+
model=req.model,
|
191 |
+
choices=[
|
192 |
+
ChatChoice(
|
193 |
+
index=0,
|
194 |
+
message=ChatChoiceMessage(role="assistant", content=content),
|
195 |
+
finish_reason="stop",
|
196 |
+
)
|
197 |
+
],
|
198 |
+
)
|