Princeaka's picture
Update app.py
a657484 verified
raw
history blame
9.58 kB
# app.py - Hybrid Gradio + FastAPI wrapper for multimodal_module.py
import os
import shutil
import asyncio
import json
from typing import Optional
import gradio as gr
from fastapi import FastAPI, Request
from multimodal_module import MultiModalChatModule
# Instantiate AI
AI = MultiModalChatModule()
# ============================================================
# Helper: File wrapper for Gradio uploads
# ============================================================
class GradioFileWrapper:
def __init__(self, gr_file):
if isinstance(gr_file, str):
self._path = gr_file
else:
try:
self._path = gr_file.name
except Exception:
try:
self._path = gr_file["name"]
except Exception:
raise ValueError("Unsupported file object from Gradio")
async def download_to_drive(self, dst_path: str) -> None:
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, shutil.copyfile, self._path, dst_path)
# ============================================================
# Async-safe helper
# ============================================================
def run_async(coro):
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop = asyncio.get_event_loop()
return loop.run_until_complete(coro)
# ============================================================
# Callback functions (used by Gradio & API)
# ============================================================
def text_chat(user_id: Optional[int], text: str, lang: str = "en"):
try:
uid = int(user_id) if user_id else 0
reply = run_async(AI.generate_response(text, uid, lang))
return reply
except Exception as e:
return f"Error: {e}"
def voice_process(user_id: Optional[int], audio_file):
try:
uid = int(user_id) if user_id else 0
wrapper = GradioFileWrapper(audio_file)
result = run_async(AI.process_voice_message(wrapper, uid))
return json.dumps(result, ensure_ascii=False, indent=2)
except Exception as e:
return f"Error: {e}"
def generate_voice(user_id: Optional[int], reply_text: str, fmt: str = "ogg"):
try:
uid = int(user_id) if user_id else 0
path = run_async(AI.generate_voice_reply(reply_text, uid, fmt))
return path
except Exception as e:
return None, f"Error: {e}"
def image_caption(user_id: Optional[int], image_file):
try:
uid = int(user_id) if user_id else 0
wrapper = GradioFileWrapper(image_file)
caption = run_async(AI.process_image_message(wrapper, uid))
return caption
except Exception as e:
return f"Error: {e}"
def generate_image(user_id: Optional[int], prompt: str, width: int = 512, height: int = 512, steps: int = 30):
try:
uid = int(user_id) if user_id else 0
path = run_async(AI.generate_image_from_text(prompt, uid, width=width, height=height, steps=steps))
return path
except Exception as e:
return f"Error: {e}"
def edit_image(user_id: Optional[int], image_file, mask_file, prompt: str = ""):
try:
uid = int(user_id) if user_id else 0
img_w = GradioFileWrapper(image_file)
mask_w = GradioFileWrapper(mask_file) if mask_file else None
path = run_async(AI.edit_image_inpaint(img_w, mask_w, prompt, uid))
return path
except Exception as e:
return f"Error: {e}"
def process_video(user_id: Optional[int], video_file):
try:
uid = int(user_id) if user_id else 0
wrapper = GradioFileWrapper(video_file)
res = run_async(AI.process_video(wrapper, uid))
return json.dumps(res, ensure_ascii=False, indent=2)
except Exception as e:
return f"Error: {e}"
def process_file(user_id: Optional[int], file_obj):
try:
uid = int(user_id) if user_id else 0
w = GradioFileWrapper(file_obj)
res = run_async(AI.process_file(w, uid))
return json.dumps(res, ensure_ascii=False, indent=2)
except Exception as e:
return f"Error: {e}"
def code_complete(user_id: Optional[int], prompt: str, max_tokens: int = 512):
try:
uid = int(user_id) if user_id else 0
out = run_async(AI.code_complete(prompt, max_tokens=max_tokens))
return out
except Exception as e:
return f"Error: {e}"
# ============================================================
# FastAPI public API
# ============================================================
api = FastAPI()
@api.post("/api/predict")
async def api_predict(request: Request):
try:
data = await request.json()
user_id = data.get("user_id", 0)
text = data.get("text", "")
lang = data.get("lang", "en")
reply = text_chat(user_id, text, lang)
return {"status": "ok", "reply": reply}
except Exception as e:
return {"status": "error", "message": str(e)}
# ============================================================
# Gradio UI
# ============================================================
with gr.Blocks(title="Multimodal Bot (Gradio)") as demo:
gr.Markdown("# 🧠 Multimodal Bot\nInteract via text, voice, images, video, or files.")
with gr.Tab("πŸ’¬ Text Chat"):
with gr.Row():
user_id_txt = gr.Textbox(label="User ID (optional)", placeholder="0")
lang_sel = gr.Dropdown(choices=["en","zh","ja","ko","es","fr","de","it"], value="en", label="Language")
txt_in = gr.Textbox(label="Your message", lines=4)
txt_out = gr.Textbox(label="Bot reply", lines=6)
gr.Button("Send").click(text_chat, [user_id_txt, txt_in, lang_sel], txt_out)
with gr.Tab("🎀 Voice (Transcribe + Emotion)"):
user_id_voice = gr.Textbox(label="User ID (optional)", placeholder="0")
voice_in = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Record or upload voice (.ogg/.wav)")
voice_out = gr.Textbox(label="Result JSON")
gr.Button("Process Voice").click(voice_process, [user_id_voice, voice_in], voice_out)
with gr.Tab("πŸ”Š Voice Reply (TTS)"):
user_id_vr = gr.Textbox(label="User ID (optional)", placeholder="0")
vr_text = gr.Textbox(label="Text to speak", lines=4)
vr_fmt = gr.Dropdown(choices=["ogg","wav","mp3"], value="ogg", label="Format")
vr_audio = gr.Audio(label="Generated Voice")
gr.Button("Generate Voice").click(generate_voice, [user_id_vr, vr_text, vr_fmt], vr_audio)
with gr.Tab("πŸ–ΌοΈ Image Caption"):
user_id_img = gr.Textbox(label="User ID (optional)", placeholder="0")
img_in = gr.Image(type="filepath", label="Upload Image")
img_out = gr.Textbox(label="Caption")
gr.Button("Caption Image").click(image_caption, [user_id_img, img_in], img_out)
with gr.Tab("🎨 Image Generate"):
user_id_gi = gr.Textbox(label="User ID (optional)", placeholder="0")
prompt_in = gr.Textbox(label="Prompt", lines=3)
width = gr.Slider(256, 1024, 512, step=64, label="Width")
height = gr.Slider(256, 1024, 512, step=64, label="Height")
steps = gr.Slider(10, 50, 30, step=5, label="Steps")
gen_out = gr.Image(type="filepath", label="Generated image")
gr.Button("Generate").click(generate_image, [user_id_gi, prompt_in, width, height, steps], gen_out)
with gr.Tab("✏️ Image Edit (Inpaint)"):
user_id_ie = gr.Textbox(label="User ID (optional)", placeholder="0")
edit_img = gr.Image(type="filepath", label="Image to edit")
edit_mask = gr.Image(type="filepath", label="Mask (optional)")
edit_prompt = gr.Textbox(label="Prompt", lines=2)
edit_out = gr.Image(type="filepath", label="Edited image")
gr.Button("Edit Image").click(edit_image, [user_id_ie, edit_img, edit_mask, edit_prompt], edit_out)
with gr.Tab("πŸŽ₯ Video"):
user_id_vid = gr.Textbox(label="User ID (optional)", placeholder="0")
vid_in = gr.Video(label="Upload video")
vid_out = gr.Textbox(label="Result JSON")
gr.Button("Process Video").click(process_video, [user_id_vid, vid_in], vid_out)
with gr.Tab("πŸ“„ Files (PDF/DOCX/TXT)"):
user_id_file = gr.Textbox(label="User ID (optional)", placeholder="0")
file_in = gr.File(label="Upload file")
file_out = gr.Textbox(label="Result JSON")
gr.Button("Process File").click(process_file, [user_id_file, file_in], file_out)
with gr.Tab("πŸ’» Code Generation"):
user_id_code = gr.Textbox(label="User ID (optional)", placeholder="0")
code_prompt = gr.Textbox(label="Code prompt", lines=6)
code_out = gr.Textbox(label="Generated code", lines=12)
gr.Button("Generate Code").click(code_complete, [user_id_code, code_prompt], code_out)
gr.Markdown("----\nThis Space runs your exact `multimodal_module.py`. First requests may take longer due to model loading.")
# ============================================================
# Launch both API + Gradio
# ============================================================
import uvicorn
from threading import Thread
def start_api():
uvicorn.run(api, host="0.0.0.0", port=8000)
# Start FastAPI in a separate thread
Thread(target=start_api, daemon=True).start()
# Launch Gradio
demo.queue()
demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))