Spaces:
Running
Running
# app.py -- HF-ready single-server FastAPI + Gradio mounted app (fixed) | |
import os | |
import shutil | |
import asyncio | |
import inspect | |
from typing import Optional | |
from fastapi import FastAPI, UploadFile, File, Form | |
from fastapi.responses import JSONResponse | |
import gradio as gr | |
import uvicorn | |
# Import your real multimodal module | |
from multimodal_module import MultiModalChatModule | |
# Instantiate your AI module | |
AI = MultiModalChatModule() | |
# ---------- Helpers ---------- | |
TMP_DIR = "/tmp" | |
os.makedirs(TMP_DIR, exist_ok=True) | |
class FileWrapper: | |
"""Simple path wrapper compatible with your existing code expectations.""" | |
def __init__(self, path: str): | |
self._path = path | |
async def download_to_drive(self, dst_path: str) -> None: | |
loop = asyncio.get_event_loop() | |
await loop.run_in_executor(None, shutil.copyfile, self._path, dst_path) | |
async def save_upload_to_tmp(up: UploadFile) -> str: | |
"""Save FastAPI UploadFile to /tmp and return path. Uses async read.""" | |
if not up or not up.filename: | |
raise ValueError("UploadFile missing or has no filename") | |
dest = os.path.join(TMP_DIR, up.filename) | |
data = await up.read() # <-- important: async read | |
with open(dest, "wb") as f: | |
f.write(data) | |
return dest | |
async def call_ai(fn, *args, **kwargs): | |
""" | |
Call AI functions safely: if fn is async, await it; if sync, run in thread. | |
If fn is None, raise a clear error. | |
""" | |
if fn is None: | |
raise AttributeError("Requested AI method is not implemented in multimodal_module") | |
if inspect.iscoroutinefunction(fn): | |
return await fn(*args, **kwargs) | |
else: | |
return await asyncio.to_thread(lambda: fn(*args, **kwargs)) | |
# ---------- FastAPI app ---------- | |
app = FastAPI(title="Multimodal Module API") | |
# Optional: allow CORS if external web apps will call this | |
from fastapi.middleware.cors import CORSMiddleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # tighten in production | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# ----------------- API endpoints ----------------- | |
async def api_predict(inputs: str = Form(...), user_id: Optional[int] = Form(0), lang: str = Form("en")): | |
""" | |
HuggingFace-style /predict compatibility. | |
Form field 'inputs' used as text. | |
""" | |
try: | |
fn = getattr(AI, "generate_response", getattr(AI, "process_text", None)) | |
reply = await call_ai(fn, inputs, int(user_id), lang) | |
# HF-style returns "data" array | |
return {"data": [reply]} | |
except Exception as e: | |
return JSONResponse({"error": str(e)}, status_code=500) | |
async def api_text(text: str = Form(...), user_id: Optional[int] = Form(0), lang: str = Form("en")): | |
try: | |
fn = getattr(AI, "generate_response", getattr(AI, "process_text", None)) | |
reply = await call_ai(fn, text, int(user_id), lang) | |
return {"status": "ok", "reply": reply} | |
except Exception as e: | |
return JSONResponse({"error": str(e)}, status_code=500) | |
async def api_voice(user_id: Optional[int] = Form(0), audio_file: UploadFile = File(...)): | |
""" | |
Upload audio file (multipart/form-data). Returns whatever your AI.process_voice_message returns (JSON/dict). | |
""" | |
try: | |
path = await save_upload_to_tmp(audio_file) | |
fn = getattr(AI, "process_voice_message", None) | |
result = await call_ai(fn, FileWrapper(path), int(user_id)) | |
return JSONResponse(result) | |
except Exception as e: | |
return JSONResponse({"error": str(e)}, status_code=500) | |
async def api_voice_reply(user_id: Optional[int] = Form(0), reply_text: str = Form(...), fmt: str = Form("ogg")): | |
try: | |
fn = getattr(AI, "generate_voice_reply", None) | |
result = await call_ai(fn, reply_text, int(user_id), fmt) | |
return {"status": "ok", "file": result} | |
except Exception as e: | |
return JSONResponse({"error": str(e)}, status_code=500) | |
async def api_image_caption(user_id: Optional[int] = Form(0), image_file: UploadFile = File(...)): | |
try: | |
path = await save_upload_to_tmp(image_file) | |
fn = getattr(AI, "process_image_message", None) | |
caption = await call_ai(fn, FileWrapper(path), int(user_id)) | |
return {"status": "ok", "caption": caption} | |
except Exception as e: | |
return JSONResponse({"error": str(e)}, status_code=500) | |
async def api_generate_image(user_id: Optional[int] = Form(0), prompt: str = Form(...), width: int = Form(512), height: int = Form(512), steps: int = Form(30)): | |
try: | |
fn = getattr(AI, "generate_image_from_text", None) | |
out_path = await call_ai(fn, prompt, int(user_id), width, height, steps) | |
return {"status": "ok", "file": out_path} | |
except Exception as e: | |
return JSONResponse({"error": str(e)}, status_code=500) | |
async def api_edit_image(user_id: Optional[int] = Form(0), image_file: UploadFile = File(...), mask_file: Optional[UploadFile] = File(None), prompt: str = Form("")): | |
try: | |
img_path = await save_upload_to_tmp(image_file) | |
mask_path = None | |
if mask_file: | |
mask_path = await save_upload_to_tmp(mask_file) | |
fn = getattr(AI, "edit_image_inpaint", None) | |
out_path = await call_ai(fn, FileWrapper(img_path), FileWrapper(mask_path) if mask_path else None, prompt, int(user_id)) | |
return {"status": "ok", "file": out_path} | |
except Exception as e: | |
return JSONResponse({"error": str(e)}, status_code=500) | |
async def api_video(user_id: Optional[int] = Form(0), video_file: UploadFile = File(...)): | |
try: | |
path = await save_upload_to_tmp(video_file) | |
fn = getattr(AI, "process_video", None) | |
result = await call_ai(fn, FileWrapper(path), int(user_id)) | |
return JSONResponse(result) | |
except Exception as e: | |
return JSONResponse({"error": str(e)}, status_code=500) | |
async def api_file(user_id: Optional[int] = Form(0), file_obj: UploadFile = File(...)): | |
try: | |
path = await save_upload_to_tmp(file_obj) | |
fn = getattr(AI, "process_file", None) | |
result = await call_ai(fn, FileWrapper(path), int(user_id)) | |
return JSONResponse(result) | |
except Exception as e: | |
return JSONResponse({"error": str(e)}, status_code=500) | |
async def api_code(user_id: Optional[int] = Form(0), prompt: str = Form(...), max_tokens: int = Form(512)): | |
try: | |
fn = getattr(AI, "code_complete", None) | |
# Many implementations expect (user_id, prompt, max_tokens) or (prompt, max_tokens) | |
# Try user-first signature first, fallback to prompt-first | |
try: | |
result = await call_ai(fn, int(user_id), prompt, max_tokens) | |
except TypeError: | |
result = await call_ai(fn, prompt, max_tokens=max_tokens) | |
return {"status": "ok", "code": result} | |
except Exception as e: | |
return JSONResponse({"error": str(e)}, status_code=500) | |
# ---------- Minimal Gradio UI (mounted) ---------- | |
def gradio_text_fn(text, user_id, lang): | |
fn = getattr(AI, "generate_response", getattr(AI, "process_text", None)) | |
if fn is None: | |
return "Error: text handler not implemented in multimodal_module" | |
if inspect.iscoroutinefunction(fn): | |
return asyncio.run(call_ai(fn, text, int(user_id or 0), lang)) | |
else: | |
return fn(text, int(user_id or 0), lang) | |
with gr.Blocks(title="Multimodal Bot (UI)") as demo: | |
gr.Markdown("# 🧠 Multimodal Bot — UI") | |
with gr.Row(): | |
txt_uid = gr.Textbox(label="User ID", value="0") | |
txt_lang = gr.Dropdown(["en","zh","ja","ko","es","fr","de","it"], value="en", label="Language") | |
inp = gr.Textbox(lines=3, label="Message") | |
out = gr.Textbox(lines=6, label="Reply") | |
gr.Button("Send").click(gradio_text_fn, [inp, txt_uid, txt_lang], out) | |
# Mount Gradio app at root | |
app = gr.mount_gradio_app(app, demo, path="/") | |
# ---------- Run server (HF Spaces uses this entrypoint) ---------- | |
if __name__ == "__main__": | |
port = int(os.environ.get("PORT", 7860)) | |
uvicorn.run(app, host="0.0.0.0", port=port) |