Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,26 +1,25 @@
|
|
1 |
-
# app.py
|
2 |
import os
|
3 |
import shutil
|
4 |
import asyncio
|
5 |
import inspect
|
|
|
6 |
from typing import Optional
|
7 |
from fastapi import FastAPI, UploadFile, File, Form
|
8 |
from fastapi.responses import JSONResponse
|
|
|
9 |
import gradio as gr
|
10 |
import uvicorn
|
11 |
|
12 |
-
# Import your multimodal module
|
13 |
from multimodal_module import MultiModalChatModule
|
14 |
-
|
15 |
-
# Instantiate AI module
|
16 |
AI = MultiModalChatModule()
|
17 |
|
18 |
-
#
|
19 |
TMP_DIR = "/tmp"
|
20 |
os.makedirs(TMP_DIR, exist_ok=True)
|
21 |
|
22 |
class FileWrapper:
|
23 |
-
"""Simple path wrapper for AI methods."""
|
24 |
def __init__(self, path: str):
|
25 |
self._path = path
|
26 |
|
@@ -29,7 +28,6 @@ class FileWrapper:
|
|
29 |
await loop.run_in_executor(None, shutil.copyfile, self._path, dst_path)
|
30 |
|
31 |
async def save_upload_to_tmp(up: UploadFile) -> str:
|
32 |
-
"""Save FastAPI UploadFile to /tmp and return path."""
|
33 |
if not up or not up.filename:
|
34 |
raise ValueError("No file uploaded")
|
35 |
dest = os.path.join(TMP_DIR, up.filename)
|
@@ -39,27 +37,35 @@ async def save_upload_to_tmp(up: UploadFile) -> str:
|
|
39 |
return dest
|
40 |
|
41 |
async def call_ai(fn, *args, **kwargs):
|
42 |
-
"""Run AI method whether it's sync or async."""
|
43 |
if fn is None:
|
44 |
raise AttributeError("Requested AI method not implemented")
|
45 |
if inspect.iscoroutinefunction(fn):
|
46 |
return await fn(*args, **kwargs)
|
47 |
return await asyncio.to_thread(lambda: fn(*args, **kwargs))
|
48 |
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
app = FastAPI(title="Multimodal Module API")
|
51 |
|
52 |
-
#
|
53 |
-
from fastapi.middleware.cors import CORSMiddleware
|
54 |
app.add_middleware(
|
55 |
CORSMiddleware,
|
56 |
-
allow_origins=["*"],
|
57 |
allow_credentials=True,
|
58 |
allow_methods=["*"],
|
59 |
allow_headers=["*"],
|
60 |
)
|
61 |
|
62 |
-
#
|
63 |
@app.post("/api/predict")
|
64 |
async def api_predict(inputs: str = Form(...), user_id: Optional[int] = Form(0), lang: str = Form("en")):
|
65 |
try:
|
@@ -161,7 +167,7 @@ async def api_code(user_id: Optional[int] = Form(0), prompt: str = Form(...), ma
|
|
161 |
except Exception as e:
|
162 |
return JSONResponse({"error": str(e)}, status_code=500)
|
163 |
|
164 |
-
#
|
165 |
def gradio_text_fn(text, user_id, lang):
|
166 |
fn = getattr(AI, "generate_response", getattr(AI, "process_text", None))
|
167 |
if fn is None:
|
@@ -178,10 +184,11 @@ with gr.Blocks(title="Multimodal Bot (UI)") as demo:
|
|
178 |
out = gr.Textbox(lines=6, label="Reply")
|
179 |
gr.Button("Send").click(gradio_text_fn, [inp, txt_uid, txt_lang], out)
|
180 |
|
181 |
-
# Mount Gradio
|
182 |
app = gr.mount_gradio_app(app, demo, path="/")
|
183 |
|
184 |
-
#
|
185 |
if __name__ == "__main__":
|
186 |
-
port = int(os.environ.get("PORT",
|
|
|
187 |
uvicorn.run("app:app", host="0.0.0.0", port=port, reload=False)
|
|
|
1 |
+
# app.py — Multimodal API + Gradio UI with automatic free port selection
|
2 |
import os
|
3 |
import shutil
|
4 |
import asyncio
|
5 |
import inspect
|
6 |
+
import socket
|
7 |
from typing import Optional
|
8 |
from fastapi import FastAPI, UploadFile, File, Form
|
9 |
from fastapi.responses import JSONResponse
|
10 |
+
from fastapi.middleware.cors import CORSMiddleware
|
11 |
import gradio as gr
|
12 |
import uvicorn
|
13 |
|
14 |
+
# ------------------- Import your multimodal module -------------------
|
15 |
from multimodal_module import MultiModalChatModule
|
|
|
|
|
16 |
AI = MultiModalChatModule()
|
17 |
|
18 |
+
# ------------------- Helpers -------------------
|
19 |
TMP_DIR = "/tmp"
|
20 |
os.makedirs(TMP_DIR, exist_ok=True)
|
21 |
|
22 |
class FileWrapper:
|
|
|
23 |
def __init__(self, path: str):
|
24 |
self._path = path
|
25 |
|
|
|
28 |
await loop.run_in_executor(None, shutil.copyfile, self._path, dst_path)
|
29 |
|
30 |
async def save_upload_to_tmp(up: UploadFile) -> str:
|
|
|
31 |
if not up or not up.filename:
|
32 |
raise ValueError("No file uploaded")
|
33 |
dest = os.path.join(TMP_DIR, up.filename)
|
|
|
37 |
return dest
|
38 |
|
39 |
async def call_ai(fn, *args, **kwargs):
|
|
|
40 |
if fn is None:
|
41 |
raise AttributeError("Requested AI method not implemented")
|
42 |
if inspect.iscoroutinefunction(fn):
|
43 |
return await fn(*args, **kwargs)
|
44 |
return await asyncio.to_thread(lambda: fn(*args, **kwargs))
|
45 |
|
46 |
+
def find_free_port(start_port=7860, max_tries=50):
|
47 |
+
"""Find an available port starting from start_port."""
|
48 |
+
port = start_port
|
49 |
+
for _ in range(max_tries):
|
50 |
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
51 |
+
if s.connect_ex(("0.0.0.0", port)) != 0:
|
52 |
+
return port
|
53 |
+
port += 1
|
54 |
+
raise OSError("No free port found")
|
55 |
+
|
56 |
+
# ------------------- FastAPI -------------------
|
57 |
app = FastAPI(title="Multimodal Module API")
|
58 |
|
59 |
+
# Allow all origins (adjust for production)
|
|
|
60 |
app.add_middleware(
|
61 |
CORSMiddleware,
|
62 |
+
allow_origins=["*"],
|
63 |
allow_credentials=True,
|
64 |
allow_methods=["*"],
|
65 |
allow_headers=["*"],
|
66 |
)
|
67 |
|
68 |
+
# ------------------- API Endpoints -------------------
|
69 |
@app.post("/api/predict")
|
70 |
async def api_predict(inputs: str = Form(...), user_id: Optional[int] = Form(0), lang: str = Form("en")):
|
71 |
try:
|
|
|
167 |
except Exception as e:
|
168 |
return JSONResponse({"error": str(e)}, status_code=500)
|
169 |
|
170 |
+
# ------------------- Gradio UI -------------------
|
171 |
def gradio_text_fn(text, user_id, lang):
|
172 |
fn = getattr(AI, "generate_response", getattr(AI, "process_text", None))
|
173 |
if fn is None:
|
|
|
184 |
out = gr.Textbox(lines=6, label="Reply")
|
185 |
gr.Button("Send").click(gradio_text_fn, [inp, txt_uid, txt_lang], out)
|
186 |
|
187 |
+
# Mount Gradio app into FastAPI
|
188 |
app = gr.mount_gradio_app(app, demo, path="/")
|
189 |
|
190 |
+
# ------------------- Run -------------------
|
191 |
if __name__ == "__main__":
|
192 |
+
port = int(os.environ.get("PORT", find_free_port()))
|
193 |
+
print(f"🚀 Starting server on port {port}")
|
194 |
uvicorn.run("app:app", host="0.0.0.0", port=port, reload=False)
|