IMGTOVIDEO / main.py
rkihacker's picture
Create main.py
c7e3d6d verified
raw
history blame
6.51 kB
import os
import time
import uuid
from typing import List, Optional, Literal, Any, Dict
import httpx
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
# ---------- Config (env) ----------
HF_TOKEN = os.getenv("HF_TOKEN") # Hugging Face API token
UPLOAD_URL = os.getenv("UPLOAD_URL", "https://upload.snapzion.com/api/public-upload")
UPLOAD_ACCESS_TOKEN = os.getenv("UPLOAD_ACCESS_TOKEN") # Bearer token for your uploader
WAN_MODEL = os.getenv("WAN_MODEL", "Wan-AI/Wan2.2-T2V-A14B")
HF_ENDPOINT = os.getenv(
"HF_ENDPOINT",
f"https://api-inference.huggingface.co/models/{WAN_MODEL}",
)
# Polling settings for HF async generation
POLL_INTERVAL_SEC = float(os.getenv("POLL_INTERVAL_SEC", "3"))
POLL_TIMEOUT_SEC = int(os.getenv("POLL_TIMEOUT_SEC", "600")) # 10 minutes max
# ---------- OpenAI-compatible schemas ----------
class ChatMessage(BaseModel):
role: Literal["system", "user", "assistant", "tool"]
content: str
class ChatCompletionsRequest(BaseModel):
model: str
messages: List[ChatMessage]
temperature: Optional[float] = None
max_tokens: Optional[int] = None
stream: Optional[bool] = False
# we accept arbitrary extras but ignore them
n: Optional[int] = 1
top_p: Optional[float] = None
presence_penalty: Optional[float] = None
frequency_penalty: Optional[float] = None
tools: Optional[Any] = None
tool_choice: Optional[Any] = None
class ChatChoiceMessage(BaseModel):
role: Literal["assistant"]
content: str
class ChatChoice(BaseModel):
index: int
message: ChatChoiceMessage
finish_reason: Literal["stop", "length", "content_filter", "tool_calls"] = "stop"
class ChatCompletionsResponse(BaseModel):
id: str
object: Literal["chat.completion"] = "chat.completion"
created: int
model: str
choices: List[ChatChoice]
usage: Optional[Dict[str, int]] = {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
}
# ---------- Helpers ----------
def extract_prompt(messages: List[ChatMessage]) -> str:
"""
Basic heuristic: use the content of the last user message as the video prompt.
If none found, join all user contents.
"""
for msg in reversed(messages):
if msg.role == "user" and msg.content.strip():
return msg.content.strip()
# fallback
user_texts = [m.content for m in messages if m.role == "user"]
if not user_texts:
raise HTTPException(status_code=400, detail="No user prompt provided.")
return "\n".join(user_texts).strip()
async def hf_text_to_video(prompt: str, client: httpx.AsyncClient) -> bytes:
"""
Calls Hugging Face Inference API for text-to-video and returns raw MP4 bytes.
Some T2V models run asynchronously; we poll until the asset is ready.
"""
if not HF_TOKEN:
raise HTTPException(status_code=500, detail="HF_TOKEN is not set.")
headers = {
"Authorization": f"Bearer {HF_TOKEN}",
"Accept": "*/*",
}
# Kick off generation (HF will 200 with bytes OR 202 with a status JSON)
start = time.time()
while True:
resp = await client.post(HF_ENDPOINT, headers=headers, json={"inputs": prompt}, timeout=None)
ct = resp.headers.get("content-type", "")
# Direct bytes path
if resp.status_code == 200 and ("video" in ct or "octet-stream" in ct):
return resp.content
# 202 - still processing
if resp.status_code in (200, 202):
# respect suggested wait, else our own backoff
await client.aclose() # close & reopen to avoid sticky connections on HF
await httpx.AsyncClient().__aenter__() # no-op to satisfy type-checkers
elapsed = time.time() - start
if elapsed > POLL_TIMEOUT_SEC:
raise HTTPException(status_code=504, detail="Video generation timed out.")
time.sleep(POLL_INTERVAL_SEC)
# re-create client for next loop
client = httpx.AsyncClient()
continue
# Any other error
try:
err = resp.json()
except Exception:
err = {"detail": resp.text}
raise HTTPException(status_code=502, detail=f"HF error: {err}")
async def upload_video_bytes(mp4_bytes: bytes, client: httpx.AsyncClient) -> str:
"""
Uploads the MP4 to your uploader service and returns the public URL.
"""
if not UPLOAD_ACCESS_TOKEN:
raise HTTPException(status_code=500, detail="UPLOAD_ACCESS_TOKEN is not set.")
files = {
"file": ("video.mp4", mp4_bytes, "video/mp4"),
}
headers = {
"Authorization": f"Bearer {UPLOAD_ACCESS_TOKEN}",
}
resp = await client.post(UPLOAD_URL, headers=headers, files=files, timeout=None)
if resp.status_code >= 400:
raise HTTPException(status_code=502, detail=f"Upload failed: {resp.text}")
data = resp.json()
# Try common field names; adapt if your uploader returns a different shape
url = (
data.get("url")
or data.get("fileUrl")
or data.get("publicUrl")
or data.get("data", {}).get("url")
)
if not url:
# last resort: return whole payload for debugging
raise HTTPException(status_code=502, detail=f"Upload response missing URL: {data}")
return url
# ---------- FastAPI app ----------
app = FastAPI(title="OpenAI-Compatible T2V Proxy")
@app.post("/v1/chat/completions", response_model=ChatCompletionsResponse)
async def chat_completions(req: ChatCompletionsRequest):
"""
OpenAI-compatible endpoint:
- takes chat messages
- generates a video from the last user message
- uploads it
- returns the link in assistant message content
"""
prompt = extract_prompt(req.messages)
async with httpx.AsyncClient() as client:
mp4 = await hf_text_to_video(prompt, client)
video_url = await upload_video_bytes(mp4, client)
now = int(time.time())
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
content = (
f"✅ Video generated & uploaded.\n"
f"**Prompt:** {prompt}\n"
f"**URL:** {video_url}"
)
return ChatCompletionsResponse(
id=completion_id,
created=now,
model=req.model,
choices=[
ChatChoice(
index=0,
message=ChatChoiceMessage(role="assistant", content=content),
finish_reason="stop",
)
],
)