File size: 6,511 Bytes
c7e3d6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
import os
import time
import uuid
from typing import List, Optional, Literal, Any, Dict

import httpx
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel

# ---------- Config (env) ----------
HF_TOKEN = os.getenv("HF_TOKEN")  # Hugging Face API token
UPLOAD_URL = os.getenv("UPLOAD_URL", "https://upload.snapzion.com/api/public-upload")
UPLOAD_ACCESS_TOKEN = os.getenv("UPLOAD_ACCESS_TOKEN")  # Bearer token for your uploader
WAN_MODEL = os.getenv("WAN_MODEL", "Wan-AI/Wan2.2-T2V-A14B")
HF_ENDPOINT = os.getenv(
    "HF_ENDPOINT",
    f"https://api-inference.huggingface.co/models/{WAN_MODEL}",
)
# Polling settings for HF async generation
POLL_INTERVAL_SEC = float(os.getenv("POLL_INTERVAL_SEC", "3"))
POLL_TIMEOUT_SEC = int(os.getenv("POLL_TIMEOUT_SEC", "600"))  # 10 minutes max


# ---------- OpenAI-compatible schemas ----------
class ChatMessage(BaseModel):
    role: Literal["system", "user", "assistant", "tool"]
    content: str


class ChatCompletionsRequest(BaseModel):
    model: str
    messages: List[ChatMessage]
    temperature: Optional[float] = None
    max_tokens: Optional[int] = None
    stream: Optional[bool] = False
    # we accept arbitrary extras but ignore them
    n: Optional[int] = 1
    top_p: Optional[float] = None
    presence_penalty: Optional[float] = None
    frequency_penalty: Optional[float] = None
    tools: Optional[Any] = None
    tool_choice: Optional[Any] = None


class ChatChoiceMessage(BaseModel):
    role: Literal["assistant"]
    content: str


class ChatChoice(BaseModel):
    index: int
    message: ChatChoiceMessage
    finish_reason: Literal["stop", "length", "content_filter", "tool_calls"] = "stop"


class ChatCompletionsResponse(BaseModel):
    id: str
    object: Literal["chat.completion"] = "chat.completion"
    created: int
    model: str
    choices: List[ChatChoice]
    usage: Optional[Dict[str, int]] = {
        "prompt_tokens": 0,
        "completion_tokens": 0,
        "total_tokens": 0,
    }


# ---------- Helpers ----------
def extract_prompt(messages: List[ChatMessage]) -> str:
    """
    Basic heuristic: use the content of the last user message as the video prompt.
    If none found, join all user contents.
    """
    for msg in reversed(messages):
        if msg.role == "user" and msg.content.strip():
            return msg.content.strip()
    # fallback
    user_texts = [m.content for m in messages if m.role == "user"]
    if not user_texts:
        raise HTTPException(status_code=400, detail="No user prompt provided.")
    return "\n".join(user_texts).strip()


async def hf_text_to_video(prompt: str, client: httpx.AsyncClient) -> bytes:
    """
    Calls Hugging Face Inference API for text-to-video and returns raw MP4 bytes.
    Some T2V models run asynchronously; we poll until the asset is ready.
    """
    if not HF_TOKEN:
        raise HTTPException(status_code=500, detail="HF_TOKEN is not set.")

    headers = {
        "Authorization": f"Bearer {HF_TOKEN}",
        "Accept": "*/*",
    }

    # Kick off generation (HF will 200 with bytes OR 202 with a status JSON)
    start = time.time()
    while True:
        resp = await client.post(HF_ENDPOINT, headers=headers, json={"inputs": prompt}, timeout=None)
        ct = resp.headers.get("content-type", "")

        # Direct bytes path
        if resp.status_code == 200 and ("video" in ct or "octet-stream" in ct):
            return resp.content

        # 202 - still processing
        if resp.status_code in (200, 202):
            # respect suggested wait, else our own backoff
            await client.aclose()  # close & reopen to avoid sticky connections on HF
            await httpx.AsyncClient().__aenter__()  # no-op to satisfy type-checkers
            elapsed = time.time() - start
            if elapsed > POLL_TIMEOUT_SEC:
                raise HTTPException(status_code=504, detail="Video generation timed out.")
            time.sleep(POLL_INTERVAL_SEC)
            # re-create client for next loop
            client = httpx.AsyncClient()
            continue

        # Any other error
        try:
            err = resp.json()
        except Exception:
            err = {"detail": resp.text}
        raise HTTPException(status_code=502, detail=f"HF error: {err}")

async def upload_video_bytes(mp4_bytes: bytes, client: httpx.AsyncClient) -> str:
    """
    Uploads the MP4 to your uploader service and returns the public URL.
    """
    if not UPLOAD_ACCESS_TOKEN:
        raise HTTPException(status_code=500, detail="UPLOAD_ACCESS_TOKEN is not set.")

    files = {
        "file": ("video.mp4", mp4_bytes, "video/mp4"),
    }
    headers = {
        "Authorization": f"Bearer {UPLOAD_ACCESS_TOKEN}",
    }
    resp = await client.post(UPLOAD_URL, headers=headers, files=files, timeout=None)
    if resp.status_code >= 400:
        raise HTTPException(status_code=502, detail=f"Upload failed: {resp.text}")

    data = resp.json()
    # Try common field names; adapt if your uploader returns a different shape
    url = (
        data.get("url")
        or data.get("fileUrl")
        or data.get("publicUrl")
        or data.get("data", {}).get("url")
    )
    if not url:
        # last resort: return whole payload for debugging
        raise HTTPException(status_code=502, detail=f"Upload response missing URL: {data}")

    return url


# ---------- FastAPI app ----------
app = FastAPI(title="OpenAI-Compatible T2V Proxy")

@app.post("/v1/chat/completions", response_model=ChatCompletionsResponse)
async def chat_completions(req: ChatCompletionsRequest):
    """
    OpenAI-compatible endpoint:
      - takes chat messages
      - generates a video from the last user message
      - uploads it
      - returns the link in assistant message content
    """
    prompt = extract_prompt(req.messages)

    async with httpx.AsyncClient() as client:
        mp4 = await hf_text_to_video(prompt, client)
        video_url = await upload_video_bytes(mp4, client)

    now = int(time.time())
    completion_id = f"chatcmpl-{uuid.uuid4().hex}"

    content = (
        f"✅ Video generated & uploaded.\n"
        f"**Prompt:** {prompt}\n"
        f"**URL:** {video_url}"
    )

    return ChatCompletionsResponse(
        id=completion_id,
        created=now,
        model=req.model,
        choices=[
            ChatChoice(
                index=0,
                message=ChatChoiceMessage(role="assistant", content=content),
                finish_reason="stop",
            )
        ],
    )