import os import time import uuid from typing import List, Optional, Literal, Any, Dict, Union import httpx from fastapi import FastAPI, HTTPException from pydantic import BaseModel from huggingface_hub import InferenceClient import asyncio # ---------------- Config (env) ---------------- HF_TOKEN = os.getenv("HF_TOKEN") # Hugging Face token (works for provider=fal-ai) WAN_MODEL = os.getenv("WAN_MODEL", "Wan-AI/Wan2.2-T2V-A14B") UPLOAD_URL = os.getenv("UPLOAD_URL", "https://upload.snapzion.com/api/public-upload") UPLOAD_ACCESS_TOKEN = os.getenv("UPLOAD_ACCESS_TOKEN") # your bearer token # Optional tuning GEN_TIMEOUT_SEC = int(os.getenv("GEN_TIMEOUT_SEC", "900")) # 15 min generation ceiling # ---------------- OpenAI-compatible schemas ---------------- class ChatMessage(BaseModel): role: Literal["system", "user", "assistant", "tool"] content: str class ChatCompletionsRequest(BaseModel): model: str messages: List[ChatMessage] temperature: Optional[float] = None max_tokens: Optional[int] = None stream: Optional[bool] = False n: Optional[int] = 1 top_p: Optional[float] = None presence_penalty: Optional[float] = None frequency_penalty: Optional[float] = None tools: Optional[Any] = None tool_choice: Optional[Any] = None class ChatChoiceMessage(BaseModel): role: Literal["assistant"] content: str class ChatChoice(BaseModel): index: int message: ChatChoiceMessage finish_reason: Literal["stop", "length", "content_filter", "tool_calls"] = "stop" class ChatCompletionsResponse(BaseModel): id: str object: Literal["chat.completion"] = "chat.completion" created: int model: str choices: List[ChatChoice] usage: Optional[Dict[str, int]] = { "prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0, } # ---------------- Helpers ---------------- def extract_prompt(messages: List[ChatMessage]) -> str: """Use the last user message as the prompt. Fallback to joining all user messages.""" for m in reversed(messages): if m.role == "user" and m.content and m.content.strip(): return m.content.strip() user_texts = [m.content for m in messages if m.role == "user" and m.content] if not user_texts: raise HTTPException(status_code=400, detail="No user prompt provided.") return "\n".join(user_texts).strip() async def generate_video_bytes(prompt: str) -> bytes: """Calls huggingface_hub.InferenceClient with provider='fal-ai' (Wan T2V) and returns MP4 bytes.""" if not HF_TOKEN: raise HTTPException(status_code=500, detail="HF_TOKEN is not set.") client = InferenceClient(provider="fal-ai", api_key=HF_TOKEN) def _sync_generate() -> Union[bytes, Dict[str, Any]]: # mirrors your Python example: # video = client.text_to_video("prompt", model="Wan-AI/Wan2.2-T2V-A14B") return client.text_to_video(prompt, model=WAN_MODEL) try: result = await asyncio.wait_for( asyncio.get_event_loop().run_in_executor(None, _sync_generate), timeout=GEN_TIMEOUT_SEC, ) except asyncio.TimeoutError: raise HTTPException(status_code=504, detail="Video generation timed out.") except Exception as e: raise HTTPException(status_code=502, detail=f"Video generation failed: {e}") # fal-ai provider typically returns a dict with "video": bytes; sometimes raw bytes if isinstance(result, (bytes, bytearray)): return bytes(result) if isinstance(result, dict): # common keys: "video" (bytes), "seed", etc. vid = result.get("video") if isinstance(vid, (bytes, bytearray)): return bytes(vid) raise HTTPException(status_code=502, detail=f"Unexpected generation result: {type(result)}") async def upload_video_bytes(mp4_bytes: bytes) -> str: """Uploads MP4 to Snapzion uploader and returns public URL.""" if not UPLOAD_ACCESS_TOKEN: raise HTTPException(status_code=500, detail="UPLOAD_ACCESS_TOKEN is not set.") headers = {"Authorization": f"Bearer {UPLOAD_ACCESS_TOKEN}"} files = {"file": ("video.mp4", mp4_bytes, "video/mp4")} async with httpx.AsyncClient(timeout=None) as client: resp = await client.post(UPLOAD_URL, headers=headers, files=files) if resp.status_code >= 400: raise HTTPException(status_code=502, detail=f"Upload failed: {resp.text}") data = resp.json() # Try common URL fields (adjust if your API returns a different shape) url = ( data.get("url") or data.get("fileUrl") or data.get("publicUrl") or data.get("data", {}).get("url") ) if not url: raise HTTPException(status_code=502, detail=f"Upload response missing URL: {data}") return url # ---------------- FastAPI app ---------------- app = FastAPI(title="OpenAI-Compatible T2V Proxy (FAL via HF)") @app.get("/health") async def health(): return {"status": "ok", "model": WAN_MODEL} @app.post("/v1/chat/completions", response_model=ChatCompletionsResponse) async def chat_completions(req: ChatCompletionsRequest): """ OpenAI-compatible endpoint: - reads last user message as the T2V prompt - generates a video with Wan-AI/Wan2.2-T2V-A14B via provider='fal-ai' - uploads to your uploader - returns the public URL inside the assistant message """ prompt = extract_prompt(req.messages) mp4 = await generate_video_bytes(prompt) video_url = await upload_video_bytes(mp4) now = int(time.time()) completion_id = f"chatcmpl-{uuid.uuid4().hex}" content = f"✅ Video generated & uploaded.\n**Prompt:** {prompt}\n**URL:** {video_url}" return ChatCompletionsResponse( id=completion_id, created=now, model=req.model, choices=[ ChatChoice( index=0, message=ChatChoiceMessage(role="assistant", content=content), finish_reason="stop", ) ], )