Princeaka commited on
Commit
1de5d77
·
verified ·
1 Parent(s): a657484

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +145 -199
app.py CHANGED
@@ -1,242 +1,188 @@
1
- # app.py - Hybrid Gradio + FastAPI wrapper for multimodal_module.py
2
  import os
3
  import shutil
4
  import asyncio
5
- import json
6
  from typing import Optional
7
-
 
8
  import gradio as gr
9
- from fastapi import FastAPI, Request
 
 
10
  from multimodal_module import MultiModalChatModule
11
 
12
- # Instantiate AI
13
  AI = MultiModalChatModule()
14
 
15
- # ============================================================
16
- # Helper: File wrapper for Gradio uploads
17
- # ============================================================
18
- class GradioFileWrapper:
19
- def __init__(self, gr_file):
20
- if isinstance(gr_file, str):
21
- self._path = gr_file
22
- else:
23
- try:
24
- self._path = gr_file.name
25
- except Exception:
26
- try:
27
- self._path = gr_file["name"]
28
- except Exception:
29
- raise ValueError("Unsupported file object from Gradio")
30
 
31
  async def download_to_drive(self, dst_path: str) -> None:
 
32
  loop = asyncio.get_event_loop()
33
  await loop.run_in_executor(None, shutil.copyfile, self._path, dst_path)
34
 
35
-
36
- # ============================================================
37
- # Async-safe helper
38
- # ============================================================
39
- def run_async(coro):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  try:
41
- loop = asyncio.get_running_loop()
42
- except RuntimeError:
43
- loop = asyncio.new_event_loop()
44
- asyncio.set_event_loop(loop)
45
- loop = asyncio.get_event_loop()
46
- return loop.run_until_complete(coro)
47
-
48
-
49
- # ============================================================
50
- # Callback functions (used by Gradio & API)
51
- # ============================================================
52
- def text_chat(user_id: Optional[int], text: str, lang: str = "en"):
53
- try:
54
- uid = int(user_id) if user_id else 0
55
- reply = run_async(AI.generate_response(text, uid, lang))
56
- return reply
57
  except Exception as e:
58
- return f"Error: {e}"
59
-
60
 
61
- def voice_process(user_id: Optional[int], audio_file):
 
62
  try:
63
- uid = int(user_id) if user_id else 0
64
- wrapper = GradioFileWrapper(audio_file)
65
- result = run_async(AI.process_voice_message(wrapper, uid))
66
- return json.dumps(result, ensure_ascii=False, indent=2)
67
  except Exception as e:
68
- return f"Error: {e}"
69
-
70
 
71
- def generate_voice(user_id: Optional[int], reply_text: str, fmt: str = "ogg"):
 
 
 
 
72
  try:
73
- uid = int(user_id) if user_id else 0
74
- path = run_async(AI.generate_voice_reply(reply_text, uid, fmt))
75
- return path
76
  except Exception as e:
77
- return None, f"Error: {e}"
78
-
79
 
80
- def image_caption(user_id: Optional[int], image_file):
 
81
  try:
82
- uid = int(user_id) if user_id else 0
83
- wrapper = GradioFileWrapper(image_file)
84
- caption = run_async(AI.process_image_message(wrapper, uid))
85
- return caption
86
  except Exception as e:
87
- return f"Error: {e}"
88
 
89
-
90
- def generate_image(user_id: Optional[int], prompt: str, width: int = 512, height: int = 512, steps: int = 30):
91
  try:
92
- uid = int(user_id) if user_id else 0
93
- path = run_async(AI.generate_image_from_text(prompt, uid, width=width, height=height, steps=steps))
94
- return path
95
  except Exception as e:
96
- return f"Error: {e}"
97
-
98
 
99
- def edit_image(user_id: Optional[int], image_file, mask_file, prompt: str = ""):
 
100
  try:
101
- uid = int(user_id) if user_id else 0
102
- img_w = GradioFileWrapper(image_file)
103
- mask_w = GradioFileWrapper(mask_file) if mask_file else None
104
- path = run_async(AI.edit_image_inpaint(img_w, mask_w, prompt, uid))
105
- return path
106
  except Exception as e:
107
- return f"Error: {e}"
108
-
109
 
