Princeaka commited on
Commit
b2130a6
·
verified ·
1 Parent(s): ea54d29

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -37
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # app.py -- HF-ready single-server FastAPI + Gradio mounted app (fixed)
2
  import os
3
  import shutil
4
  import asyncio
@@ -9,10 +9,10 @@ from fastapi.responses import JSONResponse
9
  import gradio as gr
10
  import uvicorn
11
 
12
- # Import your real multimodal module
13
  from multimodal_module import MultiModalChatModule
14
 
15
- # Instantiate your AI module
16
  AI = MultiModalChatModule()
17
 
18
  # ---------- Helpers ----------
@@ -20,7 +20,7 @@ TMP_DIR = "/tmp"
20
  os.makedirs(TMP_DIR, exist_ok=True)
21
 
22
  class FileWrapper:
23
- """Simple path wrapper compatible with your existing code expectations."""
24
  def __init__(self, path: str):
25
  self._path = path
26
 
@@ -29,52 +29,42 @@ class FileWrapper:
29
  await loop.run_in_executor(None, shutil.copyfile, self._path, dst_path)
30
 
31
  async def save_upload_to_tmp(up: UploadFile) -> str:
32
- """Save FastAPI UploadFile to /tmp and return path. Uses async read."""
33
  if not up or not up.filename:
34
- raise ValueError("UploadFile missing or has no filename")
35
  dest = os.path.join(TMP_DIR, up.filename)
36
- data = await up.read() # <-- important: async read
37
  with open(dest, "wb") as f:
38
  f.write(data)
39
  return dest
40
 
41
  async def call_ai(fn, *args, **kwargs):
42
- """
43
- Call AI functions safely: if fn is async, await it; if sync, run in thread.
44
- If fn is None, raise a clear error.
45
- """
46
  if fn is None:
47
- raise AttributeError("Requested AI method is not implemented in multimodal_module")
48
  if inspect.iscoroutinefunction(fn):
49
  return await fn(*args, **kwargs)
50
- else:
51
- return await asyncio.to_thread(lambda: fn(*args, **kwargs))
52
 
53
- # ---------- FastAPI app ----------
54
  app = FastAPI(title="Multimodal Module API")
55
 
56
- # Optional: allow CORS if external web apps will call this
57
  from fastapi.middleware.cors import CORSMiddleware
58
  app.add_middleware(
59
  CORSMiddleware,
60
- allow_origins=["*"], # tighten in production
61
  allow_credentials=True,
62
  allow_methods=["*"],
63
  allow_headers=["*"],
64
  )
65
 
66
  # ----------------- API endpoints -----------------
67
-
68
  @app.post("/api/predict")
69
  async def api_predict(inputs: str = Form(...), user_id: Optional[int] = Form(0), lang: str = Form("en")):
70
- """
71
- HuggingFace-style /predict compatibility.
72
- Form field 'inputs' used as text.
73
- """
74
  try:
75
  fn = getattr(AI, "generate_response", getattr(AI, "process_text", None))
76
  reply = await call_ai(fn, inputs, int(user_id), lang)
77
- # HF-style returns "data" array
78
  return {"data": [reply]}
79
  except Exception as e:
80
  return JSONResponse({"error": str(e)}, status_code=500)
