rkihacker commited on
Commit
534be3f
·
verified ·
1 Parent(s): f07cfab

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +74 -95
main.py CHANGED
@@ -1,27 +1,27 @@
1
  import os
2
  import time
3
  import uuid
4
- from typing import List, Optional, Literal, Any, Dict
5
 
6
  import httpx
7
  from fastapi import FastAPI, HTTPException
8
  from pydantic import BaseModel
 
 
9
 
10
- # ---------- Config (env) ----------
11
- HF_TOKEN = os.getenv("HF_TOKEN") # Hugging Face API token
12
- UPLOAD_URL = os.getenv("UPLOAD_URL", "https://upload.snapzion.com/api/public-upload")
13
- UPLOAD_ACCESS_TOKEN = os.getenv("UPLOAD_ACCESS_TOKEN") # Bearer token for your uploader
14
  WAN_MODEL = os.getenv("WAN_MODEL", "Wan-AI/Wan2.2-T2V-A14B")
15
- HF_ENDPOINT = os.getenv(
16
- "HF_ENDPOINT",
17
- f"https://api-inference.huggingface.co/models/{WAN_MODEL}",
18
- )
19
- # Polling settings for HF async generation
20
- POLL_INTERVAL_SEC = float(os.getenv("POLL_INTERVAL_SEC", "3"))
21
- POLL_TIMEOUT_SEC = int(os.getenv("POLL_TIMEOUT_SEC", "600")) # 10 minutes max
22
 
23
 
24
- # ---------- OpenAI-compatible schemas ----------
25
  class ChatMessage(BaseModel):
26
  role: Literal["system", "user", "assistant", "tool"]
27
  content: str
@@ -33,7 +33,6 @@ class ChatCompletionsRequest(BaseModel):
33
  temperature: Optional[float] = None
34
  max_tokens: Optional[int] = None
35
  stream: Optional[bool] = False
36
- # we accept arbitrary extras but ignore them
37
  n: Optional[int] = 1
38
  top_p: Optional[float] = None
39
  presence_penalty: Optional[float] = None
@@ -66,84 +65,67 @@ class ChatCompletionsResponse(BaseModel):
66
  }
67
 
68
 
69
- # ---------- Helpers ----------
70
  def extract_prompt(messages: List[ChatMessage]) -> str:
71
- """
72
- Basic heuristic: use the content of the last user message as the video prompt.
73
- If none found, join all user contents.
74
- """
75
- for msg in reversed(messages):
76
- if msg.role == "user" and msg.content.strip():
77
- return msg.content.strip()
78
- # fallback
79
- user_texts = [m.content for m in messages if m.role == "user"]
80
  if not user_texts:
81
  raise HTTPException(status_code=400, detail="No user prompt provided.")
82
  return "\n".join(user_texts).strip()
83
 
84
 
85
- async def hf_text_to_video(prompt: str, client: httpx.AsyncClient) -> bytes:
86
- """
87
- Calls Hugging Face Inference API for text-to-video and returns raw MP4 bytes.
88
- Some T2V models run asynchronously; we poll until the asset is ready.
89
- """
90
  if not HF_TOKEN:
91
  raise HTTPException(status_code=500, detail="HF_TOKEN is not set.")
92
-
93
- headers = {
94
- "Authorization": f"Bearer {HF_TOKEN}",
95
- "Accept": "*/*",
96
- }
97
-
98
- # Kick off generation (HF will 200 with bytes OR 202 with a status JSON)
99
- start = time.time()
100
- while True:
101
- resp = await client.post(HF_ENDPOINT, headers=headers, json={"inputs": prompt}, timeout=None)
102
- ct = resp.headers.get("content-type", "")
103
-
104
- # Direct bytes path
105
- if resp.status_code == 200 and ("video" in ct or "octet-stream" in ct):
106
- return resp.content
107
-
108
- # 202 - still processing
109
- if resp.status_code in (200, 202):
110
- # respect suggested wait, else our own backoff
111
- await client.aclose() # close & reopen to avoid sticky connections on HF
112
- await httpx.AsyncClient().__aenter__() # no-op to satisfy type-checkers
113
- elapsed = time.time() - start
114
- if elapsed > POLL_TIMEOUT_SEC:
115
- raise HTTPException(status_code=504, detail="Video generation timed out.")
116
- time.sleep(POLL_INTERVAL_SEC)
117
- # re-create client for next loop
118
- client = httpx.AsyncClient()
119
- continue
120
-
121
- # Any other error
122
- try:
123
- err = resp.json()
124
- except Exception:
125
- err = {"detail": resp.text}
126
- raise HTTPException(status_code=502, detail=f"HF error: {err}")
127
-
128
- async def upload_video_bytes(mp4_bytes: bytes, client: httpx.AsyncClient) -> str:
129
- """
130
- Uploads the MP4 to your uploader service and returns the public URL.
131
- """
132
  if not UPLOAD_ACCESS_TOKEN:
