Spaces:
Sleeping
Sleeping
import os | |
import time | |
import uuid | |
from typing import List, Optional, Literal, Any, Dict, Union | |
import httpx | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from huggingface_hub import InferenceClient | |
import asyncio | |
# ---------------- Config (env) ---------------- | |
HF_TOKEN = os.getenv("HF_TOKEN") # Hugging Face token (works for provider=fal-ai) | |
WAN_MODEL = os.getenv("WAN_MODEL", "Wan-AI/Wan2.2-T2V-A14B") | |
UPLOAD_URL = os.getenv("UPLOAD_URL", "https://upload.snapzion.com/api/public-upload") | |
UPLOAD_ACCESS_TOKEN = os.getenv("UPLOAD_ACCESS_TOKEN") # your bearer token | |
# Optional tuning | |
GEN_TIMEOUT_SEC = int(os.getenv("GEN_TIMEOUT_SEC", "900")) # 15 min generation ceiling | |
# ---------------- OpenAI-compatible schemas ---------------- | |
class ChatMessage(BaseModel): | |
role: Literal["system", "user", "assistant", "tool"] | |
content: str | |
class ChatCompletionsRequest(BaseModel): | |
model: str | |
messages: List[ChatMessage] | |
temperature: Optional[float] = None | |
max_tokens: Optional[int] = None | |
stream: Optional[bool] = False | |
n: Optional[int] = 1 | |
top_p: Optional[float] = None | |
presence_penalty: Optional[float] = None | |
frequency_penalty: Optional[float] = None | |
tools: Optional[Any] = None | |
tool_choice: Optional[Any] = None | |
class ChatChoiceMessage(BaseModel): | |
role: Literal["assistant"] | |
content: str | |
class ChatChoice(BaseModel): | |
index: int | |
message: ChatChoiceMessage | |
finish_reason: Literal["stop", "length", "content_filter", "tool_calls"] = "stop" | |
class ChatCompletionsResponse(BaseModel): | |
id: str | |
object: Literal["chat.completion"] = "chat.completion" | |
created: int | |
model: str | |
choices: List[ChatChoice] | |
usage: Optional[Dict[str, int]] = { | |
"prompt_tokens": 0, | |
"completion_tokens": 0, | |
"total_tokens": 0, | |
} | |
# ---------------- Helpers ---------------- | |
def extract_prompt(messages: List[ChatMessage]) -> str: | |
"""Use the last user message as the prompt. Fallback to joining all user messages.""" | |
for m in reversed(messages): | |
if m.role == "user" and m.content and m.content.strip(): | |
return m.content.strip() | |
user_texts = [m.content for m in messages if m.role == "user" and m.content] | |
if not user_texts: | |
raise HTTPException(status_code=400, detail="No user prompt provided.") | |
return "\n".join(user_texts).strip() | |
async def generate_video_bytes(prompt: str) -> bytes: | |
"""Calls huggingface_hub.InferenceClient with provider='fal-ai' (Wan T2V) and returns MP4 bytes.""" | |
if not HF_TOKEN: | |
raise HTTPException(status_code=500, detail="HF_TOKEN is not set.") | |
client = InferenceClient(provider="fal-ai", api_key=HF_TOKEN) | |
def _sync_generate() -> Union[bytes, Dict[str, Any]]: | |
# mirrors your Python example: | |
# video = client.text_to_video("prompt", model="Wan-AI/Wan2.2-T2V-A14B") | |
return client.text_to_video(prompt, model=WAN_MODEL) | |
try: | |
result = await asyncio.wait_for( | |
asyncio.get_event_loop().run_in_executor(None, _sync_generate), | |
timeout=GEN_TIMEOUT_SEC, | |
) | |
except asyncio.TimeoutError: | |
raise HTTPException(status_code=504, detail="Video generation timed out.") | |
except Exception as e: | |
raise HTTPException(status_code=502, detail=f"Video generation failed: {e}") | |
# fal-ai provider typically returns a dict with "video": bytes; sometimes raw bytes | |
if isinstance(result, (bytes, bytearray)): | |
return bytes(result) | |
if isinstance(result, dict): | |
# common keys: "video" (bytes), "seed", etc. | |
vid = result.get("video") | |
if isinstance(vid, (bytes, bytearray)): | |
return bytes(vid) | |
raise HTTPException(status_code=502, detail=f"Unexpected generation result: {type(result)}") | |
async def upload_video_bytes(mp4_bytes: bytes) -> str: | |
"""Uploads MP4 to Snapzion uploader and returns public URL.""" | |
if not UPLOAD_ACCESS_TOKEN: | |
raise HTTPException(status_code=500, detail="UPLOAD_ACCESS_TOKEN is not set.") | |
headers = {"Authorization": f"Bearer {UPLOAD_ACCESS_TOKEN}"} | |
files = {"file": ("video.mp4", mp4_bytes, "video/mp4")} | |
async with httpx.AsyncClient(timeout=None) as client: | |
resp = await client.post(UPLOAD_URL, headers=headers, files=files) | |
if resp.status_code >= 400: | |
raise HTTPException(status_code=502, detail=f"Upload failed: {resp.text}") | |
data = resp.json() | |
# Try common URL fields (adjust if your API returns a different shape) | |
url = ( | |
data.get("url") | |
or data.get("fileUrl") | |
or data.get("publicUrl") | |
or data.get("data", {}).get("url") | |
) | |
if not url: | |
raise HTTPException(status_code=502, detail=f"Upload response missing URL: {data}") | |
return url | |
# ---------------- FastAPI app ---------------- | |
app = FastAPI(title="OpenAI-Compatible T2V Proxy (FAL via HF)") | |
async def health(): | |
return {"status": "ok", "model": WAN_MODEL} | |
async def chat_completions(req: ChatCompletionsRequest): | |
""" | |
OpenAI-compatible endpoint: | |
- reads last user message as the T2V prompt | |
- generates a video with Wan-AI/Wan2.2-T2V-A14B via provider='fal-ai' | |
- uploads to your uploader | |
- returns the public URL inside the assistant message | |
""" | |
prompt = extract_prompt(req.messages) | |
mp4 = await generate_video_bytes(prompt) | |
video_url = await upload_video_bytes(mp4) | |
now = int(time.time()) | |
completion_id = f"chatcmpl-{uuid.uuid4().hex}" | |
content = f"✅ Video generated & uploaded.\n**Prompt:** {prompt}\n**URL:** {video_url}" | |
return ChatCompletionsResponse( | |
id=completion_id, | |
created=now, | |
model=req.model, | |
choices=[ | |
ChatChoice( | |
index=0, | |
message=ChatChoiceMessage(role="assistant", content=content), | |
finish_reason="stop", | |
) | |
], | |
) | |