110
- def process_video(user_id: Optional[int], video_file):
 
111
  try:
112
- uid = int(user_id) if user_id else 0
113
- wrapper = GradioFileWrapper(video_file)
114
- res = run_async(AI.process_video(wrapper, uid))
115
- return json.dumps(res, ensure_ascii=False, indent=2)
 
 
116
  except Exception as e:
117
- return f"Error: {e}"
118
 
119
-
120
- def process_file(user_id: Optional[int], file_obj):
121
  try:
122
- uid = int(user_id) if user_id else 0
123
- w = GradioFileWrapper(file_obj)
124
- res = run_async(AI.process_file(w, uid))
125
- return json.dumps(res, ensure_ascii=False, indent=2)
126
  except Exception as e:
127
- return f"Error: {e}"
128
-
129
 
130
- def code_complete(user_id: Optional[int], prompt: str, max_tokens: int = 512):
 
131
  try:
132
- uid = int(user_id) if user_id else 0
133
- out = run_async(AI.code_complete(prompt, max_tokens=max_tokens))
134
- return out
135
  except Exception as e:
136
- return f"Error: {e}"
137
-
138
 
139
- # ============================================================
140
- # FastAPI public API
141
- # ============================================================
142
- api = FastAPI()
143
-
144
- @api.post("/api/predict")
145
- async def api_predict(request: Request):
146
  try:
147
- data = await request.json()
148
- user_id = data.get("user_id", 0)
149
- text = data.get("text", "")
150
- lang = data.get("lang", "en")
151
- reply = text_chat(user_id, text, lang)
152
- return {"status": "ok", "reply": reply}
153
  except Exception as e:
154
- return {"status": "error", "message": str(e)}
155
-
156
-
157
- # ============================================================
158
- # Gradio UI
159
- # ============================================================
160
- with gr.Blocks(title="Multimodal Bot (Gradio)") as demo:
161
- gr.Markdown("# 🧠 Multimodal Bot\nInteract via text, voice, images, video, or files.")
162
-
163
- with gr.Tab("💬 Text Chat"):
164
- with gr.Row():
165
- user_id_txt = gr.Textbox(label="User ID (optional)", placeholder="0")
166
- lang_sel = gr.Dropdown(choices=["en","zh","ja","ko","es","fr","de","it"], value="en", label="Language")
167
- txt_in = gr.Textbox(label="Your message", lines=4)
168
- txt_out = gr.Textbox(label="Bot reply", lines=6)
169
- gr.Button("Send").click(text_chat, [user_id_txt, txt_in, lang_sel], txt_out)
170
-
171
- with gr.Tab("🎤 Voice (Transcribe + Emotion)"):
172
- user_id_voice = gr.Textbox(label="User ID (optional)", placeholder="0")
173
- voice_in = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Record or upload voice (.ogg/.wav)")
174
- voice_out = gr.Textbox(label="Result JSON")
175
- gr.Button("Process Voice").click(voice_process, [user_id_voice, voice_in], voice_out)
176
-
177
- with gr.Tab("🔊 Voice Reply (TTS)"):
178
- user_id_vr = gr.Textbox(label="User ID (optional)", placeholder="0")
179
- vr_text = gr.Textbox(label="Text to speak", lines=4)
180
- vr_fmt = gr.Dropdown(choices=["ogg","wav","mp3"], value="ogg", label="Format")
181
- vr_audio = gr.Audio(label="Generated Voice")
182
- gr.Button("Generate Voice").click(generate_voice, [user_id_vr, vr_text, vr_fmt], vr_audio)
183
-
184
- with gr.Tab("🖼️ Image Caption"):
185
- user_id_img = gr.Textbox(label="User ID (optional)", placeholder="0")
186
- img_in = gr.Image(type="filepath", label="Upload Image")
187
- img_out = gr.Textbox(label="Caption")
188
- gr.Button("Caption Image").click(image_caption, [user_id_img, img_in], img_out)
189
-
190
- with gr.Tab("🎨 Image Generate"):
191
- user_id_gi = gr.Textbox(label="User ID (optional)", placeholder="0")
192
- prompt_in = gr.Textbox(label="Prompt", lines=3)
193
- width = gr.Slider(256, 1024, 512, step=64, label="Width")
194
- height = gr.Slider(256, 1024, 512, step=64, label="Height")
195
- steps = gr.Slider(10, 50, 30, step=5, label="Steps")
196
- gen_out = gr.Image(type="filepath", label="Generated image")
197
- gr.Button("Generate").click(generate_image, [user_id_gi, prompt_in, width, height, steps], gen_out)
198
-
199
- with gr.Tab("✏️ Image Edit (Inpaint)"):
200
- user_id_ie = gr.Textbox(label="User ID (optional)", placeholder="0")
201
- edit_img = gr.Image(type="filepath", label="Image to edit")
202
- edit_mask = gr.Image(type="filepath", label="Mask (optional)")
203
- edit_prompt = gr.Textbox(label="Prompt", lines=2)
204
- edit_out = gr.Image(type="filepath", label="Edited image")
205
- gr.Button("Edit Image").click(edit_image, [user_id_ie, edit_img, edit_mask, edit_prompt], edit_out)
206
-
207
- with gr.Tab("🎥 Video"):
208
- user_id_vid = gr.Textbox(label="User ID (optional)", placeholder="0")
209
- vid_in = gr.Video(label="Upload video")
210
- vid_out = gr.Textbox(label="Result JSON")
211
- gr.Button("Process Video").click(process_video, [user_id_vid, vid_in], vid_out)
212
-
213
- with gr.Tab("📄 Files (PDF/DOCX/TXT)"):
214
- user_id_file = gr.Textbox(label="User ID (optional)", placeholder="0")
215
- file_in = gr.File(label="Upload file")
216
- file_out = gr.Textbox(label="Result JSON")
217
- gr.Button("Process File").click(process_file, [user_id_file, file_in], file_out)
218
-
219
- with gr.Tab("💻 Code Generation"):
220
- user_id_code = gr.Textbox(label="User ID (optional)", placeholder="0")
221
- code_prompt = gr.Textbox(label="Code prompt", lines=6)
222
- code_out = gr.Textbox(label="Generated code", lines=12)
223
- gr.Button("Generate Code").click(code_complete, [user_id_code, code_prompt], code_out)
224
-
225
- gr.Markdown("----\nThis Space runs your exact `multimodal_module.py`. First requests may take longer due to model loading.")
226
-
227
-
228
- # ============================================================
229
- # Launch both API + Gradio
230
- # ============================================================
231
- import uvicorn
232
- from threading import Thread
233
-
234
- def start_api():
235
- uvicorn.run(api, host="0.0.0.0", port=8000)
236
-
237
- # Start FastAPI in a separate thread
238
- Thread(target=start_api, daemon=True).start()
239
-
240
- # Launch Gradio
241
- demo.queue()
242
- demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))
 
1
+ # app.py -- HF-ready single-server FastAPI + Gradio mounted app
2
  import os
3
  import shutil
4
  import asyncio
5
+ import inspect
6
  from typing import Optional
7
+ from fastapi import FastAPI, UploadFile, File, Form
8
+ from fastapi.responses import JSONResponse
9
  import gradio as gr
10
+ import uvicorn
11
+
12
+ # Import your real multimodal module
13
  from multimodal_module import MultiModalChatModule
14
 
15
+ # Instantiate your AI module
16
  AI = MultiModalChatModule()
17
 
18
+ # ---------- Helpers ----------
19
+ TMP_DIR = "/tmp"
20
+ os.makedirs(TMP_DIR, exist_ok=True)
21
+
22
+ class FileWrapper:
23
+ """Simple path wrapper compatible with your existing code expectations."""
24
+ def __init__(self, path: str):
25
+ self._path = path
 
 
 
 
 
 
 
26
 
27
  async def download_to_drive(self, dst_path: str) -> None:
28
+ # keep API similar to your GradioFileWrapper
29
  loop = asyncio.get_event_loop()
30
  await loop.run_in_executor(None, shutil.copyfile, self._path, dst_path)
31
 