133
  raise HTTPException(status_code=500, detail="UPLOAD_ACCESS_TOKEN is not set.")
 
 
 
 
 
134
 
135
- files = {
136
- "file": ("video.mp4", mp4_bytes, "video/mp4"),
137
- }
138
- headers = {
139
- "Authorization": f"Bearer {UPLOAD_ACCESS_TOKEN}",
140
- }
141
- resp = await client.post(UPLOAD_URL, headers=headers, files=files, timeout=None)
142
  if resp.status_code >= 400:
143
  raise HTTPException(status_code=502, detail=f"Upload failed: {resp.text}")
144
 
145
  data = resp.json()
146
- # Try common field names; adapt if your uploader returns a different shape
147
  url = (
148
  data.get("url")
149
  or data.get("fileUrl")
@@ -151,38 +133,35 @@ async def upload_video_bytes(mp4_bytes: bytes, client: httpx.AsyncClient) -> str
151
  or data.get("data", {}).get("url")
152
  )
153
  if not url:
154
- # last resort: return whole payload for debugging
155
  raise HTTPException(status_code=502, detail=f"Upload response missing URL: {data}")
156
-
157
  return url
158
 
159
 
160
- # ---------- FastAPI app ----------
161
- app = FastAPI(title="OpenAI-Compatible T2V Proxy")
 
 
 
 
 
 
162
 
163
  @app.post("/v1/chat/completions", response_model=ChatCompletionsResponse)
164
  async def chat_completions(req: ChatCompletionsRequest):
165
  """
166
  OpenAI-compatible endpoint:
167
- - takes chat messages
168
- - generates a video from the last user message
169
- - uploads it
170
- - returns the link in assistant message content
171
  """
172
  prompt = extract_prompt(req.messages)
173
-
174
- async with httpx.AsyncClient() as client:
175
- mp4 = await hf_text_to_video(prompt, client)
176
- video_url = await upload_video_bytes(mp4, client)
177
 
178
  now = int(time.time())
179
  completion_id = f"chatcmpl-{uuid.uuid4().hex}"
180
-
181
- content = (
182
- f"✅ Video generated & uploaded.\n"
183
- f"**Prompt:** {prompt}\n"
184
- f"**URL:** {video_url}"
185
- )
186
 
187
  return ChatCompletionsResponse(
188
  id=completion_id,
 
1
  import os
2
  import time
3
  import uuid
4
+ from typing import List, Optional, Literal, Any, Dict, Union
5
 
6
  import httpx
7
  from fastapi import FastAPI, HTTPException
8
  from pydantic import BaseModel
9
+ from huggingface_hub import InferenceClient
10
+ import asyncio
11
 
12
+
13
+ # ---------------- Config (env) ----------------
14
+ HF_TOKEN = os.getenv("HF_TOKEN") # Hugging Face token (works for provider=fal-ai)
 
15
  WAN_MODEL = os.getenv("WAN_MODEL", "Wan-AI/Wan2.2-T2V-A14B")
16
+
17
+ UPLOAD_URL = os.getenv("UPLOAD_URL", "https://upload.snapzion.com/api/public-upload")
18
+ UPLOAD_ACCESS_TOKEN = os.getenv("UPLOAD_ACCESS_TOKEN") # your bearer token
19
+
20
+ # Optional tuning
21
+ GEN_TIMEOUT_SEC = int(os.getenv("GEN_TIMEOUT_SEC", "900")) # 15 min generation ceiling
 
22
 
23
 
24
+ # ---------------- OpenAI-compatible schemas ----------------
25
  class ChatMessage(BaseModel):
26
  role: Literal["system", "user", "assistant", "tool"]
27
  content: str
 
33
  temperature: Optional[float] = None
34
  max_tokens: Optional[int] = None
35
  stream: Optional[bool] = False
 
36
  n: Optional[int] = 1
37
  top_p: Optional[float] = None
38
  presence_penalty: Optional[float] = None
 
65
  }
66
 
67
 
68
+ # ---------------- Helpers ----------------
69
  def extract_prompt(messages: List[ChatMessage]) -> str:
70
+ """Use the last user message as the prompt. Fallback to joining all user messages."""
71
+ for m in reversed(messages):
72
+ if m.role == "user" and m.content and m.content.strip():
73
+ return m.content.strip()
74
+ user_texts = [m.content for m in messages if m.role == "user" and m.content]
 
 
 
 
75
  if not user_texts:
