rkihacker commited on
Commit
c7e3d6d
·
verified ·
1 Parent(s): b0666d7

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +198 -0
main.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import uuid
4
+ from typing import List, Optional, Literal, Any, Dict
5
+
6
+ import httpx
7
+ from fastapi import FastAPI, HTTPException
8
+ from pydantic import BaseModel
9
+
10
+ # ---------- Config (env) ----------
11
+ HF_TOKEN = os.getenv("HF_TOKEN") # Hugging Face API token
12
+ UPLOAD_URL = os.getenv("UPLOAD_URL", "https://upload.snapzion.com/api/public-upload")
13
+ UPLOAD_ACCESS_TOKEN = os.getenv("UPLOAD_ACCESS_TOKEN") # Bearer token for your uploader
14
+ WAN_MODEL = os.getenv("WAN_MODEL", "Wan-AI/Wan2.2-T2V-A14B")
15
+ HF_ENDPOINT = os.getenv(
16
+ "HF_ENDPOINT",
17
+ f"https://api-inference.huggingface.co/models/{WAN_MODEL}",
18
+ )
19
+ # Polling settings for HF async generation
20
+ POLL_INTERVAL_SEC = float(os.getenv("POLL_INTERVAL_SEC", "3"))
21
+ POLL_TIMEOUT_SEC = int(os.getenv("POLL_TIMEOUT_SEC", "600")) # 10 minutes max
22
+
23
+
24
+ # ---------- OpenAI-compatible schemas ----------
25
+ class ChatMessage(BaseModel):
26
+ role: Literal["system", "user", "assistant", "tool"]
27
+ content: str
28
+
29
+
30
+ class ChatCompletionsRequest(BaseModel):
31
+ model: str
32
+ messages: List[ChatMessage]
33
+ temperature: Optional[float] = None
34
+ max_tokens: Optional[int] = None
35
+ stream: Optional[bool] = False
36
+ # we accept arbitrary extras but ignore them
37
+ n: Optional[int] = 1
38
+ top_p: Optional[float] = None
39
+ presence_penalty: Optional[float] = None
40
+ frequency_penalty: Optional[float] = None
41
+ tools: Optional[Any] = None
42
+ tool_choice: Optional[Any] = None
43
+
44
+
45
+ class ChatChoiceMessage(BaseModel):
46
+ role: Literal["assistant"]
47
+ content: str
48
+
49
+
50
+ class ChatChoice(BaseModel):
51
+ index: int
52
+ message: ChatChoiceMessage
53
+ finish_reason: Literal["stop", "length", "content_filter", "tool_calls"] = "stop"
54
+
55
+
56
+ class ChatCompletionsResponse(BaseModel):
57
+ id: str
58
+ object: Literal["chat.completion"] = "chat.completion"
59
+ created: int
60
+ model: str
61
+ choices: List[ChatChoice]
62
+ usage: Optional[Dict[str, int]] = {
63
+ "prompt_tokens": 0,
64
+ "completion_tokens": 0,
65
+ "total_tokens": 0,
66
+ }
67
+
68
+
69
+ # ---------- Helpers ----------
70
+ def extract_prompt(messages: List[ChatMessage]) -> str:
71
+ """
72
+ Basic heuristic: use the content of the last user message as the video prompt.
73
+ If none found, join all user contents.
74
+ """
75
+ for msg in reversed(messages):
76
+ if msg.role == "user" and msg.content.strip():
77
+ return msg.content.strip()
78
+ # fallback
79
+ user_texts = [m.content for m in messages if m.role == "user"]
80
+ if not user_texts:
81
+ raise HTTPException(status_code=400, detail="No user prompt provided.")
82
+ return "\n".join(user_texts).strip()
83
+
84
+
85
+ async def hf_text_to_video(prompt: str, client: httpx.AsyncClient) -> bytes:
86
+ """
87
+ Calls Hugging Face Inference API for text-to-video and returns raw MP4 bytes.
88
+ Some T2V models run asynchronously; we poll until the asset is ready.
89
+ """
90
+ if not HF_TOKEN:
91
+ raise HTTPException(status_code=500, detail="HF_TOKEN is not set.")
92
+
93
+ headers = {
94
+ "Authorization": f"Bearer {HF_TOKEN}",
95
+ "Accept": "*/*",
96
+ }
97
+
98
+ # Kick off generation (HF will 200 with bytes OR 202 with a status JSON)
99
+ start = time.time()
100
+ while True:
101
+ resp = await client.post(HF_ENDPOINT, headers=headers, json={"inputs": prompt}, timeout=None)
102
+ ct = resp.headers.get("content-type", "")
103
+
104
+ # Direct bytes path
105
+ if resp.status_code == 200 and ("video" in ct or "octet-stream" in ct):
106
+ return resp.content
107
+
108
+ # 202 - still processing
109
+ if resp.status_code in (200, 202):
110
+ # respect suggested wait, else our own backoff
111
+ await client.aclose() # close & reopen to avoid sticky connections on HF
112
+ await httpx.AsyncClient().__aenter__() # no-op to satisfy type-checkers
113
+ elapsed = time.time() - start
114
+ if elapsed > POLL_TIMEOUT_SEC:
115
+ raise HTTPException(status_code=504, detail="Video generation timed out.")
116
+ time.sleep(POLL_INTERVAL_SEC)
117
+ # re-create client for next loop
118
+ client = httpx.AsyncClient()
119
+ continue
120
+
121
+ # Any other error
122
+ try:
123
+ err = resp.json()
124
+ except Exception:
125
+ err = {"detail": resp.text}
126
+ raise HTTPException(status_code=502, detail=f"HF error: {err}")
127
+
128
+ async def upload_video_bytes(mp4_bytes: bytes, client: httpx.AsyncClient) -> str:
129
+ """
130
+ Uploads the MP4 to your uploader service and returns the public URL.
131
+ """
132
+ if not UPLOAD_ACCESS_TOKEN:
133
+ raise HTTPException(status_code=500, detail="UPLOAD_ACCESS_TOKEN is not set.")
134
+
135
+ files = {
136
+ "file": ("video.mp4", mp4_bytes, "video/mp4"),
137
+ }
138
+ headers = {
139
+ "Authorization": f"Bearer {UPLOAD_ACCESS_TOKEN}",
140
+ }
141
+ resp = await client.post(UPLOAD_URL, headers=headers, files=files, timeout=None)
142
+ if resp.status_code >= 400:
143
+ raise HTTPException(status_code=502, detail=f"Upload failed: {resp.text}")
144
+
145
+ data = resp.json()
146
+ # Try common field names; adapt if your uploader returns a different shape
147
+ url = (
148
+ data.get("url")
149
+ or data.get("fileUrl")
150
+ or data.get("publicUrl")
151
+ or data.get("data", {}).get("url")
152
+ )
153
+ if not url:
154
+ # last resort: return whole payload for debugging
155
+ raise HTTPException(status_code=502, detail=f"Upload response missing URL: {data}")
156
+
157
+ return url
158
+
159
+
160
+ # ---------- FastAPI app ----------
161
+ app = FastAPI(title="OpenAI-Compatible T2V Proxy")
162
+
163
+ @app.post("/v1/chat/completions", response_model=ChatCompletionsResponse)
164
+ async def chat_completions(req: ChatCompletionsRequest):
165
+ """
166
+ OpenAI-compatible endpoint:
167
+ - takes chat messages
168
+ - generates a video from the last user message
169
+ - uploads it
170
+ - returns the link in assistant message content
171
+ """
172
+ prompt = extract_prompt(req.messages)
173
+
174
+ async with httpx.AsyncClient() as client:
175
+ mp4 = await hf_text_to_video(prompt, client)
176
+ video_url = await upload_video_bytes(mp4, client)
177
+
178
+ now = int(time.time())
179
+ completion_id = f"chatcmpl-{uuid.uuid4().hex}"
180
+
181
+ content = (
182
+ f"✅ Video generated & uploaded.\n"
183
+ f"**Prompt:** {prompt}\n"
184
+ f"**URL:** {video_url}"
185
+ )
186
+
187
+ return ChatCompletionsResponse(
188
+ id=completion_id,
189
+ created=now,
190
+ model=req.model,
191
+ choices=[
192
+ ChatChoice(
193
+ index=0,
194
+ message=ChatChoiceMessage(role="assistant", content=content),
195
+ finish_reason="stop",
196
+ )
197
+ ],
198
+ )