@@ -90,9 +80,6 @@ async def api_text(text: str = Form(...), user_id: Optional[int] = Form(0), lang
90
 
91
  @app.post("/api/voice")
92
  async def api_voice(user_id: Optional[int] = Form(0), audio_file: UploadFile = File(...)):
93
- """
94
- Upload audio file (multipart/form-data). Returns whatever your AI.process_voice_message returns (JSON/dict).
95
- """
96
  try:
97
  path = await save_upload_to_tmp(audio_file)
98
  fn = getattr(AI, "process_voice_message", None)
@@ -166,8 +153,6 @@ async def api_file(user_id: Optional[int] = Form(0), file_obj: UploadFile = File
166
  async def api_code(user_id: Optional[int] = Form(0), prompt: str = Form(...), max_tokens: int = Form(512)):
167
  try:
168
  fn = getattr(AI, "code_complete", None)
169
- # Many implementations expect (user_id, prompt, max_tokens) or (prompt, max_tokens)
170
- # Try user-first signature first, fallback to prompt-first
171
  try:
172
  result = await call_ai(fn, int(user_id), prompt, max_tokens)
173
  except TypeError:
@@ -176,15 +161,13 @@ async def api_code(user_id: Optional[int] = Form(0), prompt: str = Form(...), ma
176
  except Exception as e:
177
  return JSONResponse({"error": str(e)}, status_code=500)
178
 
179
- # ---------- Minimal Gradio UI (mounted) ----------
180
  def gradio_text_fn(text, user_id, lang):
181
  fn = getattr(AI, "generate_response", getattr(AI, "process_text", None))
182
  if fn is None:
183
- return "Error: text handler not implemented in multimodal_module"
184
- if inspect.iscoroutinefunction(fn):
185
- return asyncio.run(call_ai(fn, text, int(user_id or 0), lang))
186
- else:
187
- return fn(text, int(user_id or 0), lang)
188
 
189
  with gr.Blocks(title="Multimodal Bot (UI)") as demo:
190
  gr.Markdown("# 🧠 Multimodal Bot — UI")
@@ -195,10 +178,10 @@ with gr.Blocks(title="Multimodal Bot (UI)") as demo:
195
  out = gr.Textbox(lines=6, label="Reply")
196
  gr.Button("Send").click(gradio_text_fn, [inp, txt_uid, txt_lang], out)
197
 
198
- # Mount Gradio app at root
199
  app = gr.mount_gradio_app(app, demo, path="/")
200
 
201
- # ---------- Run server (HF Spaces uses this entrypoint) ----------
202
  if __name__ == "__main__":
203
  port = int(os.environ.get("PORT", 7860))
204
- uvicorn.run(app, host="0.0.0.0", port=port)
 
1
+ # app.py -- HF-ready single-server FastAPI + Gradio mounted app (no double server conflict)
2
  import os
3
  import shutil
4
  import asyncio
 
9
  import gradio as gr
10
  import uvicorn
11
 
12
+ # Import your multimodal module
13
  from multimodal_module import MultiModalChatModule
14
 
15
+ # Instantiate AI module
16
  AI = MultiModalChatModule()
17
 
18
  # ---------- Helpers ----------
 
20
  os.makedirs(TMP_DIR, exist_ok=True)
21
 
22
  class FileWrapper:
23
+ """Simple path wrapper for AI methods."""
24
  def __init__(self, path: str):
25
  self._path = path
26
 
 
29
  await loop.run_in_executor(None, shutil.copyfile, self._path, dst_path)
30
 
31
  async def save_upload_to_tmp(up: UploadFile) -> str:
32
+ """Save FastAPI UploadFile to /tmp and return path."""
33
  if not up or not up.filename:
34
+ raise ValueError("No file uploaded")
35
  dest = os.path.join(TMP_DIR, up.filename)
36
+ data = await up.read()
37
  with open(dest, "wb") as f:
38
  f.write(data)
39
  return dest
40
 
41
  async def call_ai(fn, *args, **kwargs):
42
+ """Run AI method whether it's sync or async."""
 
 
 
43
  if fn is None:
44
+ raise AttributeError("Requested AI method not implemented")
45
  if inspect.iscoroutinefunction(fn):
46
  return await fn(*args, **kwargs)
47
+ return await asyncio.to_thread(lambda: fn(*args, **kwargs))
 
48
 
49
+ # ---------- FastAPI ----------
50
  app = FastAPI(title="Multimodal Module API")
51
 
52
+ # CORS (if you call this from the browser)
53
  from fastapi.middleware.cors import CORSMiddleware
54
  app.add_middleware(
55
  CORSMiddleware,
56
+ allow_origins=["*"], # tighten for prod
57
  allow_credentials=True,
58
  allow_methods=["*"],
59
  allow_headers=["*"],
60
  )
61
 
62
  # ----------------- API endpoints -----------------
 
63
  @app.post("/api/predict")
64
  async def api_predict(inputs: str = Form(...), user_id: Optional[int] = Form(0), lang: str = Form("en")):
 
 
 
 
65
  try:
66
  fn = getattr(AI, "generate_response", getattr(AI, "process_text", None))
67
  reply = await call_ai(fn, inputs, int(user_id), lang)
 
68
  return {"data": [reply]}
69
  except Exception as e:
70
  return JSONResponse({"error": str(e)}, status_code=500)
 
80
 
81
  @app.post("/api/voice")
82
  async def api_voice(user_id: Optional[int] = Form(0), audio_file: UploadFile = File(...)):
 
 
 
83
  try:
84
  path = await save_upload_to_tmp(audio_file)
85
  fn = getattr(AI, "process_voice_message", None)
 
153
  async def api_code(user_id: Optional[int] = Form(0), prompt: str = Form(...), max_tokens: int = Form(512)):
154
  try:
155
  fn = getattr(AI, "code_complete", None)
 
 
156
  try:
157
  result = await call_ai(fn, int(user_id), prompt, max_tokens)
158
  except TypeError:
 
161
  except Exception as e:
162
  return JSONResponse({"error": str(e)}, status_code=500)
163
 
164
+ # ---------- Minimal Gradio UI ----------
165
  def gradio_text_fn(text, user_id, lang):
166
  fn = getattr(AI, "generate_response", getattr(AI, "process_text", None))
167
  if fn is None:
168
+ return "Error: text handler not implemented"
169
+ loop = asyncio.get_event_loop()
170
+ return loop.run_until_complete(call_ai(fn, text, int(user_id or 0), lang))
 
 
171
 
172
  with gr.Blocks(title="Multimodal Bot (UI)") as demo:
173
  gr.Markdown("# 🧠 Multimodal Bot — UI")
 
178
  out = gr.Textbox(lines=6, label="Reply")
179
  gr.Button("Send").click(gradio_text_fn, [inp, txt_uid, txt_lang], out)
180
 
181
+ # Mount Gradio at /
182
  app = gr.mount_gradio_app(app, demo, path="/")
183
 
184
+ # ---------- Run ----------
185
  if __name__ == "__main__":
186
  port = int(os.environ.get("PORT", 7860))
187
+ uvicorn.run("app:app", host="0.0.0.0", port=port, reload=False)