76
  raise HTTPException(status_code=400, detail="No user prompt provided.")
77
  return "\n".join(user_texts).strip()
78
 
79
 
80
+ async def generate_video_bytes(prompt: str) -> bytes:
81
+ """Calls huggingface_hub.InferenceClient with provider='fal-ai' (Wan T2V) and returns MP4 bytes."""
 
 
 
82
  if not HF_TOKEN:
83
  raise HTTPException(status_code=500, detail="HF_TOKEN is not set.")
84
+ client = InferenceClient(provider="fal-ai", api_key=HF_TOKEN)
85
+
86
+ def _sync_generate() -> Union[bytes, Dict[str, Any]]:
87
+ # mirrors your Python example:
88
+ # video = client.text_to_video("prompt", model="Wan-AI/Wan2.2-T2V-A14B")
89
+ return client.text_to_video(prompt, model=WAN_MODEL)
90
+
91
+ try:
92
+ result = await asyncio.wait_for(
93
+ asyncio.get_event_loop().run_in_executor(None, _sync_generate),
94
+ timeout=GEN_TIMEOUT_SEC,
95
+ )
96
+ except asyncio.TimeoutError:
97
+ raise HTTPException(status_code=504, detail="Video generation timed out.")
98
+ except Exception as e:
99
+ raise HTTPException(status_code=502, detail=f"Video generation failed: {e}")
100
+
101
+ # fal-ai provider typically returns a dict with "video": bytes; sometimes raw bytes
102
+ if isinstance(result, (bytes, bytearray)):
103
+ return bytes(result)
104
+
105
+ if isinstance(result, dict):
106
+ # common keys: "video" (bytes), "seed", etc.
107
+ vid = result.get("video")
108
+ if isinstance(vid, (bytes, bytearray)):
109
+ return bytes(vid)
110
+
111
+ raise HTTPException(status_code=502, detail=f"Unexpected generation result: {type(result)}")
112
+
113
+
114
+ async def upload_video_bytes(mp4_bytes: bytes) -> str:
115
+ """Uploads MP4 to Snapzion uploader and returns public URL."""
 
 
 
 
 
 
 
 
116
  if not UPLOAD_ACCESS_TOKEN:
117
  raise HTTPException(status_code=500, detail="UPLOAD_ACCESS_TOKEN is not set.")
118
+ headers = {"Authorization": f"Bearer {UPLOAD_ACCESS_TOKEN}"}
119
+ files = {"file": ("video.mp4", mp4_bytes, "video/mp4")}
120
+
121
+ async with httpx.AsyncClient(timeout=None) as client:
122
+ resp = await client.post(UPLOAD_URL, headers=headers, files=files)
123
 
 
 
 
 
 
 
 
124
  if resp.status_code >= 400:
125
  raise HTTPException(status_code=502, detail=f"Upload failed: {resp.text}")
126
 
127
  data = resp.json()
128
+ # Try common URL fields (adjust if your API returns a different shape)
129
  url = (
130
  data.get("url")
131
  or data.get("fileUrl")
 
133
  or data.get("data", {}).get("url")
134
  )
135
  if not url:
 
136
  raise HTTPException(status_code=502, detail=f"Upload response missing URL: {data}")
 
137
  return url
138
 
139
 
140
+ # ---------------- FastAPI app ----------------
141
+ app = FastAPI(title="OpenAI-Compatible T2V Proxy (FAL via HF)")
142
+
143
+
144
+ @app.get("/health")
145
+ async def health():
146
+ return {"status": "ok", "model": WAN_MODEL}
147
+
148
 
149
  @app.post("/v1/chat/completions", response_model=ChatCompletionsResponse)
150
  async def chat_completions(req: ChatCompletionsRequest):
151
  """
152
  OpenAI-compatible endpoint:
153
+ - reads last user message as the T2V prompt
154
+ - generates a video with Wan-AI/Wan2.2-T2V-A14B via provider='fal-ai'
155
+ - uploads to your uploader
156
+ - returns the public URL inside the assistant message
157
  """
158
  prompt = extract_prompt(req.messages)
159
+ mp4 = await generate_video_bytes(prompt)
160
+ video_url = await upload_video_bytes(mp4)
 
 
161
 
162
  now = int(time.time())
163
  completion_id = f"chatcmpl-{uuid.uuid4().hex}"
164
+ content = f"✅ Video generated & uploaded.\n**Prompt:** {prompt}\n**URL:** {video_url}"
 
 
 
 
 
165
 
166
  return ChatCompletionsResponse(
167
  id=completion_id,