import os import time import uuid from typing import List, Optional, Literal, Any, Dict import httpx from fastapi import FastAPI, HTTPException from pydantic import BaseModel # ---------- Config (env) ---------- HF_TOKEN = os.getenv("HF_TOKEN") # Hugging Face API token UPLOAD_URL = os.getenv("UPLOAD_URL", "https://upload.snapzion.com/api/public-upload") UPLOAD_ACCESS_TOKEN = os.getenv("UPLOAD_ACCESS_TOKEN") # Bearer token for your uploader WAN_MODEL = os.getenv("WAN_MODEL", "Wan-AI/Wan2.2-T2V-A14B") HF_ENDPOINT = os.getenv( "HF_ENDPOINT", f"https://api-inference.huggingface.co/models/{WAN_MODEL}", ) # Polling settings for HF async generation POLL_INTERVAL_SEC = float(os.getenv("POLL_INTERVAL_SEC", "3")) POLL_TIMEOUT_SEC = int(os.getenv("POLL_TIMEOUT_SEC", "600")) # 10 minutes max # ---------- OpenAI-compatible schemas ---------- class ChatMessage(BaseModel): role: Literal["system", "user", "assistant", "tool"] content: str class ChatCompletionsRequest(BaseModel): model: str messages: List[ChatMessage] temperature: Optional[float] = None max_tokens: Optional[int] = None stream: Optional[bool] = False # we accept arbitrary extras but ignore them n: Optional[int] = 1 top_p: Optional[float] = None presence_penalty: Optional[float] = None frequency_penalty: Optional[float] = None tools: Optional[Any] = None tool_choice: Optional[Any] = None class ChatChoiceMessage(BaseModel): role: Literal["assistant"] content: str class ChatChoice(BaseModel): index: int message: ChatChoiceMessage finish_reason: Literal["stop", "length", "content_filter", "tool_calls"] = "stop" class ChatCompletionsResponse(BaseModel): id: str object: Literal["chat.completion"] = "chat.completion" created: int model: str choices: List[ChatChoice] usage: Optional[Dict[str, int]] = { "prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0, } # ---------- Helpers ---------- def extract_prompt(messages: List[ChatMessage]) -> str: """ Basic heuristic: use the content of the last user message as the video prompt. If none found, join all user contents. """ for msg in reversed(messages): if msg.role == "user" and msg.content.strip(): return msg.content.strip() # fallback user_texts = [m.content for m in messages if m.role == "user"] if not user_texts: raise HTTPException(status_code=400, detail="No user prompt provided.") return "\n".join(user_texts).strip() async def hf_text_to_video(prompt: str, client: httpx.AsyncClient) -> bytes: """ Calls Hugging Face Inference API for text-to-video and returns raw MP4 bytes. Some T2V models run asynchronously; we poll until the asset is ready. """ if not HF_TOKEN: raise HTTPException(status_code=500, detail="HF_TOKEN is not set.") headers = { "Authorization": f"Bearer {HF_TOKEN}", "Accept": "*/*", } # Kick off generation (HF will 200 with bytes OR 202 with a status JSON) start = time.time() while True: resp = await client.post(HF_ENDPOINT, headers=headers, json={"inputs": prompt}, timeout=None) ct = resp.headers.get("content-type", "") # Direct bytes path if resp.status_code == 200 and ("video" in ct or "octet-stream" in ct): return resp.content # 202 - still processing if resp.status_code in (200, 202): # respect suggested wait, else our own backoff await client.aclose() # close & reopen to avoid sticky connections on HF await httpx.AsyncClient().__aenter__() # no-op to satisfy type-checkers elapsed = time.time() - start if elapsed > POLL_TIMEOUT_SEC: raise HTTPException(status_code=504, detail="Video generation timed out.") time.sleep(POLL_INTERVAL_SEC) # re-create client for next loop client = httpx.AsyncClient() continue # Any other error try: err = resp.json() except Exception: err = {"detail": resp.text} raise HTTPException(status_code=502, detail=f"HF error: {err}") async def upload_video_bytes(mp4_bytes: bytes, client: httpx.AsyncClient) -> str: """ Uploads the MP4 to your uploader service and returns the public URL. """ if not UPLOAD_ACCESS_TOKEN: raise HTTPException(status_code=500, detail="UPLOAD_ACCESS_TOKEN is not set.") files = { "file": ("video.mp4", mp4_bytes, "video/mp4"), } headers = { "Authorization": f"Bearer {UPLOAD_ACCESS_TOKEN}", } resp = await client.post(UPLOAD_URL, headers=headers, files=files, timeout=None) if resp.status_code >= 400: raise HTTPException(status_code=502, detail=f"Upload failed: {resp.text}") data = resp.json() # Try common field names; adapt if your uploader returns a different shape url = ( data.get("url") or data.get("fileUrl") or data.get("publicUrl") or data.get("data", {}).get("url") ) if not url: # last resort: return whole payload for debugging raise HTTPException(status_code=502, detail=f"Upload response missing URL: {data}") return url # ---------- FastAPI app ---------- app = FastAPI(title="OpenAI-Compatible T2V Proxy") @app.post("/v1/chat/completions", response_model=ChatCompletionsResponse) async def chat_completions(req: ChatCompletionsRequest): """ OpenAI-compatible endpoint: - takes chat messages - generates a video from the last user message - uploads it - returns the link in assistant message content """ prompt = extract_prompt(req.messages) async with httpx.AsyncClient() as client: mp4 = await hf_text_to_video(prompt, client) video_url = await upload_video_bytes(mp4, client) now = int(time.time()) completion_id = f"chatcmpl-{uuid.uuid4().hex}" content = ( f"✅ Video generated & uploaded.\n" f"**Prompt:** {prompt}\n" f"**URL:** {video_url}" ) return ChatCompletionsResponse( id=completion_id, created=now, model=req.model, choices=[ ChatChoice( index=0, message=ChatChoiceMessage(role="assistant", content=content), finish_reason="stop", ) ], )