File size: 4,662 Bytes
e003942
 
 
 
 
 
d6333ba
 
 
 
e003942
 
d6333ba
e003942
 
d6333ba
 
 
e003942
d6333ba
 
 
 
e003942
 
 
 
 
 
d6333ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e003942
d6333ba
 
e003942
d6333ba
e003942
d6333ba
 
e003942
d6333ba
 
 
 
 
 
e003942
d6333ba
e003942
d6333ba
 
e003942
d6333ba
 
 
 
 
 
e003942
d6333ba
e003942
d6333ba
 
e003942
d6333ba
 
 
 
 
 
e003942
d6333ba
e003942
d6333ba
 
e003942
d6333ba
 
 
 
 
 
e003942
d6333ba
e003942
d6333ba
e003942
d6333ba
 
e003942
 
 
d6333ba
 
e003942
 
d6333ba
 
 
 
 
 
e003942
d6333ba
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os
import shutil
import asyncio
from typing import Optional

import gradio as gr
from fastapi import FastAPI, UploadFile, Form
import uvicorn
import socket

from multimodal_module import MultiModalChatModule

# Initialize AI module
AI = MultiModalChatModule()

# ---------------------------
# Utility
# ---------------------------
class GradioFileWrapper:
    def __init__(self, file_path):
        self._path = file_path

    async def download_to_drive(self, dst_path: str):
        loop = asyncio.get_event_loop()
        await loop.run_in_executor(None, shutil.copyfile, self._path, dst_path)

def run_async(coro):
    return asyncio.run(coro)

def get_free_port(default=7860):
    """Find a free port if default is busy."""
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        try:
            s.bind(("0.0.0.0", default))
            return default
        except OSError:
            s.bind(("0.0.0.0", 0))
            return s.getsockname()[1]

# ---------------------------
# FastAPI API for external apps
# ---------------------------
api = FastAPI()

@api.post("/api/text_chat")
async def api_text_chat(
    user_id: Optional[int] = Form(0),
    text: str = Form(...),
    lang: str = Form("en")
):
    try:
        reply = await AI.generate_response(text, int(user_id), lang)
        return {"reply": reply}
    except Exception as e:
        return {"error": str(e)}

@api.post("/api/image_caption")
async def api_image_caption(user_id: Optional[int] = Form(0), image: UploadFile = None):
    try:
        temp_path = f"/tmp/{image.filename}"
        with open(temp_path, "wb") as f:
            f.write(await image.read())
        wrapper = GradioFileWrapper(temp_path)
        caption = await AI.process_image_message(wrapper, int(user_id))
        return {"caption": caption}
    except Exception as e:
        return {"error": str(e)}

@api.post("/api/voice_process")
async def api_voice_process(user_id: Optional[int] = Form(0), audio: UploadFile = None):
    try:
        temp_path = f"/tmp/{audio.filename}"
        with open(temp_path, "wb") as f:
            f.write(await audio.read())
        wrapper = GradioFileWrapper(temp_path)
        reply = await AI.process_voice_message(wrapper, int(user_id))
        return {"reply": reply}
    except Exception as e:
        return {"error": str(e)}

@api.post("/api/video_process")
async def api_video_process(user_id: Optional[int] = Form(0), video: UploadFile = None):
    try:
        temp_path = f"/tmp/{video.filename}"
        with open(temp_path, "wb") as f:
            f.write(await video.read())
        wrapper = GradioFileWrapper(temp_path)
        reply = await AI.process_video_message(wrapper, int(user_id))
        return {"reply": reply}
    except Exception as e:
        return {"error": str(e)}

@api.post("/api/file_process")
async def api_file_process(user_id: Optional[int] = Form(0), file: UploadFile = None):
    try:
        temp_path = f"/tmp/{file.filename}"
        with open(temp_path, "wb") as f:
            f.write(await file.read())
        wrapper = GradioFileWrapper(temp_path)
        reply = await AI.process_file_message(wrapper, int(user_id))
        return {"reply": reply}
    except Exception as e:
        return {"error": str(e)}

# ---------------------------
# Gradio UI
# ---------------------------
with gr.Blocks(title="Multimodal Bot") as demo:
    gr.Markdown("# ๐Ÿง  Multimodal Bot\nInteract via text, voice, images, video, or files.")

    with gr.Tab("๐Ÿ’ฌ Text Chat"):
        user_id_txt = gr.Textbox(label="User ID", placeholder="0")
        lang_sel = gr.Dropdown(choices=["en","zh","ja","ko","es","fr","de","it"], value="en", label="Language")
        txt_in = gr.Textbox(label="Your message", lines=4)
        txt_out = gr.Textbox(label="Bot reply", lines=6)
        gr.Button("Send").click(lambda uid, txt, lang: run_async(AI.generate_response(txt, int(uid or 0), lang)),
                                [user_id_txt, txt_in, lang_sel], txt_out)

    with gr.Tab("๐Ÿ–ผ Image Captioning"):
        user_id_img = gr.Textbox(label="User ID", placeholder="0")
        img_in = gr.Image(type="filepath", label="Upload an image")
        img_out = gr.Textbox(label="Caption")
        gr.Button("Caption").click(lambda uid, img: run_async(AI.process_image_message(GradioFileWrapper(img), int(uid or 0))),
                                   [user_id_img, img_in], img_out)

# ---------------------------
# Mount Gradio UI to FastAPI
# ---------------------------
api = gr.mount_gradio_app(api, demo, path="/")

if __name__ == "__main__":
    port = get_free_port()
    uvicorn.run(api, host="0.0.0.0", port=port)