rkihacker commited on
Commit
3a4123b
·
verified ·
1 Parent(s): a0aebc5

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +113 -55
main.py CHANGED
@@ -1,27 +1,28 @@
1
  import os
2
  import time
3
  import uuid
4
- from typing import List, Optional, Literal, Any, Dict, Union
5
 
6
  import httpx
7
  from fastapi import FastAPI, HTTPException
8
  from pydantic import BaseModel
9
- from huggingface_hub import InferenceClient
10
- import asyncio
11
-
12
-
13
- # ---------------- Config (env) ----------------
14
- HF_TOKEN = os.getenv("HF_TOKEN") # Hugging Face token (works for provider=fal-ai)
15
- WAN_MODEL = os.getenv("WAN_MODEL", "Wan-AI/Wan2.2-T2V-A14B")
16
 
 
 
 
 
 
 
 
17
  UPLOAD_URL = os.getenv("UPLOAD_URL", "https://upload.snapzion.com/api/public-upload")
18
- UPLOAD_ACCESS_TOKEN = os.getenv("UPLOAD_ACCESS_TOKEN") # your bearer token
19
 
20
- # Optional tuning
21
- GEN_TIMEOUT_SEC = int(os.getenv("GEN_TIMEOUT_SEC", "900")) # 15 min generation ceiling
 
22
 
23
 
24
- # ---------------- OpenAI-compatible schemas ----------------
25
  class ChatMessage(BaseModel):
26
  role: Literal["system", "user", "assistant", "tool"]
27
  content: str
@@ -65,9 +66,8 @@ class ChatCompletionsResponse(BaseModel):
65
  }
66
 
67
 
68
- # ---------------- Helpers ----------------
69
  def extract_prompt(messages: List[ChatMessage]) -> str:
70
- """Use the last user message as the prompt. Fallback to joining all user messages."""
71
  for m in reversed(messages):
72
  if m.role == "user" and m.content and m.content.strip():
73
  return m.content.strip()
@@ -77,42 +77,89 @@ def extract_prompt(messages: List[ChatMessage]) -> str:
77
  return "\n".join(user_texts).strip()
78
 
79
 
80
- async def generate_video_bytes(prompt: str) -> bytes:
81
- """Calls huggingface_hub.InferenceClient with provider='fal-ai' (Wan T2V) and returns MP4 bytes."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  if not HF_TOKEN:
83
  raise HTTPException(status_code=500, detail="HF_TOKEN is not set.")
84
- client = InferenceClient(provider="fal-ai", api_key=HF_TOKEN)
85
 
86
- def _sync_generate() -> Union[bytes, Dict[str, Any]]:
87
- # mirrors your Python example:
88
- # video = client.text_to_video("prompt", model="Wan-AI/Wan2.2-T2V-A14B")
89
- return client.text_to_video(prompt, model=WAN_MODEL)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
- try:
92
- result = await asyncio.wait_for(
93
- asyncio.get_event_loop().run_in_executor(None, _sync_generate),
94
- timeout=GEN_TIMEOUT_SEC,
95
- )
96
- except asyncio.TimeoutError:
97
- raise HTTPException(status_code=504, detail="Video generation timed out.")
98
- except Exception as e:
99
- raise HTTPException(status_code=502, detail=f"Video generation failed: {e}")
100
 
101
- # fal-ai provider typically returns a dict with "video": bytes; sometimes raw bytes
102
- if isinstance(result, (bytes, bytearray)):
103
- return bytes(result)
 
 
 
104
 
105
- if isinstance(result, dict):
106
- # common keys: "video" (bytes), "seed", etc.
107
- vid = result.get("video")
108
- if isinstance(vid, (bytes, bytearray)):
109
- return bytes(vid)
110
 
111
- raise HTTPException(status_code=502, detail=f"Unexpected generation result: {type(result)}")
 
 
 
 
 
112
 
113
 
114
  async def upload_video_bytes(mp4_bytes: bytes) -> str:
115
- """Uploads MP4 to Snapzion uploader and returns public URL."""
116
  if not UPLOAD_ACCESS_TOKEN:
117
  raise HTTPException(status_code=500, detail="UPLOAD_ACCESS_TOKEN is not set.")
118
  headers = {"Authorization": f"Bearer {UPLOAD_ACCESS_TOKEN}"}
@@ -120,12 +167,10 @@ async def upload_video_bytes(mp4_bytes: bytes) -> str:
120
 
121
  async with httpx.AsyncClient(timeout=None) as client:
122
  resp = await client.post(UPLOAD_URL, headers=headers, files=files)
123
-
124
  if resp.status_code >= 400:
125
  raise HTTPException(status_code=502, detail=f"Upload failed: {resp.text}")
126
 
127
  data = resp.json()
128
- # Try common URL fields (adjust if your API returns a different shape)
129
  url = (
130
  data.get("url")
131
  or data.get("fileUrl")
@@ -137,31 +182,44 @@ async def upload_video_bytes(mp4_bytes: bytes) -> str:
137
  return url
138
 
139
 
140
- # ---------------- FastAPI app ----------------
141
- app = FastAPI(title="OpenAI-Compatible T2V Proxy (FAL via HF)")
142
-
143
 
144
  @app.get("/health")
145
  async def health():
146
- return {"status": "ok", "model": WAN_MODEL}
147
-
148
 
149
  @app.post("/v1/chat/completions", response_model=ChatCompletionsResponse)
150
  async def chat_completions(req: ChatCompletionsRequest):
151
  """
