Princeaka's picture
Update app.py
b2130a6 verified
raw
history blame
7.57 kB
# app.py -- HF-ready single-server FastAPI + Gradio mounted app (no double server conflict)
import os
import shutil
import asyncio
import inspect
from typing import Optional
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.responses import JSONResponse
import gradio as gr
import uvicorn
# Import your multimodal module
from multimodal_module import MultiModalChatModule
# Instantiate AI module
AI = MultiModalChatModule()
# ---------- Helpers ----------
TMP_DIR = "/tmp"
os.makedirs(TMP_DIR, exist_ok=True)
class FileWrapper:
"""Simple path wrapper for AI methods."""
def __init__(self, path: str):
self._path = path
async def download_to_drive(self, dst_path: str) -> None:
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, shutil.copyfile, self._path, dst_path)
async def save_upload_to_tmp(up: UploadFile) -> str:
"""Save FastAPI UploadFile to /tmp and return path."""
if not up or not up.filename:
raise ValueError("No file uploaded")
dest = os.path.join(TMP_DIR, up.filename)
data = await up.read()
with open(dest, "wb") as f:
f.write(data)
return dest
async def call_ai(fn, *args, **kwargs):
"""Run AI method whether it's sync or async."""
if fn is None:
raise AttributeError("Requested AI method not implemented")
if inspect.iscoroutinefunction(fn):
return await fn(*args, **kwargs)
return await asyncio.to_thread(lambda: fn(*args, **kwargs))
# ---------- FastAPI ----------
app = FastAPI(title="Multimodal Module API")
# CORS (if you call this from the browser)
from fastapi.middleware.cors import CORSMiddleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # tighten for prod
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ----------------- API endpoints -----------------
@app.post("/api/predict")
async def api_predict(inputs: str = Form(...), user_id: Optional[int] = Form(0), lang: str = Form("en")):
try:
fn = getattr(AI, "generate_response", getattr(AI, "process_text", None))
reply = await call_ai(fn, inputs, int(user_id), lang)
return {"data": [reply]}
except Exception as e:
return JSONResponse({"error": str(e)}, status_code=500)
@app.post("/api/text")
async def api_text(text: str = Form(...), user_id: Optional[int] = Form(0), lang: str = Form("en")):
try:
fn = getattr(AI, "generate_response", getattr(AI, "process_text", None))
reply = await call_ai(fn, text, int(user_id), lang)
return {"status": "ok", "reply": reply}
except Exception as e:
return JSONResponse({"error": str(e)}, status_code=500)
@app.post("/api/voice")
async def api_voice(user_id: Optional[int] = Form(0), audio_file: UploadFile = File(...)):
try:
path = await save_upload_to_tmp(audio_file)
fn = getattr(AI, "process_voice_message", None)
result = await call_ai(fn, FileWrapper(path), int(user_id))
return JSONResponse(result)
except Exception as e:
return JSONResponse({"error": str(e)}, status_code=500)
@app.post("/api/voice_reply")
async def api_voice_reply(user_id: Optional[int] = Form(0), reply_text: str = Form(...), fmt: str = Form("ogg")):
try:
fn = getattr(AI, "generate_voice_reply", None)
result = await call_ai(fn, reply_text, int(user_id), fmt)
return {"status": "ok", "file": result}
except Exception as e:
return JSONResponse({"error": str(e)}, status_code=500)
@app.post("/api/image_caption")
async def api_image_caption(user_id: Optional[int] = Form(0), image_file: UploadFile = File(...)):
try:
path = await save_upload_to_tmp(image_file)
fn = getattr(AI, "process_image_message", None)
caption = await call_ai(fn, FileWrapper(path), int(user_id))
return {"status": "ok", "caption": caption}
except Exception as e:
return JSONResponse({"error": str(e)}, status_code=500)
@app.post("/api/generate_image")
async def api_generate_image(user_id: Optional[int] = Form(0), prompt: str = Form(...), width: int = Form(512), height: int = Form(512), steps: int = Form(30)):
try:
fn = getattr(AI, "generate_image_from_text", None)
out_path = await call_ai(fn, prompt, int(user_id), width, height, steps)
return {"status": "ok", "file": out_path}
except Exception as e:
return JSONResponse({"error": str(e)}, status_code=500)
@app.post("/api/edit_image")
async def api_edit_image(user_id: Optional[int] = Form(0), image_file: UploadFile = File(...), mask_file: Optional[UploadFile] = File(None), prompt: str = Form("")):
try:
img_path = await save_upload_to_tmp(image_file)
mask_path = None
if mask_file:
mask_path = await save_upload_to_tmp(mask_file)
fn = getattr(AI, "edit_image_inpaint", None)
out_path = await call_ai(fn, FileWrapper(img_path), FileWrapper(mask_path) if mask_path else None, prompt, int(user_id))
return {"status": "ok", "file": out_path}
except Exception as e:
return JSONResponse({"error": str(e)}, status_code=500)
@app.post("/api/video")
async def api_video(user_id: Optional[int] = Form(0), video_file: UploadFile = File(...)):
try:
path = await save_upload_to_tmp(video_file)
fn = getattr(AI, "process_video", None)
result = await call_ai(fn, FileWrapper(path), int(user_id))
return JSONResponse(result)
except Exception as e:
return JSONResponse({"error": str(e)}, status_code=500)
@app.post("/api/file")
async def api_file(user_id: Optional[int] = Form(0), file_obj: UploadFile = File(...)):
try:
path = await save_upload_to_tmp(file_obj)
fn = getattr(AI, "process_file", None)
result = await call_ai(fn, FileWrapper(path), int(user_id))
return JSONResponse(result)
except Exception as e:
return JSONResponse({"error": str(e)}, status_code=500)
@app.post("/api/code")
async def api_code(user_id: Optional[int] = Form(0), prompt: str = Form(...), max_tokens: int = Form(512)):
try:
fn = getattr(AI, "code_complete", None)
try:
result = await call_ai(fn, int(user_id), prompt, max_tokens)
except TypeError:
result = await call_ai(fn, prompt, max_tokens=max_tokens)
return {"status": "ok", "code": result}
except Exception as e:
return JSONResponse({"error": str(e)}, status_code=500)
# ---------- Minimal Gradio UI ----------
def gradio_text_fn(text, user_id, lang):
fn = getattr(AI, "generate_response", getattr(AI, "process_text", None))
if fn is None:
return "Error: text handler not implemented"
loop = asyncio.get_event_loop()
return loop.run_until_complete(call_ai(fn, text, int(user_id or 0), lang))
with gr.Blocks(title="Multimodal Bot (UI)") as demo:
gr.Markdown("# 🧠 Multimodal Bot — UI")
with gr.Row():
txt_uid = gr.Textbox(label="User ID", value="0")
txt_lang = gr.Dropdown(["en","zh","ja","ko","es","fr","de","it"], value="en", label="Language")
inp = gr.Textbox(lines=3, label="Message")
out = gr.Textbox(lines=6, label="Reply")
gr.Button("Send").click(gradio_text_fn, [inp, txt_uid, txt_lang], out)
# Mount Gradio at /
app = gr.mount_gradio_app(app, demo, path="/")
# ---------- Run ----------
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
uvicorn.run("app:app", host="0.0.0.0", port=port, reload=False)