Princeaka commited on
Commit
dd691ea
·
verified ·
1 Parent(s): 10c8690

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -138
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # app.py — FastAPI REST API + mounted Gradio UI (Hugging Face Spaces compatible)
2
  import os
3
  import shutil
4
  import asyncio
@@ -6,31 +6,28 @@ import inspect
6
  from typing import Optional
7
 
8
  from fastapi import FastAPI, UploadFile, File, Form
9
- from fastapi.responses import JSONResponse, PlainTextResponse
10
  from fastapi.middleware.cors import CORSMiddleware
 
11
  import gradio as gr
12
- import uvicorn
13
 
14
- # ---- Your module ----
15
  from multimodal_module import MultiModalChatModule
16
 
17
- # Instantiate once at import time
18
  AI = MultiModalChatModule()
19
 
20
  TMP_DIR = "/tmp"
21
  os.makedirs(TMP_DIR, exist_ok=True)
22
 
23
- # ---------------- Helpers ----------------
24
  class FileWrapper:
25
- """Tiny adapter so your module can .download_to_drive(path)."""
26
  def __init__(self, path: str):
27
  self._path = path
28
-
29
- async def download_to_drive(self, dst_path: str) -> None:
30
  loop = asyncio.get_event_loop()
31
  await loop.run_in_executor(None, shutil.copyfile, self._path, dst_path)
32
 
33
- async def save_upload_to_tmp(up: UploadFile) -> str:
 
34
  if not up or not up.filename:
35
  raise ValueError("No file uploaded")
36
  dest = os.path.join(TMP_DIR, up.filename)
@@ -39,48 +36,28 @@ async def save_upload_to_tmp(up: UploadFile) -> str:
39
  f.write(data)
40
  return dest
41
 
 
42
  async def call_ai(fn, *args, **kwargs):
43
- """Call AI methods whether they are sync or async."""
44
  if fn is None:
45
- raise AttributeError("Requested AI method is not implemented in multimodal_module")
46
  if inspect.iscoroutinefunction(fn):
47
  return await fn(*args, **kwargs)
48
  return await asyncio.to_thread(lambda: fn(*args, **kwargs))
49
 
50
- # ---------------- FastAPI app ----------------
51
- app = FastAPI(title="Multimodal Module API", version="1.0.0")
52
 
53
- # CORS so external apps can call it
54
  app.add_middleware(
55
  CORSMiddleware,
56
- allow_origins=["*"], # tighten for production
57
  allow_credentials=True,
58
  allow_methods=["*"],
59
  allow_headers=["*"],
60
  )
61
 
62
- # ---- Health / root ----
63
- @app.get("/health", response_class=PlainTextResponse)
64
- async def health():
65
- return "ok"
66
-
67
- @app.get("/")
68
- async def root():
69
- return {
70
- "name": "Multimodal Module API",
71
- "status": "ready",
72
- "docs": "/docs",
73
- "gradio_ui": "/ui"
74
- }
75
-
76
- # ---------------- REST Endpoints ----------------
77
- # Text chat
78
  @app.post("/api/text")
79
- async def api_text(
80
- text: str = Form(...),
81
- user_id: Optional[int] = Form(0),
82
- lang: str = Form("en"),
83
- ):
84
  try:
85
  fn = getattr(AI, "generate_response", getattr(AI, "process_text", None))
86
  reply = await call_ai(fn, text, int(user_id), lang)
