Princeaka commited on
Commit
20927ba
·
verified ·
1 Parent(s): b2130a6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -17
app.py CHANGED
@@ -1,26 +1,25 @@
1
- # app.py -- HF-ready single-server FastAPI + Gradio mounted app (no double server conflict)
2
  import os
3
  import shutil
4
  import asyncio
5
  import inspect
 
6
  from typing import Optional
7
  from fastapi import FastAPI, UploadFile, File, Form
8
  from fastapi.responses import JSONResponse
 
9
  import gradio as gr
10
  import uvicorn
11
 
12
- # Import your multimodal module
13
  from multimodal_module import MultiModalChatModule
14
-
15
- # Instantiate AI module
16
  AI = MultiModalChatModule()
17
 
18
- # ---------- Helpers ----------
19
  TMP_DIR = "/tmp"
20
  os.makedirs(TMP_DIR, exist_ok=True)
21
 
22
  class FileWrapper:
23
- """Simple path wrapper for AI methods."""
24
  def __init__(self, path: str):
25
  self._path = path
26
 
@@ -29,7 +28,6 @@ class FileWrapper:
29
  await loop.run_in_executor(None, shutil.copyfile, self._path, dst_path)
30
 
31
  async def save_upload_to_tmp(up: UploadFile) -> str:
32
- """Save FastAPI UploadFile to /tmp and return path."""
33
  if not up or not up.filename:
34
  raise ValueError("No file uploaded")
35
  dest = os.path.join(TMP_DIR, up.filename)
@@ -39,27 +37,35 @@ async def save_upload_to_tmp(up: UploadFile) -> str:
39
  return dest
40
 
41
  async def call_ai(fn, *args, **kwargs):
42
- """Run AI method whether it's sync or async."""
43
  if fn is None:
44
  raise AttributeError("Requested AI method not implemented")
45
  if inspect.iscoroutinefunction(fn):
46
  return await fn(*args, **kwargs)
47
  return await asyncio.to_thread(lambda: fn(*args, **kwargs))
48
 
49
- # ---------- FastAPI ----------
 
 
 
 
 
 
 
 
 
 
50
  app = FastAPI(title="Multimodal Module API")
51
 
52
- # CORS (if you call this from the browser)
53
- from fastapi.middleware.cors import CORSMiddleware
54
  app.add_middleware(
55
  CORSMiddleware,
56
- allow_origins=["*"], # tighten for prod
57
  allow_credentials=True,
58
  allow_methods=["*"],
59
  allow_headers=["*"],
60
  )
61
 
62
- # ----------------- API endpoints -----------------
63
  @app.post("/api/predict")
64
  async def api_predict(inputs: str = Form(...), user_id: Optional[int] = Form(0), lang: str = Form("en")):
65
  try:
@@ -161,7 +167,7 @@ async def api_code(user_id: Optional[int] = Form(0), prompt: str = Form(...), ma
161
  except Exception as e:
162
  return JSONResponse({"error": str(e)}, status_code=500)
163
 
164
- # ---------- Minimal Gradio UI ----------
165
  def gradio_text_fn(text, user_id, lang):
166
  fn = getattr(AI, "generate_response", getattr(AI, "process_text", None))
167
  if fn is None:
@@ -178,10 +184,11 @@ with gr.Blocks(title="Multimodal Bot (UI)") as demo:
178
  out = gr.Textbox(lines=6, label="Reply")
179
  gr.Button("Send").click(gradio_text_fn, [inp, txt_uid, txt_lang], out)
180
 
181
- # Mount Gradio at /
182
  app = gr.mount_gradio_app(app, demo, path="/")
183
 
184
- # ---------- Run ----------
185
  if __name__ == "__main__":
186
- port = int(os.environ.get("PORT", 7860))
 
187
  uvicorn.run("app:app", host="0.0.0.0", port=port, reload=False)
 
1
+ # app.py Multimodal API + Gradio UI with automatic free port selection
2
  import os
3
  import shutil
4
  import asyncio
5
  import inspect
6
+ import socket
7
  from typing import Optional
8
  from fastapi import FastAPI, UploadFile, File, Form
9
  from fastapi.responses import JSONResponse
10
+ from fastapi.middleware.cors import CORSMiddleware
11
  import gradio as gr
12
  import uvicorn
13
 
14
+ # ------------------- Import your multimodal module -------------------
15
  from multimodal_module import MultiModalChatModule
 
 
16
  AI = MultiModalChatModule()
17
 
18
+ # ------------------- Helpers -------------------
19
  TMP_DIR = "/tmp"
20
  os.makedirs(TMP_DIR, exist_ok=True)
21
 
22
  class FileWrapper:
 
23
  def __init__(self, path: str):
24
  self._path = path
25
 
 
28
  await loop.run_in_executor(None, shutil.copyfile, self._path, dst_path)
29
 
30
  async def save_upload_to_tmp(up: UploadFile) -> str:
 
31
  if not up or not up.filename:
32
  raise ValueError("No file uploaded")
33
  dest = os.path.join(TMP_DIR, up.filename)
 
37
  return dest
38
 
39
  async def call_ai(fn, *args, **kwargs):
 
40
  if fn is None:
41
  raise AttributeError("Requested AI method not implemented")
42
  if inspect.iscoroutinefunction(fn):
43
  return await fn(*args, **kwargs)
44
  return await asyncio.to_thread(lambda: fn(*args, **kwargs))
45
 
46
+ def find_free_port(start_port=7860, max_tries=50):
47
+ """Find an available port starting from start_port."""
48
+ port = start_port
49
+ for _ in range(max_tries):
50
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
51
+ if s.connect_ex(("0.0.0.0", port)) != 0:
52
+ return port
53
+ port += 1
54
+ raise OSError("No free port found")
55
+
56
+ # ------------------- FastAPI -------------------
57
  app = FastAPI(title="Multimodal Module API")
58
 
59
+ # Allow all origins (adjust for production)
 
60
  app.add_middleware(
61
  CORSMiddleware,
62
+ allow_origins=["*"],
63
  allow_credentials=True,
64
  allow_methods=["*"],
65
  allow_headers=["*"],
66
  )
67
 
68
+ # ------------------- API Endpoints -------------------
69
  @app.post("/api/predict")
70
  async def api_predict(inputs: str = Form(...), user_id: Optional[int] = Form(0), lang: str = Form("en")):
71
  try:
 
167
  except Exception as e:
168
  return JSONResponse({"error": str(e)}, status_code=500)
169
 
170
+ # ------------------- Gradio UI -------------------
171
  def gradio_text_fn(text, user_id, lang):
172
  fn = getattr(AI, "generate_response", getattr(AI, "process_text", None))
173
  if fn is None:
 
184
  out = gr.Textbox(lines=6, label="Reply")
185
  gr.Button("Send").click(gradio_text_fn, [inp, txt_uid, txt_lang], out)
186
 
187
+ # Mount Gradio app into FastAPI
188
  app = gr.mount_gradio_app(app, demo, path="/")
189
 
190
+ # ------------------- Run -------------------
191
  if __name__ == "__main__":
192
+ port = int(os.environ.get("PORT", find_free_port()))
193
+ print(f"🚀 Starting server on port {port}")
194
  uvicorn.run("app:app", host="0.0.0.0", port=port, reload=False)