import os import time import uuid from typing import List, Optional, Literal, Any, Dict import httpx from fastapi import FastAPI, HTTPException from pydantic import BaseModel # ---------- Config via env ---------- HF_TOKEN = os.getenv("HF_TOKEN") # e.g. hf_jwt_... # Default matches your curl submit endpoint HF_SUBMIT_URL = os.getenv( "HF_SUBMIT_URL", "https://router.huggingface.co/fal-ai/fal-ai/wan/v2.2-a14b/text-to-video?_subdomain=queue", ) UPLOAD_URL = os.getenv("UPLOAD_URL", "https://upload.snapzion.com/api/public-upload") UPLOAD_ACCESS_TOKEN = os.getenv("UPLOAD_ACCESS_TOKEN") # your Snapzion bearer # Polling/backoff POLL_INTERVAL_SEC = float(os.getenv("POLL_INTERVAL_SEC", "3")) POLL_TIMEOUT_SEC = int(os.getenv("POLL_TIMEOUT_SEC", "900")) # 15 min max # ---------- OpenAI-compatible schemas ---------- class ChatMessage(BaseModel): role: Literal["system", "user", "assistant", "tool"] content: str class ChatCompletionsRequest(BaseModel): model: str messages: List[ChatMessage] temperature: Optional[float] = None max_tokens: Optional[int] = None stream: Optional[bool] = False n: Optional[int] = 1 top_p: Optional[float] = None presence_penalty: Optional[float] = None frequency_penalty: Optional[float] = None tools: Optional[Any] = None tool_choice: Optional[Any] = None class ChatChoiceMessage(BaseModel): role: Literal["assistant"] content: str class ChatChoice(BaseModel): index: int message: ChatChoiceMessage finish_reason: Literal["stop", "length", "content_filter", "tool_calls"] = "stop" class ChatCompletionsResponse(BaseModel): id: str object: Literal["chat.completion"] = "chat.completion" created: int model: str choices: List[ChatChoice] usage: Optional[Dict[str, int]] = { "prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0, } # ---------- Helpers ---------- def extract_prompt(messages: List[ChatMessage]) -> str: for m in reversed(messages): if m.role == "user" and m.content and m.content.strip(): return m.content.strip() user_texts = [m.content for m in messages if m.role == "user" and m.content] if not user_texts: raise HTTPException(status_code=400, detail="No user prompt provided.") return "\n".join(user_texts).strip() async def hf_queue_submit(prompt: str) -> Dict[str, str]: """ Submit a job to the HF router queue endpoint. Returns a dict containing status_url and response_url (from HF). """ if not HF_TOKEN: raise HTTPException(status_code=500, detail="HF_TOKEN is not set.") headers = { "Authorization": f"Bearer {HF_TOKEN}", "Content-Type": "application/json", "Accept": "*/*", } payload = {"prompt": prompt} async with httpx.AsyncClient(timeout=None) as client: resp = await client.post(HF_SUBMIT_URL, headers=headers, json=payload) if resp.status_code >= 400: raise HTTPException(status_code=502, detail=f"HF submit error: {resp.text}") data = resp.json() status_url = data.get("status_url") or data.get("urls", {}).get("status_url") response_url = data.get("response_url") or data.get("urls", {}).get("response_url") if not status_url or not response_url: raise HTTPException(status_code=502, detail=f"Unexpected HF submit response: {data}") return {"status_url": status_url, "response_url": response_url} async def hf_queue_wait(status_url: str) -> None: """ Polls the HF queue status_url until COMPLETED or error states. """ if not HF_TOKEN: raise HTTPException(status_code=500, detail="HF_TOKEN is not set.") headers = {"Authorization": f"Bearer {HF_TOKEN}", "Accept": "*/*"} start = time.time() async with httpx.AsyncClient(timeout=None) as client: while True: resp = await client.get(status_url, headers=headers) if resp.status_code >= 400: raise HTTPException(status_code=502, detail=f"HF status error: {resp.text}") data = resp.json() status = data.get("status") if status in ("COMPLETED", "SUCCEEDED"): return if status in ("FAILED", "ERROR", "CANCELLED", "CANCELED"): raise HTTPException(status_code=502, detail=f"HF job failed: {data}") if time.time() - start > POLL_TIMEOUT_SEC: raise HTTPException(status_code=504, detail="HF queue timed out.") time.sleep(POLL_INTERVAL_SEC) async def hf_queue_fetch_result(response_url: str) -> Dict[str, Any]: """ Fetch the final response JSON, which includes {"video": {"url": ...}, ...} """ if not HF_TOKEN: raise HTTPException(status_code=500, detail="HF_TOKEN is not set.") headers = {"Authorization": f"Bearer {HF_TOKEN}", "Accept": "*/*"} async with httpx.AsyncClient(timeout=None) as client: resp = await client.get(response_url, headers=headers) if resp.status_code >= 400: raise HTTPException(status_code=502, detail=f"HF result error: {resp.text}") return resp.json() async def download_video(url: str) -> bytes: async with httpx.AsyncClient(timeout=None) as client: resp = await client.get(url) if resp.status_code >= 400: raise HTTPException(status_code=502, detail=f"Download failed: {resp.text}") return resp.content async def upload_video_bytes(mp4_bytes: bytes) -> str: if not UPLOAD_ACCESS_TOKEN: raise HTTPException(status_code=500, detail="UPLOAD_ACCESS_TOKEN is not set.") headers = {"Authorization": f"Bearer {UPLOAD_ACCESS_TOKEN}"} files = {"file": ("video.mp4", mp4_bytes, "video/mp4")} async with httpx.AsyncClient(timeout=None) as client: resp = await client.post(UPLOAD_URL, headers=headers, files=files) if resp.status_code >= 400: raise HTTPException(status_code=502, detail=f"Upload failed: {resp.text}") data = resp.json() url = ( data.get("url") or data.get("fileUrl") or data.get("publicUrl") or data.get("data", {}).get("url") ) if not url: raise HTTPException(status_code=502, detail=f"Upload response missing URL: {data}") return url # ---------- FastAPI ---------- app = FastAPI(title="OpenAI-Compatible T2V Proxy (HF Router Queue)") @app.get("/health") async def health(): return {"status": "ok", "submit_url": HF_SUBMIT_URL} @app.post("/v1/chat/completions", response_model=ChatCompletionsResponse) async def chat_completions(req: ChatCompletionsRequest): """ 1) submit to HF router queue (Bearer HF_TOKEN) 2) poll status_url until COMPLETED 3) fetch response_url -> video.url 4) download MP4, upload to Snapzion 5) return URL in OpenAI chat shape """ prompt = extract_prompt(req.messages) # 1) Submit urls = await hf_queue_submit(prompt) # 2) Wait await hf_queue_wait(urls["status_url"]) # 3) Fetch result JSON result = await hf_queue_fetch_result(urls["response_url"]) video_url = (result.get("video") or {}).get("url") if not video_url: raise HTTPException(status_code=502, detail=f"HF result missing video.url: {result}") # 4) Download + re-upload mp4 = await download_video(video_url) public_url = await upload_video_bytes(mp4) # 5) Respond OpenAI-style now = int(time.time()) completion_id = f"chatcmpl-{uuid.uuid4().hex}" content = f"✅ Video generated & uploaded.\n**Prompt:** {prompt}\n**URL:** {public_url}" return ChatCompletionsResponse( id=completion_id, created=now, model=req.model, choices=[ ChatChoice( index=0, message=ChatChoiceMessage(role="assistant", content=content), finish_reason="stop", ) ], )