Princeaka commited on
Commit
02f51d7
ยท
verified ยท
1 Parent(s): b6444d6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +188 -164
app.py CHANGED
@@ -1,170 +1,194 @@
1
- # app.py โ€” FastAPI + Gradio (External API + UI)
2
  import os
3
- os.environ["CUDA_VISIBLE_DEVICES"] = "" # Disable GPU
4
- os.environ["MPLBACKEND"] = "Agg" # Non-interactive matplotlib
5
- os.environ["IMAGEIO_FFMPEG_EXE"] = "/usr/bin/ffmpeg" # Explicit path
6
- import shutil
7
  import asyncio
8
- import inspect
9
- from typing import Optional
10
-
11
- from fastapi import FastAPI, UploadFile, File, Form
12
- from fastapi.middleware.cors import CORSMiddleware
13
- from fastapi.responses import JSONResponse
14
  import gradio as gr
15
-
16
  from multimodal_module import MultiModalChatModule
17
 
18
- # Instantiate AI module
19
- AI = MultiModalChatModule()
20
-
21
- TMP_DIR = "/tmp"
22
- os.makedirs(TMP_DIR, exist_ok=True)
23
-
24
- # --- File wrapper ---
25
- class FileWrapper:
26
- def __init__(self, path: str):
27
- self._path = path
28
- async def download_to_drive(self, dst_path: str):
29
- loop = asyncio.get_event_loop()
30
- await loop.run_in_executor(None, shutil.copyfile, self._path, dst_path)
31
-
32
- # --- Save uploaded file ---
33
- async def save_upload(up: UploadFile) -> str:
34
- if not up or not up.filename:
35
- raise ValueError("No file uploaded")
36
- dest = os.path.join(TMP_DIR, up.filename)
37
- data = await up.read()
38
- with open(dest, "wb") as f:
39
- f.write(data)
40
- return dest
41
-
42
- # --- Call AI (sync or async) ---
43
- async def call_ai(fn, *args, **kwargs):
44
- if fn is None:
45
- raise AttributeError("Requested AI method not implemented")
46
- if inspect.iscoroutinefunction(fn):
47
- return await fn(*args, **kwargs)
48
- return await asyncio.to_thread(lambda: fn(*args, **kwargs))
49
-
50
- # === FASTAPI APP ===
51
- app = FastAPI(title="Multimodal API")
52
-
53
- app.add_middleware(
54
- CORSMiddleware,
55
- allow_origins=["*"], # change for production
56
- allow_credentials=True,
57
- allow_methods=["*"],
58
- allow_headers=["*"],
59
- )
60
-
61
- # --- API Endpoints ---
62
- @app.post("/api/text")
63
- async def api_text(text: str = Form(...), user_id: Optional[int] = Form(0), lang: str = Form("en")):
64
- try:
65
- fn = getattr(AI, "generate_response", getattr(AI, "process_text", None))
66
- reply = await call_ai(fn, text, int(user_id), lang)
67
- return {"status": "ok", "reply": reply}
68
- except Exception as e:
69
- return JSONResponse({"error": str(e)}, status_code=500)
70
-
71
- @app.post("/api/voice")
72
- async def api_voice(user_id: Optional[int] = Form(0), audio_file: UploadFile = File(...)):
73
- try:
74
- path = await save_upload(audio_file)
75
- fn = getattr(AI, "process_voice_message", None)
76
- result = await call_ai(fn, FileWrapper(path), int(user_id))
77
- return {"status": "ok", "result": result}
78
- except Exception as e:
79
- return JSONResponse({"error": str(e)}, status_code=500)
80
-
81
- @app.post("/api/voice_reply")
82
- async def api_voice_reply(user_id: Optional[int] = Form(0), reply_text: str = Form(...), fmt: str = Form("ogg")):
83
- try:
84
- fn = getattr(AI, "generate_voice_reply", None)
85
- result = await call_ai(fn, reply_text, int(user_id), fmt)
86
- return {"status": "ok", "file": result}
87
- except Exception as e:
88
- return JSONResponse({"error": str(e)}, status_code=500)
89
-
90
- @app.post("/api/image_caption")
91
- async def api_image_caption(user_id: Optional[int] = Form(0), image_file: UploadFile = File(...)):
92
- try:
93
- path = await save_upload(image_file)
94
- fn = getattr(AI, "process_image_message", None)
95
- caption = await call_ai(fn, FileWrapper(path), int(user_id))
96
- return {"status": "ok", "caption": caption}
97
- except Exception as e:
98
- return JSONResponse({"error": str(e)}, status_code=500)
99
-
100
- @app.post("/api/generate_image")
101
- async def api_generate_image(user_id: Optional[int] = Form(0), prompt: str = Form(...), width: int = Form(512), height: int = Form(512), steps: int = Form(30)):
102
- try:
103
- fn = getattr(AI, "generate_image_from_text", None)
104
- out_path = await call_ai(fn, prompt, int(user_id), width, height, steps)
105
- return {"status": "ok", "file": out_path}
106
- except Exception as e:
107
- return JSONResponse({"error": str(e)}, status_code=500)
108
-
109
- @app.post("/api/edit_image")
110
- async def api_edit_image(user_id: Optional[int] = Form(0), image_file: UploadFile = File(...), mask_file: Optional[UploadFile] = File(None), prompt: str = Form("")):
111
- try:
112
- img_path = await save_upload(image_file)
113
- mask_path = None
114
- if mask_file:
115
- mask_path = await save_upload(mask_file)
116
- fn = getattr(AI, "edit_image_inpaint", None)
117
- out_path = await call_ai(fn, FileWrapper(img_path), FileWrapper(mask_path) if mask_path else None, prompt, int(user_id))
118
- return {"status": "ok", "file": out_path}
119
- except Exception as e:
120
- return JSONResponse({"error": str(e)}, status_code=500)
121
-
122
- @app.post("/api/video")
123
- async def api_video(user_id: Optional[int] = Form(0), video_file: UploadFile = File(...)):
124
- try:
125
- path = await save_upload(video_file)
126
- fn = getattr(AI, "process_video", None)
127
- result = await call_ai(fn, FileWrapper(path), int(user_id))
128
- return {"status": "ok", "result": result}
129
- except Exception as e:
130
- return JSONResponse({"error": str(e)}, status_code=500)
131
-
132
- @app.post("/api/file")
133
- async def api_file(user_id: Optional[int] = Form(0), file_obj: UploadFile = File(...)):
134
- try:
135
- path = await save_upload(file_obj)
136
- fn = getattr(AI, "process_file", None)
137
- result = await call_ai(fn, FileWrapper(path), int(user_id))
138
- return {"status": "ok", "result": result}
139
- except Exception as e:
140
- return JSONResponse({"error": str(e)}, status_code=500)
141
-
142
- @app.post("/api/code")
143
- async def api_code(user_id: Optional[int] = Form(0), prompt: str = Form(...), max_tokens: int = Form(512)):
144
- try:
145
- fn = getattr(AI, "code_complete", None)
146
- try:
147
- result = await call_ai(fn, int(user_id), prompt, max_tokens)
148
- except TypeError:
149
- result = await call_ai(fn, prompt, max_tokens=max_tokens)
150
- return {"status": "ok", "code": result}
151
- except Exception as e:
152
- return JSONResponse({"error": str(e)}, status_code=500)
153
-
154
- # === GRADIO UI ===
155
- def gradio_text_fn(text, user_id, lang):
156
- fn = getattr(AI, "generate_response", getattr(AI, "process_text", None))
157
- loop = asyncio.get_event_loop()
158
- return loop.run_until_complete(call_ai(fn, text, int(user_id or 0), lang))
159
-
160
- with gr.Blocks(title="Multimodal Bot") as demo:
161
- gr.Markdown("# ๐Ÿง  Multimodal Bot โ€” UI")
162
  with gr.Row():