@@ -88,71 +65,37 @@ async def api_text(
88
  except Exception as e:
89
  return JSONResponse({"error": str(e)}, status_code=500)
90
 
91
- # Hugging Face-style predict (optional)
92
- @app.post("/api/predict")
93
- async def api_predict(
94
- inputs: str = Form(...),
95
- user_id: Optional[int] = Form(0),
96
- lang: str = Form("en"),
97
- ):
98
- try:
99
- fn = getattr(AI, "generate_response", getattr(AI, "process_text", None))
100
- reply = await call_ai(fn, inputs, int(user_id), lang)
101
- return {"data": [reply]}
102
- except Exception as e:
103
- return JSONResponse({"error": str(e)}, status_code=500)
104
-
105
- # Voice -> ASR / emotion
106
  @app.post("/api/voice")
107
- async def api_voice(
108
- user_id: Optional[int] = Form(0),
109
- audio_file: UploadFile = File(...),
110
- ):
111
  try:
112
- path = await save_upload_to_tmp(audio_file)
113
  fn = getattr(AI, "process_voice_message", None)
114
  result = await call_ai(fn, FileWrapper(path), int(user_id))
115
- return JSONResponse(result)
116
  except Exception as e:
117
  return JSONResponse({"error": str(e)}, status_code=500)
118
 
119
- # TTS
120
  @app.post("/api/voice_reply")
121
- async def api_voice_reply(
122
- user_id: Optional[int] = Form(0),
123
- reply_text: str = Form(...),
124
- fmt: str = Form("ogg"),
125
- ):
126
  try:
127
  fn = getattr(AI, "generate_voice_reply", None)
128
- out_path = await call_ai(fn, reply_text, int(user_id), fmt)
129
- return {"status": "ok", "file": out_path}
130
  except Exception as e:
131
  return JSONResponse({"error": str(e)}, status_code=500)
132
 
133
- # Image caption
134
  @app.post("/api/image_caption")
135
- async def api_image_caption(
136
- user_id: Optional[int] = Form(0),
137
- image_file: UploadFile = File(...),
138
- ):
139
  try:
140
- path = await save_upload_to_tmp(image_file)
141
  fn = getattr(AI, "process_image_message", None)
142
  caption = await call_ai(fn, FileWrapper(path), int(user_id))
143
  return {"status": "ok", "caption": caption}
144
  except Exception as e:
145
  return JSONResponse({"error": str(e)}, status_code=500)
146
 
147
- # Text-to-image
148
  @app.post("/api/generate_image")
149
- async def api_generate_image(
150
- user_id: Optional[int] = Form(0),
151
- prompt: str = Form(...),
152
- width: int = Form(512),
153
- height: int = Form(512),
154
- steps: int = Form(30),
155
- ):
156
  try:
157
  fn = getattr(AI, "generate_image_from_text", None)
158
  out_path = await call_ai(fn, prompt, int(user_id), width, height, steps)
@@ -160,66 +103,41 @@ async def api_generate_image(
160
  except Exception as e:
161
  return JSONResponse({"error": str(e)}, status_code=500)
162
 
163
- # Image edit / inpaint
164
  @app.post("/api/edit_image")
165
- async def api_edit_image(
166
- user_id: Optional[int] = Form(0),
167
- image_file: UploadFile = File(...),
168
- mask_file: Optional[UploadFile] = File(None),
169
- prompt: str = Form(""),
170
- ):
171
  try:
172
- img_path = await save_upload_to_tmp(image_file)
173
  mask_path = None
174
  if mask_file:
175
- mask_path = await save_upload_to_tmp(mask_file)
176
  fn = getattr(AI, "edit_image_inpaint", None)
177
- out_path = await call_ai(
178
- fn,
179
- FileWrapper(img_path),
180
- FileWrapper(mask_path) if mask_path else None,
181
- prompt,
182
- int(user_id),
183
- )
184
  return {"status": "ok", "file": out_path}
185
  except Exception as e:
186
  return JSONResponse({"error": str(e)}, status_code=500)
187
 
188
- # Video
189
  @app.post("/api/video")
190
- async def api_video(
191
- user_id: Optional[int] = Form(0),
192
- video_file: UploadFile = File(...),
193
- ):
194
  try:
195
- path = await save_upload_to_tmp(video_file)
196
  fn = getattr(AI, "process_video", None)
197
  result = await call_ai(fn, FileWrapper(path), int(user_id))
198
- return JSONResponse(result)
199
  except Exception as e:
200
  return JSONResponse({"error": str(e)}, status_code=500)
201
 
202
- # Files (PDF/DOCX/TXT)
203
  @app.post("/api/file")
204
- async def api_file(
205
- user_id: Optional[int] = Form(0),
206
- file_obj: UploadFile = File(...),
207
- ):
208
  try:
209
- path = await save_upload_to_tmp(file_obj)
210
  fn = getattr(AI, "process_file", None)
211
  result = await call_ai(fn, FileWrapper(path), int(user_id))
212
- return JSONResponse(result)
213
  except Exception as e:
214
  return JSONResponse({"error": str(e)}, status_code=500)
215
 
216
- # Code completion
217
  @app.post("/api/code")
218
- async def api_code(
219
- user_id: Optional[int] = Form(0),
220
- prompt: str = Form(...),
221
- max_tokens: int = Form(512),
222
- ):
223
  try:
224
  fn = getattr(AI, "code_complete", None)
225
  try:
@@ -230,28 +148,20 @@ async def api_code(
230
  except Exception as e:
231
  return JSONResponse({"error": str(e)}, status_code=500)
232
 
233
- # ---------------- Gradio UI (mounted at /ui) ----------------
234
- def _gradio_text_fn(text, user_id, lang):
235
  fn = getattr(AI, "generate_response", getattr(AI, "process_text", None))
236
- if fn is None:
237
- return "Error: text handler not implemented in multimodal_module"
238
- # Gradio callbacks run in a worker thread, safe to create/own an event loop
239
- return asyncio.run(call_ai(fn, text, int(user_id or 0), lang))
240
 
241
- with gr.Blocks(title="Multimodal Bot — UI") as demo:
242
- gr.Markdown("# 🧠 Multimodal Bot — UI\nThis is a helper UI. Use the REST API for external apps.")
243
  with gr.Row():
244
- g_uid = gr.Textbox(label="User ID", value="0")
245
- g_lang = gr.Dropdown(["en", "zh", "ja", "ko", "es", "fr", "de", "it"], value="en", label="Language")
246
- g_in = gr.Textbox(lines=3, label="Message")
247
- g_out = gr.Textbox(lines=6, label="Reply")
248
- gr.Button("Send").click(_gradio_text_fn, [g_in, g_uid, g_lang], g_out)
249
-
250
- # Mount Gradio *into* FastAPI at /ui (does not open another port)
251
- app = gr.mount_gradio_app(app, demo, path="/ui")
252
-
253
- # ---------------- Entrypoint ----------------
254
- if __name__ == "__main__":
255
- # Hugging Face Spaces (FastAPI template) sets PORT; bind exactly to it.
256
- port = int(os.environ.get("PORT", "7860"))
257
- uvicorn.run("app:app", host="0.0.0.0", port=port)
 
1
+ # app.py — FastAPI + Gradio (External API + UI)
2
  import os
3
  import shutil
4
  import asyncio
 
6
  from typing import Optional
7
 
8
  from fastapi import FastAPI, UploadFile, File, Form
 
9
  from fastapi.middleware.cors import CORSMiddleware
10
+ from fastapi.responses import JSONResponse
11
  import gradio as gr
 
12
 
 
13
  from multimodal_module import MultiModalChatModule
14
 
15
+ # Instantiate AI module
16
  AI = MultiModalChatModule()
17
 
18
  TMP_DIR = "/tmp"
19
  os.makedirs(TMP_DIR, exist_ok=True)
20
 
21
+ # --- File wrapper ---
22
  class FileWrapper:
 
23
  def __init__(self, path: str):
24
  self._path = path
25
+ async def download_to_drive(self, dst_path: str):
 
26
  loop = asyncio.get_event_loop()
27
  await loop.run_in_executor(None, shutil.copyfile, self._path, dst_path)
28
 
29
+ # --- Save uploaded file ---
30
+ async def save_upload(up: UploadFile) -> str:
31
  if not up or not up.filename:
32
  raise ValueError("No file uploaded")
33
  dest = os.path.join(TMP_DIR, up.filename)
 
36
  f.write(data)
37
  return dest
38
 
39
+ # --- Call AI (sync or async) ---
40
  async def call_ai(fn, *args, **kwargs):
 
41
  if fn is None:
42
+ raise AttributeError("Requested AI method not implemented")
43
  if inspect.iscoroutinefunction(fn):
44
  return await fn(*args, **kwargs)
45
  return await asyncio.to_thread(lambda: fn(*args, **kwargs))
46
 
47
+ # === FASTAPI APP ===
48
+ app = FastAPI(title="Multimodal API")
49
 
 
50
  app.add_middleware(
51
  CORSMiddleware,
52
+ allow_origins=["*"], # change for production
53
  allow_credentials=True,
54
  allow_methods=["*"],
55
  allow_headers=["*"],
56
  )
57
 
58
+ # --- API Endpoints ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  @app.post("/api/text")
60
+ async def api_text(text: str = Form(...), user_id: Optional[int] = Form(0), lang: str = Form("en")):
 
 
 
 
61
  try:
62
  fn = getattr(AI, "generate_response", getattr(AI, "process_text", None))
63
  reply = await call_ai(fn, text, int(user_id), lang)
 
65
  except Exception as e:
66
  return JSONResponse({"error": str(e)}, status_code=500)
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  @app.post("/api/voice")
69
+ async def api_voice(user_id: Optional[int] = Form(0), audio_file: UploadFile = File(...)):
 
 
 
70
  try:
71
+ path = await save_upload(audio_file)
72
  fn = getattr(AI, "process_voice_message", None)
73
  result = await call_ai(fn, FileWrapper(path), int(user_id))
74
+ return {"status": "ok", "result": result}
75
  except Exception as e:
76
  return JSONResponse({"error": str(e)}, status_code=500)
77
 
 
78
  @app.post("/api/voice_reply")
79
+ async def api_voice_reply(user_id: Optional[int] = Form(0), reply_text: str = Form(...), fmt: str = Form("ogg")):
 
 
 
 
80
  try:
81
  fn = getattr(AI, "generate_voice_reply", None)
82
+ result = await call_ai(fn, reply_text, int(user_id), fmt)
83
+ return {"status": "ok", "file": result}
84
  except Exception as e:
85
  return JSONResponse({"error": str(e)}, status_code=500)
86
 
 
87
  @app.post("/api/image_caption")
88
+ async def api_image_caption(user_id: Optional[int] = Form(0), image_file: UploadFile = File(...)):
 
 
 
89
  try:
90
+ path = await save_upload(image_file)
91
  fn = getattr(AI, "process_image_message", None)
92
  caption = await call_ai(fn, FileWrapper(path), int(user_id))
93
  return {"status": "ok", "caption": caption}
94
  except Exception as e:
95
  return JSONResponse({"error": str(e)}, status_code=500)
96
 
 
97
  @app.post("/api/generate_image")
98
+ async def api_generate_image(user_id: Optional[int] = Form(0), prompt: str = Form(...), width: int = Form(512), height: int = Form(512), steps: int = Form(30)):
 
 
 
 
 
 
99
  try:
100
  fn = getattr(AI, "generate_image_from_text", None)
101
  out_path = await call_ai(fn, prompt, int(user_id), width, height, steps)
 
103
  except Exception as e:
104
  return JSONResponse({"error": str(e)}, status_code=500)
105
 
 
106
  @app.post("/api/edit_image")
107
+ async def api_edit_image(user_id: Optional[int] = Form(0), image_file: UploadFile = File(...), mask_file: Optional[UploadFile] = File(None), prompt: str = Form("")):
 
 
 
 
 
108
  try:
109
+ img_path = await save_upload(image_file)
110
  mask_path = None
111
  if mask_file:
112
+ mask_path = await save_upload(mask_file)
113
  fn = getattr(AI, "edit_image_inpaint", None)
114
+ out_path = await call_ai(fn, FileWrapper(img_path), FileWrapper(mask_path) if mask_path else None, prompt, int(user_id))
 
 
 
 
 
 
115
  return {"status": "ok", "file": out_path}
116
  except Exception as e:
117
  return JSONResponse({"error": str(e)}, status_code=500)
118
 
 
119
  @app.post("/api/video")
120
+ async def api_video(user_id: Optional[int] = Form(0), video_file: UploadFile = File(...)):
 
 
 
121
  try:
122
+ path = await save_upload(video_file)
123
  fn = getattr(AI, "process_video", None)
124
  result = await call_ai(fn, FileWrapper(path), int(user_id))
125
+ return {"status": "ok", "result": result}
126
  except Exception as e:
127
  return JSONResponse({"error": str(e)}, status_code=500)
128
 
 
129
  @app.post("/api/file")
130
+ async def api_file(user_id: Optional[int] = Form(0), file_obj: UploadFile = File(...)):
 
 
 
131
  try:
132
+ path = await save_upload(file_obj)
133
  fn = getattr(AI, "process_file", None)
134
  result = await call_ai(fn, FileWrapper(path), int(user_id))
135
+ return {"status": "ok", "result": result}
136
  except Exception as e:
137
  return JSONResponse({"error": str(e)}, status_code=500)
138
 
 
139
  @app.post("/api/code")
140
+ async def api_code(user_id: Optional[int] = Form(0), prompt: str = Form(...), max_tokens: int = Form(512)):
 
 
 
 
141
  try:
142
  fn = getattr(AI, "code_complete", None)
143
  try:
 
148
  except Exception as e:
149
  return JSONResponse({"error": str(e)}, status_code=500)
150
 
151
+ # === GRADIO UI ===
152
+ def gradio_text_fn(text, user_id, lang):
153
  fn = getattr(AI, "generate_response", getattr(AI, "process_text", None))
154
+ loop = asyncio.get_event_loop()
155
+ return loop.run_until_complete(call_ai(fn, text, int(user_id or 0), lang))
 
 
156
 
157
+ with gr.Blocks(title="Multimodal Bot") as demo:
158
+ gr.Markdown("# 🧠 Multimodal Bot — UI")
159
  with gr.Row():
160
+ uid = gr.Textbox(label="User ID", value="0")
161
+ lang = gr.Dropdown(["en", "zh", "ja", "ko", "es", "fr", "de", "it"], value="en", label="Language")
162
+ inp = gr.Textbox(lines=3, label="Message")
163
+ out = gr.Textbox(lines=6, label="Reply")
164
+ gr.Button("Send").click(gradio_text_fn, [inp, uid, lang], out)
165
+
166
+ # Mount Gradio under /ui
167
+ app = gr.mount_gradio_app(app, demo, path="/ui")