Spaces:
Running
Running
import os | |
import shutil | |
import asyncio | |
from typing import Optional | |
import gradio as gr | |
from fastapi import FastAPI, UploadFile, Form | |
import uvicorn | |
import socket | |
from multimodal_module import MultiModalChatModule | |
# Initialize AI module | |
AI = MultiModalChatModule() | |
# --------------------------- | |
# Utility | |
# --------------------------- | |
class GradioFileWrapper: | |
def __init__(self, file_path): | |
self._path = file_path | |
async def download_to_drive(self, dst_path: str): | |
loop = asyncio.get_event_loop() | |
await loop.run_in_executor(None, shutil.copyfile, self._path, dst_path) | |
def run_async(coro): | |
return asyncio.run(coro) | |
def get_free_port(default=7860): | |
"""Find a free port if default is busy.""" | |
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: | |
try: | |
s.bind(("0.0.0.0", default)) | |
return default | |
except OSError: | |
s.bind(("0.0.0.0", 0)) | |
return s.getsockname()[1] | |
# --------------------------- | |
# FastAPI API for external apps | |
# --------------------------- | |
api = FastAPI() | |
async def api_text_chat( | |
user_id: Optional[int] = Form(0), | |
text: str = Form(...), | |
lang: str = Form("en") | |
): | |
try: | |
reply = await AI.generate_response(text, int(user_id), lang) | |
return {"reply": reply} | |
except Exception as e: | |
return {"error": str(e)} | |
async def api_image_caption(user_id: Optional[int] = Form(0), image: UploadFile = None): | |
try: | |
temp_path = f"/tmp/{image.filename}" | |
with open(temp_path, "wb") as f: | |
f.write(await image.read()) | |
wrapper = GradioFileWrapper(temp_path) | |
caption = await AI.process_image_message(wrapper, int(user_id)) | |
return {"caption": caption} | |
except Exception as e: | |
return {"error": str(e)} | |
async def api_voice_process(user_id: Optional[int] = Form(0), audio: UploadFile = None): | |
try: | |
temp_path = f"/tmp/{audio.filename}" | |
with open(temp_path, "wb") as f: | |
f.write(await audio.read()) | |
wrapper = GradioFileWrapper(temp_path) | |
reply = await AI.process_voice_message(wrapper, int(user_id)) | |
return {"reply": reply} | |
except Exception as e: | |
return {"error": str(e)} | |
async def api_video_process(user_id: Optional[int] = Form(0), video: UploadFile = None): | |
try: | |
temp_path = f"/tmp/{video.filename}" | |
with open(temp_path, "wb") as f: | |
f.write(await video.read()) | |
wrapper = GradioFileWrapper(temp_path) | |
reply = await AI.process_video_message(wrapper, int(user_id)) | |
return {"reply": reply} | |
except Exception as e: | |
return {"error": str(e)} | |
async def api_file_process(user_id: Optional[int] = Form(0), file: UploadFile = None): | |
try: | |
temp_path = f"/tmp/{file.filename}" | |
with open(temp_path, "wb") as f: | |
f.write(await file.read()) | |
wrapper = GradioFileWrapper(temp_path) | |
reply = await AI.process_file_message(wrapper, int(user_id)) | |
return {"reply": reply} | |
except Exception as e: | |
return {"error": str(e)} | |
# --------------------------- | |
# Gradio UI | |
# --------------------------- | |
with gr.Blocks(title="Multimodal Bot") as demo: | |
gr.Markdown("# π§ Multimodal Bot\nInteract via text, voice, images, video, or files.") | |
with gr.Tab("π¬ Text Chat"): | |
user_id_txt = gr.Textbox(label="User ID", placeholder="0") | |
lang_sel = gr.Dropdown(choices=["en","zh","ja","ko","es","fr","de","it"], value="en", label="Language") | |
txt_in = gr.Textbox(label="Your message", lines=4) | |
txt_out = gr.Textbox(label="Bot reply", lines=6) | |
gr.Button("Send").click(lambda uid, txt, lang: run_async(AI.generate_response(txt, int(uid or 0), lang)), | |
[user_id_txt, txt_in, lang_sel], txt_out) | |
with gr.Tab("πΌ Image Captioning"): | |
user_id_img = gr.Textbox(label="User ID", placeholder="0") | |
img_in = gr.Image(type="filepath", label="Upload an image") | |
img_out = gr.Textbox(label="Caption") | |
gr.Button("Caption").click(lambda uid, img: run_async(AI.process_image_message(GradioFileWrapper(img), int(uid or 0))), | |
[user_id_img, img_in], img_out) | |
# --------------------------- | |
# Mount Gradio UI to FastAPI | |
# --------------------------- | |
api = gr.mount_gradio_app(api, demo, path="/") | |
if __name__ == "__main__": | |
port = get_free_port() | |
uvicorn.run(api, host="0.0.0.0", port=port) |