File size: 6,056 Bytes
c7e3d6d
 
 
534be3f
c7e3d6d
 
 
 
534be3f
 
c7e3d6d
534be3f
 
 
c7e3d6d
534be3f
 
 
 
 
 
c7e3d6d
 
534be3f
c7e3d6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
534be3f
c7e3d6d
534be3f
 
 
 
 
c7e3d6d
 
 
 
 
534be3f
 
c7e3d6d
 
534be3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7e3d6d
 
534be3f
 
 
 
 
c7e3d6d
 
 
 
 
534be3f
c7e3d6d
 
 
 
 
 
 
 
 
 
 
534be3f
 
 
 
 
 
 
 
c7e3d6d
 
 
 
 
534be3f
 
 
 
c7e3d6d
 
534be3f
 
c7e3d6d
 
 
534be3f
c7e3d6d
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import os
import time
import uuid
from typing import List, Optional, Literal, Any, Dict, Union

import httpx
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from huggingface_hub import InferenceClient
import asyncio


# ---------------- Config (env) ----------------
HF_TOKEN = os.getenv("HF_TOKEN")  # Hugging Face token (works for provider=fal-ai)
WAN_MODEL = os.getenv("WAN_MODEL", "Wan-AI/Wan2.2-T2V-A14B")

UPLOAD_URL = os.getenv("UPLOAD_URL", "https://upload.snapzion.com/api/public-upload")
UPLOAD_ACCESS_TOKEN = os.getenv("UPLOAD_ACCESS_TOKEN")  # your bearer token

# Optional tuning
GEN_TIMEOUT_SEC = int(os.getenv("GEN_TIMEOUT_SEC", "900"))  # 15 min generation ceiling


# ---------------- OpenAI-compatible schemas ----------------
class ChatMessage(BaseModel):
    role: Literal["system", "user", "assistant", "tool"]
    content: str


class ChatCompletionsRequest(BaseModel):
    model: str
    messages: List[ChatMessage]
    temperature: Optional[float] = None
    max_tokens: Optional[int] = None
    stream: Optional[bool] = False
    n: Optional[int] = 1
    top_p: Optional[float] = None
    presence_penalty: Optional[float] = None
    frequency_penalty: Optional[float] = None
    tools: Optional[Any] = None
    tool_choice: Optional[Any] = None


class ChatChoiceMessage(BaseModel):
    role: Literal["assistant"]
    content: str


class ChatChoice(BaseModel):
    index: int
    message: ChatChoiceMessage
    finish_reason: Literal["stop", "length", "content_filter", "tool_calls"] = "stop"


class ChatCompletionsResponse(BaseModel):
    id: str
    object: Literal["chat.completion"] = "chat.completion"
    created: int
    model: str
    choices: List[ChatChoice]
    usage: Optional[Dict[str, int]] = {
        "prompt_tokens": 0,
        "completion_tokens": 0,
        "total_tokens": 0,
    }


# ---------------- Helpers ----------------
def extract_prompt(messages: List[ChatMessage]) -> str:
    """Use the last user message as the prompt. Fallback to joining all user messages."""
    for m in reversed(messages):
        if m.role == "user" and m.content and m.content.strip():
            return m.content.strip()
    user_texts = [m.content for m in messages if m.role == "user" and m.content]
    if not user_texts:
        raise HTTPException(status_code=400, detail="No user prompt provided.")
    return "\n".join(user_texts).strip()


async def generate_video_bytes(prompt: str) -> bytes:
    """Calls huggingface_hub.InferenceClient with provider='fal-ai' (Wan T2V) and returns MP4 bytes."""
    if not HF_TOKEN:
        raise HTTPException(status_code=500, detail="HF_TOKEN is not set.")
    client = InferenceClient(provider="fal-ai", api_key=HF_TOKEN)

    def _sync_generate() -> Union[bytes, Dict[str, Any]]:
        # mirrors your Python example:
        # video = client.text_to_video("prompt", model="Wan-AI/Wan2.2-T2V-A14B")
        return client.text_to_video(prompt, model=WAN_MODEL)

    try:
        result = await asyncio.wait_for(
            asyncio.get_event_loop().run_in_executor(None, _sync_generate),
            timeout=GEN_TIMEOUT_SEC,
        )
    except asyncio.TimeoutError:
        raise HTTPException(status_code=504, detail="Video generation timed out.")
    except Exception as e:
        raise HTTPException(status_code=502, detail=f"Video generation failed: {e}")

    # fal-ai provider typically returns a dict with "video": bytes; sometimes raw bytes
    if isinstance(result, (bytes, bytearray)):
        return bytes(result)

    if isinstance(result, dict):
        # common keys: "video" (bytes), "seed", etc.
        vid = result.get("video")
        if isinstance(vid, (bytes, bytearray)):
            return bytes(vid)

    raise HTTPException(status_code=502, detail=f"Unexpected generation result: {type(result)}")


async def upload_video_bytes(mp4_bytes: bytes) -> str:
    """Uploads MP4 to Snapzion uploader and returns public URL."""
    if not UPLOAD_ACCESS_TOKEN:
        raise HTTPException(status_code=500, detail="UPLOAD_ACCESS_TOKEN is not set.")
    headers = {"Authorization": f"Bearer {UPLOAD_ACCESS_TOKEN}"}
    files = {"file": ("video.mp4", mp4_bytes, "video/mp4")}

    async with httpx.AsyncClient(timeout=None) as client:
        resp = await client.post(UPLOAD_URL, headers=headers, files=files)

    if resp.status_code >= 400:
        raise HTTPException(status_code=502, detail=f"Upload failed: {resp.text}")

    data = resp.json()
    # Try common URL fields (adjust if your API returns a different shape)
    url = (
        data.get("url")
        or data.get("fileUrl")
        or data.get("publicUrl")
        or data.get("data", {}).get("url")
    )
    if not url:
        raise HTTPException(status_code=502, detail=f"Upload response missing URL: {data}")
    return url


# ---------------- FastAPI app ----------------
app = FastAPI(title="OpenAI-Compatible T2V Proxy (FAL via HF)")


@app.get("/health")
async def health():
    return {"status": "ok", "model": WAN_MODEL}


@app.post("/v1/chat/completions", response_model=ChatCompletionsResponse)
async def chat_completions(req: ChatCompletionsRequest):
    """
    OpenAI-compatible endpoint:
      - reads last user message as the T2V prompt
      - generates a video with Wan-AI/Wan2.2-T2V-A14B via provider='fal-ai'
      - uploads to your uploader
      - returns the public URL inside the assistant message
    """
    prompt = extract_prompt(req.messages)
    mp4 = await generate_video_bytes(prompt)
    video_url = await upload_video_bytes(mp4)

    now = int(time.time())
    completion_id = f"chatcmpl-{uuid.uuid4().hex}"
    content = f"✅ Video generated & uploaded.\n**Prompt:** {prompt}\n**URL:** {video_url}"

    return ChatCompletionsResponse(
        id=completion_id,
        created=now,
        model=req.model,
        choices=[
            ChatChoice(
                index=0,
                message=ChatChoiceMessage(role="assistant", content=content),
                finish_reason="stop",
            )
        ],
    )