152
- OpenAI-compatible endpoint:
153
- - reads last user message as the T2V prompt
154
- - generates a video with Wan-AI/Wan2.2-T2V-A14B via provider='fal-ai'
155
- - uploads to your uploader
156
- - returns the public URL inside the assistant message
157
  """
158
  prompt = extract_prompt(req.messages)
159
- mp4 = await generate_video_bytes(prompt)
160
- video_url = await upload_video_bytes(mp4)
161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  now = int(time.time())
163
  completion_id = f"chatcmpl-{uuid.uuid4().hex}"
164
- content = f"✅ Video generated & uploaded.\n**Prompt:** {prompt}\n**URL:** {video_url}"
165
 
166
  return ChatCompletionsResponse(
167
  id=completion_id,
 
1
  import os
2
  import time
3
  import uuid
4
+ from typing import List, Optional, Literal, Any, Dict
5
 
6
  import httpx
7
  from fastapi import FastAPI, HTTPException
8
  from pydantic import BaseModel
 
 
 
 
 
 
 
9
 
10
+ # ---------- Config via env ----------
11
+ HF_TOKEN = os.getenv("HF_TOKEN") # e.g. hf_jwt_...
12
+ # Default matches your curl submit endpoint
13
+ HF_SUBMIT_URL = os.getenv(
14
+ "HF_SUBMIT_URL",
15
+ "https://router.huggingface.co/fal-ai/fal-ai/wan/v2.2-a14b/text-to-video?_subdomain=queue",
16
+ )
17
  UPLOAD_URL = os.getenv("UPLOAD_URL", "https://upload.snapzion.com/api/public-upload")
18
+ UPLOAD_ACCESS_TOKEN = os.getenv("UPLOAD_ACCESS_TOKEN") # your Snapzion bearer
19
 
20
+ # Polling/backoff
21
+ POLL_INTERVAL_SEC = float(os.getenv("POLL_INTERVAL_SEC", "3"))
22
+ POLL_TIMEOUT_SEC = int(os.getenv("POLL_TIMEOUT_SEC", "900")) # 15 min max
23
 
24
 
25
+ # ---------- OpenAI-compatible schemas ----------
26
  class ChatMessage(BaseModel):
27
  role: Literal["system", "user", "assistant", "tool"]
28
  content: str
 
66
  }
67
 
68
 
69
+ # ---------- Helpers ----------
70
  def extract_prompt(messages: List[ChatMessage]) -> str:
 
71
  for m in reversed(messages):
72
  if m.role == "user" and m.content and m.content.strip():
73
  return m.content.strip()
 
77
  return "\n".join(user_texts).strip()
78
 
79
 
80
+ async def hf_queue_submit(prompt: str) -> Dict[str, str]:
81
+ """
82
+ Submit a job to the HF router queue endpoint.
83
+ Returns a dict containing status_url and response_url (from HF).
84
+ """
85
+ if not HF_TOKEN:
86
+ raise HTTPException(status_code=500, detail="HF_TOKEN is not set.")
87
+
88
+ headers = {
89
+ "Authorization": f"Bearer {HF_TOKEN}",
90
+ "Content-Type": "application/json",
91
+ "Accept": "*/*",
92
+ }
93
+ payload = {"prompt": prompt}
94
+
95
+ async with httpx.AsyncClient(timeout=None) as client:
96
+ resp = await client.post(HF_SUBMIT_URL, headers=headers, json=payload)
97
+ if resp.status_code >= 400:
98
+ raise HTTPException(status_code=502, detail=f"HF submit error: {resp.text}")
99
+
100
+ data = resp.json()
101
+ status_url = data.get("status_url") or data.get("urls", {}).get("status_url")
102
+ response_url = data.get("response_url") or data.get("urls", {}).get("response_url")
103
+ if not status_url or not response_url:
104
+ raise HTTPException(status_code=502, detail=f"Unexpected HF submit response: {data}")
105
+
106
+ return {"status_url": status_url, "response_url": response_url}
107
+
108
+
109
+ async def hf_queue_wait(status_url: str) -> None:
110
+ """
111
+ Polls the HF queue status_url until COMPLETED or error states.
112
+ """
113
  if not HF_TOKEN:
114
  raise HTTPException(status_code=500, detail="HF_TOKEN is not set.")
 
115
 
116
+ headers = {"Authorization": f"Bearer {HF_TOKEN}", "Accept": "*/*"}
117
+ start = time.time()
118
+
119
+ async with httpx.AsyncClient(timeout=None) as client:
120
+ while True:
121
+ resp = await client.get(status_url, headers=headers)
122
+ if resp.status_code >= 400:
123
+ raise HTTPException(status_code=502, detail=f"HF status error: {resp.text}")
124
+
125
+ data = resp.json()
126
+ status = data.get("status")
127
+
128
+ if status in ("COMPLETED", "SUCCEEDED"):
129
+ return
130
+ if status in ("FAILED", "ERROR", "CANCELLED", "CANCELED"):
131
+ raise HTTPException(status_code=502, detail=f"HF job failed: {data}")
132
+
133
+ if time.time() - start > POLL_TIMEOUT_SEC:
134
+ raise HTTPException(status_code=504, detail="HF queue timed out.")
135
+
136
+ time.sleep(POLL_INTERVAL_SEC)
137
+
138
 
139
+ async def hf_queue_fetch_result(response_url: str) -> Dict[str, Any]:
140
+ """
141
+ Fetch the final response JSON, which includes {"video": {"url": ...}, ...}
142
+ """
143
+ if not HF_TOKEN:
144
+ raise HTTPException(status_code=500, detail="HF_TOKEN is not set.")
 
 
 
145
 
146
+ headers = {"Authorization": f"Bearer {HF_TOKEN}", "Accept": "*/*"}
147
+ async with httpx.AsyncClient(timeout=None) as client:
148
+ resp = await client.get(response_url, headers=headers)
149
+ if resp.status_code >= 400:
150
+ raise HTTPException(status_code=502, detail=f"HF result error: {resp.text}")
151
+ return resp.json()
152
 
 
 
 
 
 
153
 
154
+ async def download_video(url: str) -> bytes:
155
+ async with httpx.AsyncClient(timeout=None) as client:
156
+ resp = await client.get(url)
157
+ if resp.status_code >= 400:
158
+ raise HTTPException(status_code=502, detail=f"Download failed: {resp.text}")
159
+ return resp.content
160
 
161
 
162
  async def upload_video_bytes(mp4_bytes: bytes) -> str:
 
163
  if not UPLOAD_ACCESS_TOKEN:
164
  raise HTTPException(status_code=500, detail="UPLOAD_ACCESS_TOKEN is not set.")
165
  headers = {"Authorization": f"Bearer {UPLOAD_ACCESS_TOKEN}"}
 
167
 
168
  async with httpx.AsyncClient(timeout=None) as client:
169
  resp = await client.post(UPLOAD_URL, headers=headers, files=files)
 
170
  if resp.status_code >= 400:
171
  raise HTTPException(status_code=502, detail=f"Upload failed: {resp.text}")
172
 
173
  data = resp.json()
 
174
  url = (
175
  data.get("url")
176
  or data.get("fileUrl")
 
182
  return url
183
 
184
 
185
+ # ---------- FastAPI ----------
186
+ app = FastAPI(title="OpenAI-Compatible T2V Proxy (HF Router Queue)")
 
187
 
188
  @app.get("/health")
189
  async def health():
190
+ return {"status": "ok", "submit_url": HF_SUBMIT_URL}
 
191
 
192
  @app.post("/v1/chat/completions", response_model=ChatCompletionsResponse)
193
  async def chat_completions(req: ChatCompletionsRequest):
194
  """
