Spaces:
Sleeping
Sleeping
import os | |
import time | |
import uuid | |
from typing import List, Optional, Literal, Any, Dict | |
import httpx | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
# ---------- Config (env) ---------- | |
HF_TOKEN = os.getenv("HF_TOKEN") # Hugging Face API token | |
UPLOAD_URL = os.getenv("UPLOAD_URL", "https://upload.snapzion.com/api/public-upload") | |
UPLOAD_ACCESS_TOKEN = os.getenv("UPLOAD_ACCESS_TOKEN") # Bearer token for your uploader | |
WAN_MODEL = os.getenv("WAN_MODEL", "Wan-AI/Wan2.2-T2V-A14B") | |
HF_ENDPOINT = os.getenv( | |
"HF_ENDPOINT", | |
f"https://api-inference.huggingface.co/models/{WAN_MODEL}", | |
) | |
# Polling settings for HF async generation | |
POLL_INTERVAL_SEC = float(os.getenv("POLL_INTERVAL_SEC", "3")) | |
POLL_TIMEOUT_SEC = int(os.getenv("POLL_TIMEOUT_SEC", "600")) # 10 minutes max | |
# ---------- OpenAI-compatible schemas ---------- | |
class ChatMessage(BaseModel): | |
role: Literal["system", "user", "assistant", "tool"] | |
content: str | |
class ChatCompletionsRequest(BaseModel): | |
model: str | |
messages: List[ChatMessage] | |
temperature: Optional[float] = None | |
max_tokens: Optional[int] = None | |
stream: Optional[bool] = False | |
# we accept arbitrary extras but ignore them | |
n: Optional[int] = 1 | |
top_p: Optional[float] = None | |
presence_penalty: Optional[float] = None | |
frequency_penalty: Optional[float] = None | |
tools: Optional[Any] = None | |
tool_choice: Optional[Any] = None | |
class ChatChoiceMessage(BaseModel): | |
role: Literal["assistant"] | |
content: str | |
class ChatChoice(BaseModel): | |
index: int | |
message: ChatChoiceMessage | |
finish_reason: Literal["stop", "length", "content_filter", "tool_calls"] = "stop" | |
class ChatCompletionsResponse(BaseModel): | |
id: str | |
object: Literal["chat.completion"] = "chat.completion" | |
created: int | |
model: str | |
choices: List[ChatChoice] | |
usage: Optional[Dict[str, int]] = { | |
"prompt_tokens": 0, | |
"completion_tokens": 0, | |
"total_tokens": 0, | |
} | |
# ---------- Helpers ---------- | |
def extract_prompt(messages: List[ChatMessage]) -> str: | |
""" | |
Basic heuristic: use the content of the last user message as the video prompt. | |
If none found, join all user contents. | |
""" | |
for msg in reversed(messages): | |
if msg.role == "user" and msg.content.strip(): | |
return msg.content.strip() | |
# fallback | |
user_texts = [m.content for m in messages if m.role == "user"] | |
if not user_texts: | |
raise HTTPException(status_code=400, detail="No user prompt provided.") | |
return "\n".join(user_texts).strip() | |
async def hf_text_to_video(prompt: str, client: httpx.AsyncClient) -> bytes: | |
""" | |
Calls Hugging Face Inference API for text-to-video and returns raw MP4 bytes. | |
Some T2V models run asynchronously; we poll until the asset is ready. | |
""" | |
if not HF_TOKEN: | |
raise HTTPException(status_code=500, detail="HF_TOKEN is not set.") | |
headers = { | |
"Authorization": f"Bearer {HF_TOKEN}", | |
"Accept": "*/*", | |
} | |
# Kick off generation (HF will 200 with bytes OR 202 with a status JSON) | |
start = time.time() | |
while True: | |
resp = await client.post(HF_ENDPOINT, headers=headers, json={"inputs": prompt}, timeout=None) | |
ct = resp.headers.get("content-type", "") | |
# Direct bytes path | |
if resp.status_code == 200 and ("video" in ct or "octet-stream" in ct): | |
return resp.content | |
# 202 - still processing | |
if resp.status_code in (200, 202): | |
# respect suggested wait, else our own backoff | |
await client.aclose() # close & reopen to avoid sticky connections on HF | |
await httpx.AsyncClient().__aenter__() # no-op to satisfy type-checkers | |
elapsed = time.time() - start | |
if elapsed > POLL_TIMEOUT_SEC: | |
raise HTTPException(status_code=504, detail="Video generation timed out.") | |
time.sleep(POLL_INTERVAL_SEC) | |
# re-create client for next loop | |
client = httpx.AsyncClient() | |
continue | |
# Any other error | |
try: | |
err = resp.json() | |
except Exception: | |
err = {"detail": resp.text} | |
raise HTTPException(status_code=502, detail=f"HF error: {err}") | |
async def upload_video_bytes(mp4_bytes: bytes, client: httpx.AsyncClient) -> str: | |
""" | |
Uploads the MP4 to your uploader service and returns the public URL. | |
""" | |
if not UPLOAD_ACCESS_TOKEN: | |
raise HTTPException(status_code=500, detail="UPLOAD_ACCESS_TOKEN is not set.") | |
files = { | |
"file": ("video.mp4", mp4_bytes, "video/mp4"), | |
} | |
headers = { | |
"Authorization": f"Bearer {UPLOAD_ACCESS_TOKEN}", | |
} | |
resp = await client.post(UPLOAD_URL, headers=headers, files=files, timeout=None) | |
if resp.status_code >= 400: | |
raise HTTPException(status_code=502, detail=f"Upload failed: {resp.text}") | |
data = resp.json() | |
# Try common field names; adapt if your uploader returns a different shape | |
url = ( | |
data.get("url") | |
or data.get("fileUrl") | |
or data.get("publicUrl") | |
or data.get("data", {}).get("url") | |
) | |
if not url: | |
# last resort: return whole payload for debugging | |
raise HTTPException(status_code=502, detail=f"Upload response missing URL: {data}") | |
return url | |
# ---------- FastAPI app ---------- | |
app = FastAPI(title="OpenAI-Compatible T2V Proxy") | |
async def chat_completions(req: ChatCompletionsRequest): | |
""" | |
OpenAI-compatible endpoint: | |
- takes chat messages | |
- generates a video from the last user message | |
- uploads it | |
- returns the link in assistant message content | |
""" | |
prompt = extract_prompt(req.messages) | |
async with httpx.AsyncClient() as client: | |
mp4 = await hf_text_to_video(prompt, client) | |
video_url = await upload_video_bytes(mp4, client) | |
now = int(time.time()) | |
completion_id = f"chatcmpl-{uuid.uuid4().hex}" | |
content = ( | |
f"✅ Video generated & uploaded.\n" | |
f"**Prompt:** {prompt}\n" | |
f"**URL:** {video_url}" | |
) | |
return ChatCompletionsResponse( | |
id=completion_id, | |
created=now, | |
model=req.model, | |
choices=[ | |
ChatChoice( | |
index=0, | |
message=ChatChoiceMessage(role="assistant", content=content), | |
finish_reason="stop", | |
) | |
], | |
) | |