Spaces:
Running
Running
import os | |
import time | |
import uuid | |
from typing import List, Optional, Literal, Any, Dict | |
import httpx | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
# ---------- Config via env ---------- | |
HF_TOKEN = os.getenv("HF_TOKEN") # e.g. hf_jwt_... | |
# Default matches your curl submit endpoint | |
HF_SUBMIT_URL = os.getenv( | |
"HF_SUBMIT_URL", | |
"https://router.huggingface.co/fal-ai/fal-ai/wan/v2.2-a14b/text-to-video?_subdomain=queue", | |
) | |
UPLOAD_URL = os.getenv("UPLOAD_URL", "https://upload.snapzion.com/api/public-upload") | |
UPLOAD_ACCESS_TOKEN = os.getenv("UPLOAD_ACCESS_TOKEN") # your Snapzion bearer | |
# Polling/backoff | |
POLL_INTERVAL_SEC = float(os.getenv("POLL_INTERVAL_SEC", "3")) | |
POLL_TIMEOUT_SEC = int(os.getenv("POLL_TIMEOUT_SEC", "900")) # 15 min max | |
# ---------- OpenAI-compatible schemas ---------- | |
class ChatMessage(BaseModel): | |
role: Literal["system", "user", "assistant", "tool"] | |
content: str | |
class ChatCompletionsRequest(BaseModel): | |
model: str | |
messages: List[ChatMessage] | |
temperature: Optional[float] = None | |
max_tokens: Optional[int] = None | |
stream: Optional[bool] = False | |
n: Optional[int] = 1 | |
top_p: Optional[float] = None | |
presence_penalty: Optional[float] = None | |
frequency_penalty: Optional[float] = None | |
tools: Optional[Any] = None | |
tool_choice: Optional[Any] = None | |
class ChatChoiceMessage(BaseModel): | |
role: Literal["assistant"] | |
content: str | |
class ChatChoice(BaseModel): | |
index: int | |
message: ChatChoiceMessage | |
finish_reason: Literal["stop", "length", "content_filter", "tool_calls"] = "stop" | |
class ChatCompletionsResponse(BaseModel): | |
id: str | |
object: Literal["chat.completion"] = "chat.completion" | |
created: int | |
model: str | |
choices: List[ChatChoice] | |
usage: Optional[Dict[str, int]] = { | |
"prompt_tokens": 0, | |
"completion_tokens": 0, | |
"total_tokens": 0, | |
} | |
# ---------- Helpers ---------- | |
def extract_prompt(messages: List[ChatMessage]) -> str: | |
for m in reversed(messages): | |
if m.role == "user" and m.content and m.content.strip(): | |
return m.content.strip() | |
user_texts = [m.content for m in messages if m.role == "user" and m.content] | |
if not user_texts: | |
raise HTTPException(status_code=400, detail="No user prompt provided.") | |
return "\n".join(user_texts).strip() | |
async def hf_queue_submit(prompt: str) -> Dict[str, str]: | |
""" | |
Submit a job to the HF router queue endpoint. | |
Returns a dict containing status_url and response_url (from HF). | |
""" | |
if not HF_TOKEN: | |
raise HTTPException(status_code=500, detail="HF_TOKEN is not set.") | |
headers = { | |
"Authorization": f"Bearer {HF_TOKEN}", | |
"Content-Type": "application/json", | |
"Accept": "*/*", | |
} | |
payload = {"prompt": prompt} | |
async with httpx.AsyncClient(timeout=None) as client: | |
resp = await client.post(HF_SUBMIT_URL, headers=headers, json=payload) | |
if resp.status_code >= 400: | |
raise HTTPException(status_code=502, detail=f"HF submit error: {resp.text}") | |
data = resp.json() | |
status_url = data.get("status_url") or data.get("urls", {}).get("status_url") | |
response_url = data.get("response_url") or data.get("urls", {}).get("response_url") | |
if not status_url or not response_url: | |
raise HTTPException(status_code=502, detail=f"Unexpected HF submit response: {data}") | |
return {"status_url": status_url, "response_url": response_url} | |
async def hf_queue_wait(status_url: str) -> None: | |
""" | |
Polls the HF queue status_url until COMPLETED or error states. | |
""" | |
if not HF_TOKEN: | |
raise HTTPException(status_code=500, detail="HF_TOKEN is not set.") | |
headers = {"Authorization": f"Bearer {HF_TOKEN}", "Accept": "*/*"} | |
start = time.time() | |
async with httpx.AsyncClient(timeout=None) as client: | |
while True: | |
resp = await client.get(status_url, headers=headers) | |
if resp.status_code >= 400: | |
raise HTTPException(status_code=502, detail=f"HF status error: {resp.text}") | |
data = resp.json() | |
status = data.get("status") | |
if status in ("COMPLETED", "SUCCEEDED"): | |
return | |
if status in ("FAILED", "ERROR", "CANCELLED", "CANCELED"): | |
raise HTTPException(status_code=502, detail=f"HF job failed: {data}") | |
if time.time() - start > POLL_TIMEOUT_SEC: | |
raise HTTPException(status_code=504, detail="HF queue timed out.") | |
time.sleep(POLL_INTERVAL_SEC) | |
async def hf_queue_fetch_result(response_url: str) -> Dict[str, Any]: | |
""" | |
Fetch the final response JSON, which includes {"video": {"url": ...}, ...} | |
""" | |
if not HF_TOKEN: | |
raise HTTPException(status_code=500, detail="HF_TOKEN is not set.") | |
headers = {"Authorization": f"Bearer {HF_TOKEN}", "Accept": "*/*"} | |
async with httpx.AsyncClient(timeout=None) as client: | |
resp = await client.get(response_url, headers=headers) | |
if resp.status_code >= 400: | |
raise HTTPException(status_code=502, detail=f"HF result error: {resp.text}") | |
return resp.json() | |
async def download_video(url: str) -> bytes: | |
async with httpx.AsyncClient(timeout=None) as client: | |
resp = await client.get(url) | |
if resp.status_code >= 400: | |
raise HTTPException(status_code=502, detail=f"Download failed: {resp.text}") | |
return resp.content | |
async def upload_video_bytes(mp4_bytes: bytes) -> str: | |
if not UPLOAD_ACCESS_TOKEN: | |
raise HTTPException(status_code=500, detail="UPLOAD_ACCESS_TOKEN is not set.") | |
headers = {"Authorization": f"Bearer {UPLOAD_ACCESS_TOKEN}"} | |
files = {"file": ("video.mp4", mp4_bytes, "video/mp4")} | |
async with httpx.AsyncClient(timeout=None) as client: | |
resp = await client.post(UPLOAD_URL, headers=headers, files=files) | |
if resp.status_code >= 400: | |
raise HTTPException(status_code=502, detail=f"Upload failed: {resp.text}") | |
data = resp.json() | |
url = ( | |
data.get("url") | |
or data.get("fileUrl") | |
or data.get("publicUrl") | |
or data.get("data", {}).get("url") | |
) | |
if not url: | |
raise HTTPException(status_code=502, detail=f"Upload response missing URL: {data}") | |
return url | |
# ---------- FastAPI ---------- | |
app = FastAPI(title="OpenAI-Compatible T2V Proxy (HF Router Queue)") | |
async def health(): | |
return {"status": "ok", "submit_url": HF_SUBMIT_URL} | |
async def chat_completions(req: ChatCompletionsRequest): | |
""" | |
1) submit to HF router queue (Bearer HF_TOKEN) | |
2) poll status_url until COMPLETED | |
3) fetch response_url -> video.url | |
4) download MP4, upload to Snapzion | |
5) return URL in OpenAI chat shape | |
""" | |
prompt = extract_prompt(req.messages) | |
# 1) Submit | |
urls = await hf_queue_submit(prompt) | |
# 2) Wait | |
await hf_queue_wait(urls["status_url"]) | |
# 3) Fetch result JSON | |
result = await hf_queue_fetch_result(urls["response_url"]) | |
video_url = (result.get("video") or {}).get("url") | |
if not video_url: | |
raise HTTPException(status_code=502, detail=f"HF result missing video.url: {result}") | |
# 4) Download + re-upload | |
mp4 = await download_video(video_url) | |
public_url = await upload_video_bytes(mp4) | |
# 5) Respond OpenAI-style | |
now = int(time.time()) | |
completion_id = f"chatcmpl-{uuid.uuid4().hex}" | |
content = f"✅ Video generated & uploaded.\n**Prompt:** {prompt}\n**URL:** {public_url}" | |
return ChatCompletionsResponse( | |
id=completion_id, | |
created=now, | |
model=req.model, | |
choices=[ | |
ChatChoice( | |
index=0, | |
message=ChatChoiceMessage(role="assistant", content=content), | |
finish_reason="stop", | |
) | |
], | |
) | |