Spaces:
Running
Running
import os | |
import shutil | |
import asyncio | |
import inspect | |
import socket | |
from typing import Optional | |
from fastapi import FastAPI, UploadFile, File, Form | |
from fastapi.responses import JSONResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
import gradio as gr | |
from multimodal_module import MultiModalChatModule | |
AI = MultiModalChatModule() | |
TMP_DIR = "/tmp" | |
os.makedirs(TMP_DIR, exist_ok=True) | |
# --- Helpers --- | |
class FileWrapper: | |
def __init__(self, path: str): | |
self._path = path | |
async def download_to_drive(self, dst_path: str) -> None: | |
loop = asyncio.get_event_loop() | |
await loop.run_in_executor(None, shutil.copyfile, self._path, dst_path) | |
async def save_upload_to_tmp(up: UploadFile) -> str: | |
if not up or not up.filename: | |
raise ValueError("No file uploaded") | |
dest = os.path.join(TMP_DIR, up.filename) | |
data = await up.read() | |
with open(dest, "wb") as f: | |
f.write(data) | |
return dest | |
async def call_ai(fn, *args, **kwargs): | |
if fn is None: | |
raise AttributeError("Requested AI method not implemented") | |
if inspect.iscoroutinefunction(fn): | |
return await fn(*args, **kwargs) | |
return await asyncio.to_thread(lambda: fn(*args, **kwargs)) | |
def get_free_port(): | |
s = socket.socket() | |
s.bind(("", 0)) | |
port = s.getsockname()[1] | |
s.close() | |
return port | |
# --- FastAPI app --- | |
app = FastAPI(title="Multimodal Module API") | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # Change in production! | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# --- API Endpoints --- | |
async def api_text(text: str = Form(...), user_id: Optional[int] = Form(0), lang: str = Form("en")): | |
try: | |
fn = getattr(AI, "generate_response", getattr(AI, "process_text", None)) | |
reply = await call_ai(fn, text, int(user_id), lang) | |
return {"status": "ok", "reply": reply} | |
except Exception as e: | |
return JSONResponse({"error": str(e)}, status_code=500) | |
async def api_voice(user_id: Optional[int] = Form(0), audio_file: UploadFile = File(...)): | |
try: | |
path = await save_upload_to_tmp(audio_file) | |
fn = getattr(AI, "process_voice_message", None) | |
result = await call_ai(fn, FileWrapper(path), int(user_id)) | |
return JSONResponse(result) | |
except Exception as e: | |
return JSONResponse({"error": str(e)}, status_code=500) | |
async def api_voice_reply(user_id: Optional[int] = Form(0), reply_text: str = Form(...), fmt: str = Form("ogg")): | |
try: | |
fn = getattr(AI, "generate_voice_reply", None) | |
result = await call_ai(fn, reply_text, int(user_id), fmt) | |
return {"status": "ok", "file": result} | |
except Exception as e: | |
return JSONResponse({"error": str(e)}, status_code=500) | |
async def api_image_caption(user_id: Optional[int] = Form(0), image_file: UploadFile = File(...)): | |
try: | |
path = await save_upload_to_tmp(image_file) | |
fn = getattr(AI, "process_image_message", None) | |
caption = await call_ai(fn, FileWrapper(path), int(user_id)) | |
return {"status": "ok", "caption": caption} | |
except Exception as e: | |
return JSONResponse({"error": str(e)}, status_code=500) | |
async def api_generate_image( | |
user_id: Optional[int] = Form(0), | |
prompt: str = Form(...), | |
width: int = Form(512), | |
height: int = Form(512), | |
steps: int = Form(30) | |
): | |
try: | |
fn = getattr(AI, "generate_image_from_text", None) | |
out_path = await call_ai(fn, prompt, int(user_id), width, height, steps) | |
return {"status": "ok", "file": out_path} | |
except Exception as e: | |
return JSONResponse({"error": str(e)}, status_code=500) | |
async def api_edit_image( | |
user_id: Optional[int] = Form(0), | |
image_file: UploadFile = File(...), | |
mask_file: Optional[UploadFile] = File(None), | |
prompt: str = Form("") | |
): | |
try: | |
img_path = await save_upload_to_tmp(image_file) | |
mask_path = None | |
if mask_file: | |
mask_path = await save_upload_to_tmp(mask_file) | |
fn = getattr(AI, "edit_image_inpaint", None) | |
out_path = await call_ai( | |
fn, | |
FileWrapper(img_path), | |
FileWrapper(mask_path) if mask_path else None, | |
prompt, | |
int(user_id) | |
) | |
return {"status": "ok", "file": out_path} | |
except Exception as e: | |
return JSONResponse({"error": str(e)}, status_code=500) | |
async def api_video(user_id: Optional[int] = Form(0), video_file: UploadFile = File(...)): | |
try: | |
path = await save_upload_to_tmp(video_file) | |
fn = getattr(AI, "process_video", None) | |
result = await call_ai(fn, FileWrapper(path), int(user_id)) | |
return JSONResponse(result) | |
except Exception as e: | |
return JSONResponse({"error": str(e)}, status_code=500) | |
async def api_file(user_id: Optional[int] = Form(0), file_obj: UploadFile = File(...)): | |
try: | |
path = await save_upload_to_tmp(file_obj) | |
fn = getattr(AI, "process_file", None) | |
result = await call_ai(fn, FileWrapper(path), int(user_id)) | |
return JSONResponse(result) | |
except Exception as e: | |
return JSONResponse({"error": str(e)}, status_code=500) | |
async def api_code(user_id: Optional[int] = Form(0), prompt: str = Form(...), max_tokens: int = Form(512)): | |
try: | |
fn = getattr(AI, "code_complete", None) | |
try: | |
result = await call_ai(fn, int(user_id), prompt, max_tokens) | |
except TypeError: | |
result = await call_ai(fn, prompt, max_tokens=max_tokens) | |
return {"status": "ok", "code": result} | |
except Exception as e: | |
return JSONResponse({"error": str(e)}, status_code=500) | |
# --- Gradio UI --- | |
def gradio_text_fn(text, user_id, lang): | |
fn = getattr(AI, "generate_response", getattr(AI, "process_text", None)) | |
if fn is None: | |
return "Error: text handler not implemented" | |
loop = asyncio.get_event_loop() | |
return loop.run_until_complete(call_ai(fn, text, int(user_id or 0), lang)) | |
with gr.Blocks(title="Multimodal Bot (UI)") as demo: | |
gr.Markdown("# 🧠 Multimodal Bot — UI") | |
with gr.Row(): | |
uid = gr.Textbox(label="User ID", value="0") | |
lang = gr.Dropdown(["en", "zh", "ja", "ko", "es", "fr", "de", "it"], value="en", label="Language") | |
inp = gr.Textbox(lines=3, label="Message") | |
out = gr.Textbox(lines=6, label="Reply") | |
gr.Button("Send").click(gradio_text_fn, [inp, uid, lang], out) | |
# Mount Gradio to FastAPI | |
app = gr.mount_gradio_app(app, demo, path="/") | |
# --- Local Run Only --- | |
if __name__ == "__main__": | |
import uvicorn | |
port = int(os.environ.get("PORT", get_free_port())) | |
uvicorn.run("app:app", host="0.0.0.0", port=port, reload=True) |