163
- uid = gr.Textbox(label="User ID", value="0")
164
- lang = gr.Dropdown(["en", "zh", "ja", "ko", "es", "fr", "de", "it"], value="en", label="Language")
165
- inp = gr.Textbox(lines=3, label="Message")
166
- out = gr.Textbox(lines=6, label="Reply")
167
- gr.Button("Send").click(gradio_text_fn, [inp, uid, lang], out)
168
-
169
- # Mount Gradio under /ui
170
- app = gr.mount_gradio_app(app, demo, path="/ui")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
 
 
 
 
2
  import asyncio
3
+ import shutil
 
 
 
 
 
4
  import gradio as gr
 
5
  from multimodal_module import MultiModalChatModule
6
 
7
+ # Optional: keep model cache persistent across restarts
8
+ os.makedirs("model_cache", exist_ok=True)
9
+ os.environ.setdefault("HF_HUB_DISABLE_TELEMETRY", "1")
10
+ os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
11
+ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
12
+
13
+ mm = MultiModalChatModule()
14
+
15
+ # --- Small wrapper so Gradio file paths work with your module's .download_to_drive API ---
16
+ class _GradioFile:
17
+ def __init__(self, path: str | None):
18
+ self.path = path
19
+ async def download_to_drive(self, dest: str):
20
+ if not self.path:
21
+ raise ValueError("No file path provided.")
22
+ shutil.copy(self.path, dest)
23
+
24
+ # -------------------------
25
+ # TEXT CHAT
26
+ # -------------------------
27
+ async def chat_fn(user_id: str, message: str, lang: str):
28
+ uid = int(user_id or "1")
29
+ message = message or ""
30
+ lang = (lang or "en").strip()
31
+ return await mm.generate_response(message, uid, lang=lang)
32
+
33
+ # -------------------------
34
+ # TTS (generate_voice_reply)
35
+ # -------------------------
36
+ async def tts_fn(user_id: str, text: str, fmt: str):
37
+ uid = int(user_id or "1")
38
+ out_path = await mm.generate_voice_reply(text or "", user_id=uid, fmt=fmt)
39
+ # Gradio expects the file path for Audio/Image outputs
40
+ return out_path
41
+
42
+ # -------------------------
43
+ # VOICE -> TEXT (+emotion)
44
+ # -------------------------
45
+ async def voice_fn(user_id: str, audio_path: str | None):
46
+ uid = int(user_id or "1")
47
+ if not audio_path:
48
+ return {"text": "", "language": "en", "emotion": "no_audio", "is_speech": False}
49
+ result = await mm.process_voice_message(_GradioFile(audio_path), user_id=uid)
50
+ return result
51
+
52
+ # -------------------------
53
+ # IMAGE: caption
54
+ # -------------------------
55
+ async def img_caption_fn(user_id: str, image_path: str | None):
56
+ uid = int(user_id or "1")
57
+ if not image_path:
58
+ return "No image provided."
59
+ caption = await mm.process_image_message(_GradioFile(image_path), user_id=uid)
60
+ return caption
61
+
62
+ # -------------------------
63
+ # IMAGE: text2img
64
+ # -------------------------
65
+ async def img_generate_fn(user_id: str, prompt: str, width: int, height: int, steps: int):
66
+ uid = int(user_id or "1")
67
+ img_path = await mm.generate_image_from_text(prompt or "", user_id=uid, width=width, height=height, steps=steps)
68
+ return img_path
69
+
70
+ # -------------------------
71
+ # IMAGE: inpaint
72
+ # -------------------------
73
+ async def img_inpaint_fn(user_id: str, image_path: str | None, mask_path: str | None, prompt: str):
74
+ uid = int(user_id or "1")
75
+ if not image_path:
76
+ return None
77
+ out_path = await mm.edit_image_inpaint(
78
+ _GradioFile(image_path),
79
+ _GradioFile(mask_path) if mask_path else None,
80
+ prompt=prompt or "",
81
+ user_id=uid,
82
+ )
83
+ return out_path
84
+
85
+ # -------------------------
86
+ # VIDEO: process
87
+ # -------------------------
88
+ async def video_fn(user_id: str, video_path: str | None, max_frames: int):
89
+ uid = int(user_id or "1")
90
+ if not video_path:
91
+ return {"duration": 0, "fps": 0, "transcription": "", "captions": []}
92
+ result = await mm.process_video(_GradioFile(video_path), user_id=uid, max_frames=max_frames)
93
+ return result
94
+
95
+ # -------------------------
96
+ # FILE: process (pdf/docx/txt/csv)
97
+ # -------------------------
98
+ async def file_fn(user_id: str, file_path: str | None):
99
+ uid = int(user_id or "1")
100
+ if not file_path:
101
+ return {"summary": "", "length": 0, "type": ""}
102
+ result = await mm.process_file(_GradioFile(file_path), user_id=uid)
103
+ return result
104
+
105
+ # -------------------------
106
+ # CODE: complete
107
+ # -------------------------
108
+ async def code_complete_fn(prompt: str, max_tokens: int, temperature: float):
109
+ return await mm.code_complete(prompt or "", max_tokens=max_tokens, temperature=temperature)
110
+
111
+ # -------------------------
112
+ # CODE: execute (DANGEROUS)
113
+ # -------------------------
114
+ async def code_exec_fn(code: str, timeout: int):
115
+ # Your module already time-limits; still, treat as unsafe
116
+ result = await mm.execute_python_code(code or "", timeout=timeout)
117
+ # Present nicely
118
+ if "error" in result:
119
+ return f"ERROR: {result['error']}"
120
+ out = []
121
+ if result.get("stdout"):
122
+ out.append(f"[stdout]\n{result['stdout']}")
123
+ if result.get("stderr"):
124
+ out.append(f"[stderr]\n{result['stderr']}")
125
+ return "\n".join(out).strip() or "(no output)"
126
+
127
+ with gr.Blocks(title="Multimodal Space") as demo:
128
+ gr.Markdown("# ๐Ÿ”ฎ Multimodal Space")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  with gr.Row():
130
+ user_id = gr.Textbox(label="User ID", value="1", scale=1)
131
+ lang = gr.Textbox(label="Language code (e.g., en, fr, es)", value="en", scale=1)
132
+
133
+ with gr.Tab("๐Ÿ’ฌ Chat"):
134
+ msg_in = gr.Textbox(label="Message")
135
+ msg_out = gr.Textbox(label="Response", interactive=False)
136
+ gr.Button("Send").click(chat_fn, [user_id, msg_in, lang], msg_out)
137
+
138
+ with gr.Tab("๐Ÿ—ฃ๏ธ Voice โ†’ Text (+ Emotion)"):
139
+ audio_in = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Upload/record voice (ogg/wav/mp3)")
140
+ voice_json = gr.JSON(label="Result")
141
+ gr.Button("Transcribe & Analyze").click(voice_fn, [user_id, audio_in], voice_json)
142
+
143
+ with gr.Tab("๐Ÿ”Š TTS"):
144
+ tts_text = gr.Textbox(label="Text to speak")
145
+ tts_fmt = gr.Dropdown(choices=["ogg", "wav", "mp3"], value="ogg", label="Format")
146
+ tts_audio = gr.Audio(label="Generated Audio", interactive=False)
147
+ gr.Button("Generate Voice Reply").click(tts_fn, [user_id, tts_text, tts_fmt], tts_audio)
148
+
149
+ with gr.Tab("๐Ÿ–ผ๏ธ Image Caption"):
150
+ img_in = gr.Image(type="filepath", label="Image")
151
+ caption_out = gr.Textbox(label="Caption", interactive=False)
152
+ gr.Button("Caption").click(img_caption_fn, [user_id, img_in], caption_out)
153
+
154
+ with gr.Tab("๐ŸŽจ Text โ†’ Image"):
155
+ ti_prompt = gr.Textbox(label="Prompt")
156
+ with gr.Row():
157
+ ti_w = gr.Slider(256, 768, value=512, step=64, label="Width")
158
+ ti_h = gr.Slider(256, 768, value=512, step=64, label="Height")
159
+ ti_steps = gr.Slider(10, 50, value=30, step=1, label="Steps")
160
+ ti_out = gr.Image(label="Generated Image", interactive=False, type="filepath")
161
+ gr.Button("Generate").click(img_generate_fn, [user_id, ti_prompt, ti_w, ti_h, ti_steps], ti_out)
162
+
163
+ with gr.Tab("๐Ÿฉน Inpaint"):
164
+ base_img = gr.Image(type="filepath", label="Base image")
165
+ mask_img = gr.Image(type="filepath", label="Mask (white = keep, black = edit)", optional=True)
166
+ inpaint_prompt = gr.Textbox(label="Prompt")
167
+ inpaint_out = gr.Image(label="Edited Image", interactive=False, type="filepath")
168
+ gr.Button("Inpaint").click(img_inpaint_fn, [user_id, base_img, mask_img, inpaint_prompt], inpaint_out)
169
+
170
+ with gr.Tab("๐ŸŽž๏ธ Video"):
171
+ vid_in = gr.Video(label="Video file")
172
+ max_frames = gr.Slider(1, 12, value=4, step=1, label="Max keyframes to sample")
173
+ vid_json = gr.JSON(label="Result (duration/fps/transcript/captions)")
174
+ gr.Button("Process Video").click(video_fn, [user_id, vid_in, max_frames], vid_json)
175
+
176
+ with gr.Tab("๐Ÿ“„ File"):
177
+ file_in = gr.File(label="Upload file (pdf/docx/txt/csv)", type="filepath")
178
+ file_json = gr.JSON(label="Summary")
179
+ gr.Button("Process File").click(file_fn, [user_id, file_in], file_json)
180
+
181
+ with gr.Tab("๐Ÿ‘จโ€๐Ÿ’ป Code"):
182
+ cc_prompt = gr.Textbox(label="Completion prompt")
183
+ cc_tokens = gr.Slider(16, 1024, value=256, step=16, label="Max tokens")
184
+ cc_temp = gr.Slider(0.0, 1.0, value=0.2, step=0.05, label="Temperature")
185
+ cc_out = gr.Code(label="Completion")
186
+ gr.Button("Complete").click(code_complete_fn, [cc_prompt, cc_tokens, cc_temp], cc_out)
187
+
188
+ ce_code = gr.Code(label="Execute Python (sandboxed, time-limited)")
189
+ ce_timeout = gr.Slider(1, 10, value=5, step=1, label="Timeout (s)")
190
+ ce_out = gr.Code(label="Exec output")
191
+ gr.Button("Run Code").click(code_exec_fn, [ce_code, ce_timeout], ce_out)
192
+
193
+ # Make API-callable and Space-visible
194
+ demo.queue(concurrency_count=2, max_size=32).launch(server_name="0.0.0.0")