32
+ def save_upload_to_tmp(up: UploadFile) -> str:
33
+ """Save FastAPI UploadFile to /tmp and return path."""
34
+ assert up and up.filename, "UploadFile missing filename"
35
+ dest = os.path.join(TMP_DIR, up.filename)
36
+ # overwrite if exists
37
+ with open(dest, "wb") as f:
38
+ f.write(up.file.read())
39
+ return dest
40
+
41
+ async def call_ai(fn, *args, **kwargs):
42
+ """
43
+ Call AI functions safely: if fn is async, await it; if sync, run in thread.
44
+ This avoids blocking the event loop.
45
+ """
46
+ if inspect.iscoroutinefunction(fn):
47
+ return await fn(*args, **kwargs)
48
+ else:
49
+ # run sync function in a thread to avoid blocking
50
+ return await asyncio.to_thread(lambda: fn(*args, **kwargs))
51
+
52
+ # ---------- FastAPI app ----------
53
+ app = FastAPI(title="Multimodal Module API")
54
+
55
+ # Optional: allow CORS if external web apps will call this
56
+ from fastapi.middleware.cors import CORSMiddleware
57
+ app.add_middleware(
58
+ CORSMiddleware,
59
+ allow_origins=["*"], # change to specific domains for production
60
+ allow_credentials=True,
61
+ allow_methods=["*"],
62
+ allow_headers=["*"],
63
+ )
64
+
65
+ # ----------------- API endpoints -----------------
66
+
67
+ @app.post("/api/predict")
68
+ async def api_predict(inputs: str = Form(...), user_id: Optional[int] = Form(0), lang: str = Form("en")):
69
+ """
70
+ HuggingFace-style /predict compatibility.
71
+ Form field 'inputs' used as text.
72
+ """
73
  try:
74
+ reply = await call_ai(getattr(AI, "generate_response", getattr(AI, "process_text", None)), inputs, int(user_id), lang)
75
+ # HF-style returns "data" array
76
+ return {"data": [reply]}
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  except Exception as e:
78
+ return JSONResponse({"error": str(e)}, status_code=500)
 
79
 
80
+ @app.post("/api/text")
81
+ async def api_text(text: str = Form(...), user_id: Optional[int] = Form(0), lang: str = Form("en")):
82
  try:
83
+ reply = await call_ai(getattr(AI, "generate_response", getattr(AI, "process_text", None)), text, int(user_id), lang)
84
+ return {"status": "ok", "reply": reply}
 
 
85
  except Exception as e:
86
+ return JSONResponse({"error": str(e)}, status_code=500)
 
87
 
88
+ @app.post("/api/voice")
89
+ async def api_voice(user_id: Optional[int] = Form(0), audio_file: UploadFile = File(...)):
90
+ """
91
+ Upload audio file (multipart/form-data). Returns whatever your AI.process_voice_message returns (JSON/dict).
92
+ """
93
  try:
94
+ path = save_upload_to_tmp(audio_file)
95
+ result = await call_ai(getattr(AI, "process_voice_message", None), FileWrapper(path), int(user_id))
96
+ return JSONResponse(result)
97
  except Exception as e:
98
+ return JSONResponse({"error": str(e)}, status_code=500)
 
99
 
100
+ @app.post("/api/voice_reply")
101
+ async def api_voice_reply(user_id: Optional[int] = Form(0), reply_text: str = Form(...), fmt: str = Form("ogg")):
102
  try:
103
+ result = await call_ai(getattr(AI, "generate_voice_reply", None), reply_text, int(user_id), fmt)
104
+ return {"status": "ok", "file": result}
 
 
105
  except Exception as e:
106
+ return JSONResponse({"error": str(e)}, status_code=500)
107
 
108
+ @app.post("/api/image_caption")
109
+ async def api_image_caption(user_id: Optional[int] = Form(0), image_file: UploadFile = File(...)):
110
  try:
111
+ path = save_upload_to_tmp(image_file)
112
+ caption = await call_ai(getattr(AI, "process_image_message", None), FileWrapper(path), int(user_id))
113
+ return {"status": "ok", "caption": caption}
114
  except Exception as e:
115
+ return JSONResponse({"error": str(e)}, status_code=500)
 
116
 
117
+ @app.post("/api/generate_image")
118
+ async def api_generate_image(user_id: Optional[int] = Form(0), prompt: str = Form(...), width: int = Form(512), height: int = Form(512), steps: int = Form(30)):
119
  try:
