Princeaka's picture
Update app.py
d6333ba verified
raw
history blame
4.66 kB
import os
import shutil
import asyncio
from typing import Optional
import gradio as gr
from fastapi import FastAPI, UploadFile, Form
import uvicorn
import socket
from multimodal_module import MultiModalChatModule
# Initialize AI module
AI = MultiModalChatModule()
# ---------------------------
# Utility
# ---------------------------
class GradioFileWrapper:
def __init__(self, file_path):
self._path = file_path
async def download_to_drive(self, dst_path: str):
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, shutil.copyfile, self._path, dst_path)
def run_async(coro):
return asyncio.run(coro)
def get_free_port(default=7860):
"""Find a free port if default is busy."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
try:
s.bind(("0.0.0.0", default))
return default
except OSError:
s.bind(("0.0.0.0", 0))
return s.getsockname()[1]
# ---------------------------
# FastAPI API for external apps
# ---------------------------
api = FastAPI()
@api.post("/api/text_chat")
async def api_text_chat(
user_id: Optional[int] = Form(0),
text: str = Form(...),
lang: str = Form("en")
):
try:
reply = await AI.generate_response(text, int(user_id), lang)
return {"reply": reply}
except Exception as e:
return {"error": str(e)}
@api.post("/api/image_caption")
async def api_image_caption(user_id: Optional[int] = Form(0), image: UploadFile = None):
try:
temp_path = f"/tmp/{image.filename}"
with open(temp_path, "wb") as f:
f.write(await image.read())
wrapper = GradioFileWrapper(temp_path)
caption = await AI.process_image_message(wrapper, int(user_id))
return {"caption": caption}
except Exception as e:
return {"error": str(e)}
@api.post("/api/voice_process")
async def api_voice_process(user_id: Optional[int] = Form(0), audio: UploadFile = None):
try:
temp_path = f"/tmp/{audio.filename}"
with open(temp_path, "wb") as f:
f.write(await audio.read())
wrapper = GradioFileWrapper(temp_path)
reply = await AI.process_voice_message(wrapper, int(user_id))
return {"reply": reply}
except Exception as e:
return {"error": str(e)}
@api.post("/api/video_process")
async def api_video_process(user_id: Optional[int] = Form(0), video: UploadFile = None):
try:
temp_path = f"/tmp/{video.filename}"
with open(temp_path, "wb") as f:
f.write(await video.read())
wrapper = GradioFileWrapper(temp_path)
reply = await AI.process_video_message(wrapper, int(user_id))
return {"reply": reply}
except Exception as e:
return {"error": str(e)}
@api.post("/api/file_process")
async def api_file_process(user_id: Optional[int] = Form(0), file: UploadFile = None):
try:
temp_path = f"/tmp/{file.filename}"
with open(temp_path, "wb") as f:
f.write(await file.read())
wrapper = GradioFileWrapper(temp_path)
reply = await AI.process_file_message(wrapper, int(user_id))
return {"reply": reply}
except Exception as e:
return {"error": str(e)}
# ---------------------------
# Gradio UI
# ---------------------------
with gr.Blocks(title="Multimodal Bot") as demo:
gr.Markdown("# 🧠 Multimodal Bot\nInteract via text, voice, images, video, or files.")
with gr.Tab("πŸ’¬ Text Chat"):
user_id_txt = gr.Textbox(label="User ID", placeholder="0")
lang_sel = gr.Dropdown(choices=["en","zh","ja","ko","es","fr","de","it"], value="en", label="Language")
txt_in = gr.Textbox(label="Your message", lines=4)
txt_out = gr.Textbox(label="Bot reply", lines=6)
gr.Button("Send").click(lambda uid, txt, lang: run_async(AI.generate_response(txt, int(uid or 0), lang)),
[user_id_txt, txt_in, lang_sel], txt_out)
with gr.Tab("πŸ–Ό Image Captioning"):
user_id_img = gr.Textbox(label="User ID", placeholder="0")
img_in = gr.Image(type="filepath", label="Upload an image")
img_out = gr.Textbox(label="Caption")
gr.Button("Caption").click(lambda uid, img: run_async(AI.process_image_message(GradioFileWrapper(img), int(uid or 0))),
[user_id_img, img_in], img_out)
# ---------------------------
# Mount Gradio UI to FastAPI
# ---------------------------
api = gr.mount_gradio_app(api, demo, path="/")
if __name__ == "__main__":
port = get_free_port()
uvicorn.run(api, host="0.0.0.0", port=port)