# app.py — Full multimodal FastAPI + Gradio with auto free port pick import os import shutil import asyncio import inspect import socket from typing import Optional from fastapi import FastAPI, UploadFile, File, Form from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware import gradio as gr import uvicorn from multimodal_module import MultiModalChatModule AI = MultiModalChatModule() TMP_DIR = "/tmp" os.makedirs(TMP_DIR, exist_ok=True) # --- Helpers --- class FileWrapper: def __init__(self, path: str): self._path = path async def download_to_drive(self, dst_path: str) -> None: loop = asyncio.get_event_loop() await loop.run_in_executor(None, shutil.copyfile, self._path, dst_path) async def save_upload_to_tmp(up: UploadFile) -> str: if not up or not up.filename: raise ValueError("No file uploaded") dest = os.path.join(TMP_DIR, up.filename) data = await up.read() with open(dest, "wb") as f: f.write(data) return dest async def call_ai(fn, *args, **kwargs): if fn is None: raise AttributeError("Requested AI method not implemented") if inspect.iscoroutinefunction(fn): return await fn(*args, **kwargs) return await asyncio.to_thread(lambda: fn(*args, **kwargs)) # --- FastAPI app --- app = FastAPI(title="Multimodal Module API") app.add_middleware( CORSMiddleware, allow_origins=["*"], # Change this in production! allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # --- API Endpoints --- @app.post("/api/text") async def api_text(text: str = Form(...), user_id: Optional[int] = Form(0), lang: str = Form("en")): try: fn = getattr(AI, "generate_response", getattr(AI, "process_text", None)) reply = await call_ai(fn, text, int(user_id), lang) return {"status": "ok", "reply": reply} except Exception as e: return JSONResponse({"error": str(e)}, status_code=500) @app.post("/api/voice") async def api_voice(user_id: Optional[int] = Form(0), audio_file: UploadFile = File(...)): try: path = await save_upload_to_tmp(audio_file) fn = getattr(AI, "process_voice_message", None) result = await call_ai(fn, FileWrapper(path), int(user_id)) return JSONResponse(result) except Exception as e: return JSONResponse({"error": str(e)}, status_code=500) @app.post("/api/voice_reply") async def api_voice_reply(user_id: Optional[int] = Form(0), reply_text: str = Form(...), fmt: str = Form("ogg")): try: fn = getattr(AI, "generate_voice_reply", None) result = await call_ai(fn, reply_text, int(user_id), fmt) return {"status": "ok", "file": result} except Exception as e: return JSONResponse({"error": str(e)}, status_code=500) @app.post("/api/image_caption") async def api_image_caption(user_id: Optional[int] = Form(0), image_file: UploadFile = File(...)): try: path = await save_upload_to_tmp(image_file) fn = getattr(AI, "process_image_message", None) caption = await call_ai(fn, FileWrapper(path), int(user_id)) return {"status": "ok", "caption": caption} except Exception as e: return JSONResponse({"error": str(e)}, status_code=500) @app.post("/api/generate_image") async def api_generate_image( user_id: Optional[int] = Form(0), prompt: str = Form(...), width: int = Form(512), height: int = Form(512), steps: int = Form(30) ): try: fn = getattr(AI, "generate_image_from_text", None) out_path = await call_ai(fn, prompt, int(user_id), width, height, steps) return {"status": "ok", "file": out_path} except Exception as e: return JSONResponse({"error": str(e)}, status_code=500) @app.post("/api/edit_image") async def api_edit_image( user_id: Optional[int] = Form(0), image_file: UploadFile = File(...), mask_file: Optional[UploadFile] = File(None), prompt: str = Form("") ): try: img_path = await save_upload_to_tmp(image_file) mask_path = None if mask_file: mask_path = await save_upload_to_tmp(mask_file) fn = getattr(AI, "edit_image_inpaint", None) out_path = await call_ai( fn, FileWrapper(img_path), FileWrapper(mask_path) if mask_path else None, prompt, int(user_id) ) return {"status": "ok", "file": out_path} except Exception as e: return JSONResponse({"error": str(e)}, status_code=500) @app.post("/api/video") async def api_video(user_id: Optional[int] = Form(0), video_file: UploadFile = File(...)): try: path = await save_upload_to_tmp(video_file) fn = getattr(AI, "process_video", None) result = await call_ai(fn, FileWrapper(path), int(user_id)) return JSONResponse(result) except Exception as e: return JSONResponse({"error": str(e)}, status_code=500) @app.post("/api/file") async def api_file(user_id: Optional[int] = Form(0), file_obj: UploadFile = File(...)): try: path = await save_upload_to_tmp(file_obj) fn = getattr(AI, "process_file", None) result = await call_ai(fn, FileWrapper(path), int(user_id)) return JSONResponse(result) except Exception as e: return JSONResponse({"error": str(e)}, status_code=500) @app.post("/api/code") async def api_code(user_id: Optional[int] = Form(0), prompt: str = Form(...), max_tokens: int = Form(512)): try: fn = getattr(AI, "code_complete", None) try: result = await call_ai(fn, int(user_id), prompt, max_tokens) except TypeError: result = await call_ai(fn, prompt, max_tokens=max_tokens) return {"status": "ok", "code": result} except Exception as e: return JSONResponse({"error": str(e)}, status_code=500) # --- Gradio UI --- def gradio_text_fn(text, user_id, lang): fn = getattr(AI, "generate_response", getattr(AI, "process_text", None)) if fn is None: return "Error: text handler not implemented" loop = asyncio.get_event_loop() return loop.run_until_complete(call_ai(fn, text, int(user_id or 0), lang)) with gr.Blocks(title="Multimodal Bot (UI)") as demo: gr.Markdown("# 🧠 Multimodal Bot — UI") with gr.Row(): uid = gr.Textbox(label="User ID", value="0") lang = gr.Dropdown(["en", "zh", "ja", "ko", "es", "fr", "de", "it"], value="en", label="Language") inp = gr.Textbox(lines=3, label="Message") out = gr.Textbox(lines=6, label="Reply") gr.Button("Send").click(gradio_text_fn, [inp, uid, lang], out) app = gr.mount_gradio_app(app, demo, path="/") # --- Auto free port finder --- def get_free_port(): s = socket.socket() s.bind(("", 0)) port = s.getsockname()[1] s.close() return port if __name__ == "__main__": port = int(os.environ.get("PORT", get_free_port())) uvicorn.run("app:app", host="0.0.0.0", port=port)