120
+ out_path = await call_ai(getattr(AI, "generate_image_from_text", None), prompt, int(user_id), width, height, steps)
121
+ return {"status": "ok", "file": out_path}
 
 
 
122
  except Exception as e:
123
+ return JSONResponse({"error": str(e)}, status_code=500)
 
124
 
125
+ @app.post("/api/edit_image")
126
+ async def api_edit_image(user_id: Optional[int] = Form(0), image_file: UploadFile = File(...), mask_file: Optional[UploadFile] = File(None), prompt: str = Form("")):
127
  try:
128
+ img_path = save_upload_to_tmp(image_file)
129
+ mask_path = None
130
+ if mask_file:
131
+ mask_path = save_upload_to_tmp(mask_file)
132
+ out_path = await call_ai(getattr(AI, "edit_image_inpaint", None), FileWrapper(img_path), FileWrapper(mask_path) if mask_path else None, prompt, int(user_id))
133
+ return {"status": "ok", "file": out_path}
134
  except Exception as e:
135
+ return JSONResponse({"error": str(e)}, status_code=500)
136
 
137
+ @app.post("/api/video")
138
+ async def api_video(user_id: Optional[int] = Form(0), video_file: UploadFile = File(...)):
139
  try:
140
+ path = save_upload_to_tmp(video_file)
141
+ result = await call_ai(getattr(AI, "process_video", None), FileWrapper(path), int(user_id))
142
+ return JSONResponse(result)
 
143
  except Exception as e:
144
+ return JSONResponse({"error": str(e)}, status_code=500)
 
145
 
146
+ @app.post("/api/file")
147
+ async def api_file(user_id: Optional[int] = Form(0), file_obj: UploadFile = File(...)):
148
  try:
149
+ path = save_upload_to_tmp(file_obj)
150
+ result = await call_ai(getattr(AI, "process_file", None), FileWrapper(path), int(user_id))
151
+ return JSONResponse(result)
152
  except Exception as e:
153
+ return JSONResponse({"error": str(e)}, status_code=500)
 
154
 
155
+ @app.post("/api/code")
156
+ async def api_code(user_id: Optional[int] = Form(0), prompt: str = Form(...), max_tokens: int = Form(512)):
 
 
 
 
 
157
  try:
158
+ result = await call_ai(getattr(AI, "code_complete", None), int(user_id), prompt, max_tokens)
159
+ # Some modules return string/code, others dict — normalize:
160
+ return {"status": "ok", "code": result}
 
 
 
161
  except Exception as e:
162
+ return JSONResponse({"error": str(e)}, status_code=500)
163
+
164
+ # ---------- Minimal Gradio UI (mounted) ----------
165
+ def gradio_text_fn(text, user_id, lang):
166
+ # call AI synchronously from Gradio (blocking safe in Gradio)
167
+ if inspect.iscoroutinefunction(getattr(AI, "generate_response", getattr(AI, "process_text", None))):
168
+ return asyncio.run(call_ai(getattr(AI, "generate_response", getattr(AI, "process_text", None)), text, int(user_id or 0), lang))
169
+ else:
170
+ # sync
171
+ return getattr(AI, "generate_response", getattr(AI, "process_text", None))(text, int(user_id or 0), lang)
172
+
173
+ with gr.Blocks(title="Multimodal Bot (UI)") as demo:
174
+ gr.Markdown("# 🧠 Multimodal Bot — UI")
175
+ with gr.Row():
176
+ txt_uid = gr.Textbox(label="User ID", value="0")
177
+ txt_lang = gr.Dropdown(["en","zh","ja","ko","es","fr","de","it"], value="en", label="Language")
178
+ inp = gr.Textbox(lines=3, label="Message")
179
+ out = gr.Textbox(lines=6, label="Reply")
180
+ gr.Button("Send").click(gradio_text_fn, [inp, txt_uid, txt_lang], out)
181
+
182
+ # Mount Gradio app at root
183
+ app = gr.mount_gradio_app(app, demo, path="/")
184
+
185
+ # ---------- Run server (HF Spaces uses this entrypoint) ----------
186
+ if __name__ == "__main__":
187
+ port = int(os.environ.get("PORT", 7860))
188
+ uvicorn.run(app, host="0.0.0.0", port=port)