195
+ 1) submit to HF router queue (Bearer HF_TOKEN)
196
+ 2) poll status_url until COMPLETED
197
+ 3) fetch response_url -> video.url
198
+ 4) download MP4, upload to Snapzion
199
+ 5) return URL in OpenAI chat shape
200
  """
201
  prompt = extract_prompt(req.messages)
 
 
202
 
203
+ # 1) Submit
204
+ urls = await hf_queue_submit(prompt)
205
+
206
+ # 2) Wait
207
+ await hf_queue_wait(urls["status_url"])
208
+
209
+ # 3) Fetch result JSON
210
+ result = await hf_queue_fetch_result(urls["response_url"])
211
+ video_url = (result.get("video") or {}).get("url")
212
+ if not video_url:
213
+ raise HTTPException(status_code=502, detail=f"HF result missing video.url: {result}")
214
+
215
+ # 4) Download + re-upload
216
+ mp4 = await download_video(video_url)
217
+ public_url = await upload_video_bytes(mp4)
218
+
219
+ # 5) Respond OpenAI-style
220
  now = int(time.time())
221
  completion_id = f"chatcmpl-{uuid.uuid4().hex}"
222
+ content = f"✅ Video generated & uploaded.\n**Prompt:** {prompt}\n**URL:** {public_url}"
223
 
224
  return ChatCompletionsResponse(
225
  id=completion_id,