IMGTOVIDEO / main.py
rkihacker's picture
Update main.py
3a4123b verified
import os
import time
import uuid
from typing import List, Optional, Literal, Any, Dict
import httpx
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
# ---------- Config via env ----------
HF_TOKEN = os.getenv("HF_TOKEN") # e.g. hf_jwt_...
# Default matches your curl submit endpoint
HF_SUBMIT_URL = os.getenv(
"HF_SUBMIT_URL",
"https://router.huggingface.co/fal-ai/fal-ai/wan/v2.2-a14b/text-to-video?_subdomain=queue",
)
UPLOAD_URL = os.getenv("UPLOAD_URL", "https://upload.snapzion.com/api/public-upload")
UPLOAD_ACCESS_TOKEN = os.getenv("UPLOAD_ACCESS_TOKEN") # your Snapzion bearer
# Polling/backoff
POLL_INTERVAL_SEC = float(os.getenv("POLL_INTERVAL_SEC", "3"))
POLL_TIMEOUT_SEC = int(os.getenv("POLL_TIMEOUT_SEC", "900")) # 15 min max
# ---------- OpenAI-compatible schemas ----------
class ChatMessage(BaseModel):
role: Literal["system", "user", "assistant", "tool"]
content: str
class ChatCompletionsRequest(BaseModel):
model: str
messages: List[ChatMessage]
temperature: Optional[float] = None
max_tokens: Optional[int] = None
stream: Optional[bool] = False
n: Optional[int] = 1
top_p: Optional[float] = None
presence_penalty: Optional[float] = None
frequency_penalty: Optional[float] = None
tools: Optional[Any] = None
tool_choice: Optional[Any] = None
class ChatChoiceMessage(BaseModel):
role: Literal["assistant"]
content: str
class ChatChoice(BaseModel):
index: int
message: ChatChoiceMessage
finish_reason: Literal["stop", "length", "content_filter", "tool_calls"] = "stop"
class ChatCompletionsResponse(BaseModel):
id: str
object: Literal["chat.completion"] = "chat.completion"
created: int
model: str
choices: List[ChatChoice]
usage: Optional[Dict[str, int]] = {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
}
# ---------- Helpers ----------
def extract_prompt(messages: List[ChatMessage]) -> str:
for m in reversed(messages):
if m.role == "user" and m.content and m.content.strip():
return m.content.strip()
user_texts = [m.content for m in messages if m.role == "user" and m.content]
if not user_texts:
raise HTTPException(status_code=400, detail="No user prompt provided.")
return "\n".join(user_texts).strip()
async def hf_queue_submit(prompt: str) -> Dict[str, str]:
"""
Submit a job to the HF router queue endpoint.
Returns a dict containing status_url and response_url (from HF).
"""
if not HF_TOKEN:
raise HTTPException(status_code=500, detail="HF_TOKEN is not set.")
headers = {
"Authorization": f"Bearer {HF_TOKEN}",
"Content-Type": "application/json",
"Accept": "*/*",
}
payload = {"prompt": prompt}
async with httpx.AsyncClient(timeout=None) as client:
resp = await client.post(HF_SUBMIT_URL, headers=headers, json=payload)
if resp.status_code >= 400:
raise HTTPException(status_code=502, detail=f"HF submit error: {resp.text}")
data = resp.json()
status_url = data.get("status_url") or data.get("urls", {}).get("status_url")
response_url = data.get("response_url") or data.get("urls", {}).get("response_url")
if not status_url or not response_url:
raise HTTPException(status_code=502, detail=f"Unexpected HF submit response: {data}")
return {"status_url": status_url, "response_url": response_url}
async def hf_queue_wait(status_url: str) -> None:
"""
Polls the HF queue status_url until COMPLETED or error states.
"""
if not HF_TOKEN:
raise HTTPException(status_code=500, detail="HF_TOKEN is not set.")
headers = {"Authorization": f"Bearer {HF_TOKEN}", "Accept": "*/*"}
start = time.time()
async with httpx.AsyncClient(timeout=None) as client:
while True:
resp = await client.get(status_url, headers=headers)
if resp.status_code >= 400:
raise HTTPException(status_code=502, detail=f"HF status error: {resp.text}")
data = resp.json()
status = data.get("status")
if status in ("COMPLETED", "SUCCEEDED"):
return
if status in ("FAILED", "ERROR", "CANCELLED", "CANCELED"):
raise HTTPException(status_code=502, detail=f"HF job failed: {data}")
if time.time() - start > POLL_TIMEOUT_SEC:
raise HTTPException(status_code=504, detail="HF queue timed out.")
time.sleep(POLL_INTERVAL_SEC)
async def hf_queue_fetch_result(response_url: str) -> Dict[str, Any]:
"""
Fetch the final response JSON, which includes {"video": {"url": ...}, ...}
"""
if not HF_TOKEN:
raise HTTPException(status_code=500, detail="HF_TOKEN is not set.")
headers = {"Authorization": f"Bearer {HF_TOKEN}", "Accept": "*/*"}
async with httpx.AsyncClient(timeout=None) as client:
resp = await client.get(response_url, headers=headers)
if resp.status_code >= 400:
raise HTTPException(status_code=502, detail=f"HF result error: {resp.text}")
return resp.json()
async def download_video(url: str) -> bytes:
async with httpx.AsyncClient(timeout=None) as client:
resp = await client.get(url)
if resp.status_code >= 400:
raise HTTPException(status_code=502, detail=f"Download failed: {resp.text}")
return resp.content
async def upload_video_bytes(mp4_bytes: bytes) -> str:
if not UPLOAD_ACCESS_TOKEN:
raise HTTPException(status_code=500, detail="UPLOAD_ACCESS_TOKEN is not set.")
headers = {"Authorization": f"Bearer {UPLOAD_ACCESS_TOKEN}"}
files = {"file": ("video.mp4", mp4_bytes, "video/mp4")}
async with httpx.AsyncClient(timeout=None) as client:
resp = await client.post(UPLOAD_URL, headers=headers, files=files)
if resp.status_code >= 400:
raise HTTPException(status_code=502, detail=f"Upload failed: {resp.text}")
data = resp.json()
url = (
data.get("url")
or data.get("fileUrl")
or data.get("publicUrl")
or data.get("data", {}).get("url")
)
if not url:
raise HTTPException(status_code=502, detail=f"Upload response missing URL: {data}")
return url
# ---------- FastAPI ----------
app = FastAPI(title="OpenAI-Compatible T2V Proxy (HF Router Queue)")
@app.get("/health")
async def health():
return {"status": "ok", "submit_url": HF_SUBMIT_URL}
@app.post("/v1/chat/completions", response_model=ChatCompletionsResponse)
async def chat_completions(req: ChatCompletionsRequest):
"""
1) submit to HF router queue (Bearer HF_TOKEN)
2) poll status_url until COMPLETED
3) fetch response_url -> video.url
4) download MP4, upload to Snapzion
5) return URL in OpenAI chat shape
"""
prompt = extract_prompt(req.messages)
# 1) Submit
urls = await hf_queue_submit(prompt)
# 2) Wait
await hf_queue_wait(urls["status_url"])
# 3) Fetch result JSON
result = await hf_queue_fetch_result(urls["response_url"])
video_url = (result.get("video") or {}).get("url")
if not video_url:
raise HTTPException(status_code=502, detail=f"HF result missing video.url: {result}")
# 4) Download + re-upload
mp4 = await download_video(video_url)
public_url = await upload_video_bytes(mp4)
# 5) Respond OpenAI-style
now = int(time.time())
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
content = f"✅ Video generated & uploaded.\n**Prompt:** {prompt}\n**URL:** {public_url}"
return ChatCompletionsResponse(
id=completion_id,
created=now,
model=req.model,
choices=[
ChatChoice(
index=0,
message=ChatChoiceMessage(role="assistant", content=content),
finish_reason="stop",
)
],
)