Spaces:
Sleeping
Sleeping
File size: 6,511 Bytes
c7e3d6d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
import os
import time
import uuid
from typing import List, Optional, Literal, Any, Dict
import httpx
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
# ---------- Config (env) ----------
HF_TOKEN = os.getenv("HF_TOKEN") # Hugging Face API token
UPLOAD_URL = os.getenv("UPLOAD_URL", "https://upload.snapzion.com/api/public-upload")
UPLOAD_ACCESS_TOKEN = os.getenv("UPLOAD_ACCESS_TOKEN") # Bearer token for your uploader
WAN_MODEL = os.getenv("WAN_MODEL", "Wan-AI/Wan2.2-T2V-A14B")
HF_ENDPOINT = os.getenv(
"HF_ENDPOINT",
f"https://api-inference.huggingface.co/models/{WAN_MODEL}",
)
# Polling settings for HF async generation
POLL_INTERVAL_SEC = float(os.getenv("POLL_INTERVAL_SEC", "3"))
POLL_TIMEOUT_SEC = int(os.getenv("POLL_TIMEOUT_SEC", "600")) # 10 minutes max
# ---------- OpenAI-compatible schemas ----------
class ChatMessage(BaseModel):
role: Literal["system", "user", "assistant", "tool"]
content: str
class ChatCompletionsRequest(BaseModel):
model: str
messages: List[ChatMessage]
temperature: Optional[float] = None
max_tokens: Optional[int] = None
stream: Optional[bool] = False
# we accept arbitrary extras but ignore them
n: Optional[int] = 1
top_p: Optional[float] = None
presence_penalty: Optional[float] = None
frequency_penalty: Optional[float] = None
tools: Optional[Any] = None
tool_choice: Optional[Any] = None
class ChatChoiceMessage(BaseModel):
role: Literal["assistant"]
content: str
class ChatChoice(BaseModel):
index: int
message: ChatChoiceMessage
finish_reason: Literal["stop", "length", "content_filter", "tool_calls"] = "stop"
class ChatCompletionsResponse(BaseModel):
id: str
object: Literal["chat.completion"] = "chat.completion"
created: int
model: str
choices: List[ChatChoice]
usage: Optional[Dict[str, int]] = {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
}
# ---------- Helpers ----------
def extract_prompt(messages: List[ChatMessage]) -> str:
"""
Basic heuristic: use the content of the last user message as the video prompt.
If none found, join all user contents.
"""
for msg in reversed(messages):
if msg.role == "user" and msg.content.strip():
return msg.content.strip()
# fallback
user_texts = [m.content for m in messages if m.role == "user"]
if not user_texts:
raise HTTPException(status_code=400, detail="No user prompt provided.")
return "\n".join(user_texts).strip()
async def hf_text_to_video(prompt: str, client: httpx.AsyncClient) -> bytes:
"""
Calls Hugging Face Inference API for text-to-video and returns raw MP4 bytes.
Some T2V models run asynchronously; we poll until the asset is ready.
"""
if not HF_TOKEN:
raise HTTPException(status_code=500, detail="HF_TOKEN is not set.")
headers = {
"Authorization": f"Bearer {HF_TOKEN}",
"Accept": "*/*",
}
# Kick off generation (HF will 200 with bytes OR 202 with a status JSON)
start = time.time()
while True:
resp = await client.post(HF_ENDPOINT, headers=headers, json={"inputs": prompt}, timeout=None)
ct = resp.headers.get("content-type", "")
# Direct bytes path
if resp.status_code == 200 and ("video" in ct or "octet-stream" in ct):
return resp.content
# 202 - still processing
if resp.status_code in (200, 202):
# respect suggested wait, else our own backoff
await client.aclose() # close & reopen to avoid sticky connections on HF
await httpx.AsyncClient().__aenter__() # no-op to satisfy type-checkers
elapsed = time.time() - start
if elapsed > POLL_TIMEOUT_SEC:
raise HTTPException(status_code=504, detail="Video generation timed out.")
time.sleep(POLL_INTERVAL_SEC)
# re-create client for next loop
client = httpx.AsyncClient()
continue
# Any other error
try:
err = resp.json()
except Exception:
err = {"detail": resp.text}
raise HTTPException(status_code=502, detail=f"HF error: {err}")
async def upload_video_bytes(mp4_bytes: bytes, client: httpx.AsyncClient) -> str:
"""
Uploads the MP4 to your uploader service and returns the public URL.
"""
if not UPLOAD_ACCESS_TOKEN:
raise HTTPException(status_code=500, detail="UPLOAD_ACCESS_TOKEN is not set.")
files = {
"file": ("video.mp4", mp4_bytes, "video/mp4"),
}
headers = {
"Authorization": f"Bearer {UPLOAD_ACCESS_TOKEN}",
}
resp = await client.post(UPLOAD_URL, headers=headers, files=files, timeout=None)
if resp.status_code >= 400:
raise HTTPException(status_code=502, detail=f"Upload failed: {resp.text}")
data = resp.json()
# Try common field names; adapt if your uploader returns a different shape
url = (
data.get("url")
or data.get("fileUrl")
or data.get("publicUrl")
or data.get("data", {}).get("url")
)
if not url:
# last resort: return whole payload for debugging
raise HTTPException(status_code=502, detail=f"Upload response missing URL: {data}")
return url
# ---------- FastAPI app ----------
app = FastAPI(title="OpenAI-Compatible T2V Proxy")
@app.post("/v1/chat/completions", response_model=ChatCompletionsResponse)
async def chat_completions(req: ChatCompletionsRequest):
"""
OpenAI-compatible endpoint:
- takes chat messages
- generates a video from the last user message
- uploads it
- returns the link in assistant message content
"""
prompt = extract_prompt(req.messages)
async with httpx.AsyncClient() as client:
mp4 = await hf_text_to_video(prompt, client)
video_url = await upload_video_bytes(mp4, client)
now = int(time.time())
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
content = (
f"✅ Video generated & uploaded.\n"
f"**Prompt:** {prompt}\n"
f"**URL:** {video_url}"
)
return ChatCompletionsResponse(
id=completion_id,
created=now,
model=req.model,
choices=[
ChatChoice(
index=0,
message=ChatChoiceMessage(role="assistant", content=content),
finish_reason="stop",
)
],
)
|