Spaces:
Running
Running
# app.py β FastAPI + Gradio (External API + UI) | |
import os | |
import shutil | |
import asyncio | |
import inspect | |
from typing import Optional | |
from fastapi import FastAPI, UploadFile, File, Form | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import JSONResponse | |
import gradio as gr | |
from multimodal_module import MultiModalChatModule | |
# Instantiate AI module | |
AI = MultiModalChatModule() | |
TMP_DIR = "/tmp" | |
os.makedirs(TMP_DIR, exist_ok=True) | |
# --- File wrapper --- | |
class FileWrapper: | |
def __init__(self, path: str): | |
self._path = path | |
async def download_to_drive(self, dst_path: str): | |
loop = asyncio.get_event_loop() | |
await loop.run_in_executor(None, shutil.copyfile, self._path, dst_path) | |
# --- Save uploaded file --- | |
async def save_upload(up: UploadFile) -> str: | |
if not up or not up.filename: | |
raise ValueError("No file uploaded") | |
dest = os.path.join(TMP_DIR, up.filename) | |
data = await up.read() | |
with open(dest, "wb") as f: | |
f.write(data) | |
return dest | |
# --- Call AI (sync or async) --- | |
async def call_ai(fn, *args, **kwargs): | |
if fn is None: | |
raise AttributeError("Requested AI method not implemented") | |
if inspect.iscoroutinefunction(fn): | |
return await fn(*args, **kwargs) | |
return await asyncio.to_thread(lambda: fn(*args, **kwargs)) | |
# === FASTAPI APP === | |
app = FastAPI(title="Multimodal API") | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # change for production | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# --- API Endpoints --- | |
async def api_text(text: str = Form(...), user_id: Optional[int] = Form(0), lang: str = Form("en")): | |
try: | |
fn = getattr(AI, "generate_response", getattr(AI, "process_text", None)) | |
reply = await call_ai(fn, text, int(user_id), lang) | |
return {"status": "ok", "reply": reply} | |
except Exception as e: | |
return JSONResponse({"error": str(e)}, status_code=500) | |
async def api_voice(user_id: Optional[int] = Form(0), audio_file: UploadFile = File(...)): | |
try: | |
path = await save_upload(audio_file) | |
fn = getattr(AI, "process_voice_message", None) | |
result = await call_ai(fn, FileWrapper(path), int(user_id)) | |
return {"status": "ok", "result": result} | |
except Exception as e: | |
return JSONResponse({"error": str(e)}, status_code=500) | |
async def api_voice_reply(user_id: Optional[int] = Form(0), reply_text: str = Form(...), fmt: str = Form("ogg")): | |
try: | |
fn = getattr(AI, "generate_voice_reply", None) | |
result = await call_ai(fn, reply_text, int(user_id), fmt) | |
return {"status": "ok", "file": result} | |
except Exception as e: | |
return JSONResponse({"error": str(e)}, status_code=500) | |
async def api_image_caption(user_id: Optional[int] = Form(0), image_file: UploadFile = File(...)): | |
try: | |
path = await save_upload(image_file) | |
fn = getattr(AI, "process_image_message", None) | |
caption = await call_ai(fn, FileWrapper(path), int(user_id)) | |
return {"status": "ok", "caption": caption} | |
except Exception as e: | |
return JSONResponse({"error": str(e)}, status_code=500) | |
async def api_generate_image(user_id: Optional[int] = Form(0), prompt: str = Form(...), width: int = Form(512), height: int = Form(512), steps: int = Form(30)): | |
try: | |
fn = getattr(AI, "generate_image_from_text", None) | |
out_path = await call_ai(fn, prompt, int(user_id), width, height, steps) | |
return {"status": "ok", "file": out_path} | |
except Exception as e: | |
return JSONResponse({"error": str(e)}, status_code=500) | |
async def api_edit_image(user_id: Optional[int] = Form(0), image_file: UploadFile = File(...), mask_file: Optional[UploadFile] = File(None), prompt: str = Form("")): | |
try: | |
img_path = await save_upload(image_file) | |
mask_path = None | |
if mask_file: | |
mask_path = await save_upload(mask_file) | |
fn = getattr(AI, "edit_image_inpaint", None) | |
out_path = await call_ai(fn, FileWrapper(img_path), FileWrapper(mask_path) if mask_path else None, prompt, int(user_id)) | |
return {"status": "ok", "file": out_path} | |
except Exception as e: | |
return JSONResponse({"error": str(e)}, status_code=500) | |
async def api_video(user_id: Optional[int] = Form(0), video_file: UploadFile = File(...)): | |
try: | |
path = await save_upload(video_file) | |
fn = getattr(AI, "process_video", None) | |
result = await call_ai(fn, FileWrapper(path), int(user_id)) | |
return {"status": "ok", "result": result} | |
except Exception as e: | |
return JSONResponse({"error": str(e)}, status_code=500) | |
async def api_file(user_id: Optional[int] = Form(0), file_obj: UploadFile = File(...)): | |
try: | |
path = await save_upload(file_obj) | |
fn = getattr(AI, "process_file", None) | |
result = await call_ai(fn, FileWrapper(path), int(user_id)) | |
return {"status": "ok", "result": result} | |
except Exception as e: | |
return JSONResponse({"error": str(e)}, status_code=500) | |
async def api_code(user_id: Optional[int] = Form(0), prompt: str = Form(...), max_tokens: int = Form(512)): | |
try: | |
fn = getattr(AI, "code_complete", None) | |
try: | |
result = await call_ai(fn, int(user_id), prompt, max_tokens) | |
except TypeError: | |
result = await call_ai(fn, prompt, max_tokens=max_tokens) | |
return {"status": "ok", "code": result} | |
except Exception as e: | |
return JSONResponse({"error": str(e)}, status_code=500) | |
# === GRADIO UI === | |
def gradio_text_fn(text, user_id, lang): | |
fn = getattr(AI, "generate_response", getattr(AI, "process_text", None)) | |
loop = asyncio.get_event_loop() | |
return loop.run_until_complete(call_ai(fn, text, int(user_id or 0), lang)) | |
with gr.Blocks(title="Multimodal Bot") as demo: | |
gr.Markdown("# π§ Multimodal Bot β UI") | |
with gr.Row(): | |
uid = gr.Textbox(label="User ID", value="0") | |
lang = gr.Dropdown(["en", "zh", "ja", "ko", "es", "fr", "de", "it"], value="en", label="Language") | |
inp = gr.Textbox(lines=3, label="Message") | |
out = gr.Textbox(lines=6, label="Reply") | |
gr.Button("Send").click(gradio_text_fn, [inp, uid, lang], out) | |
# Mount Gradio under /ui | |
app = gr.mount_gradio_app(app, demo, path="/ui") |