Princeaka commited on
Commit
720a169
Β·
verified Β·
1 Parent(s): 20927ba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -38
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # app.py β€” Multimodal API + Gradio UI with automatic free port selection
2
  import os
3
  import shutil
4
  import asyncio
@@ -11,18 +11,16 @@ from fastapi.middleware.cors import CORSMiddleware
11
  import gradio as gr
12
  import uvicorn
13
 
14
- # ------------------- Import your multimodal module -------------------
15
  from multimodal_module import MultiModalChatModule
16
  AI = MultiModalChatModule()
17
 
18
- # ------------------- Helpers -------------------
19
  TMP_DIR = "/tmp"
20
  os.makedirs(TMP_DIR, exist_ok=True)
21
 
 
22
  class FileWrapper:
23
  def __init__(self, path: str):
24
  self._path = path
25
-
26
  async def download_to_drive(self, dst_path: str) -> None:
27
  loop = asyncio.get_event_loop()
28
  await loop.run_in_executor(None, shutil.copyfile, self._path, dst_path)
@@ -43,37 +41,18 @@ async def call_ai(fn, *args, **kwargs):
43
  return await fn(*args, **kwargs)
44
  return await asyncio.to_thread(lambda: fn(*args, **kwargs))
45
 
46
- def find_free_port(start_port=7860, max_tries=50):
47
- """Find an available port starting from start_port."""
48
- port = start_port
49
- for _ in range(max_tries):
50
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
51
- if s.connect_ex(("0.0.0.0", port)) != 0:
52
- return port
53
- port += 1
54
- raise OSError("No free port found")
55
-
56
- # ------------------- FastAPI -------------------
57
  app = FastAPI(title="Multimodal Module API")
58
 
59
- # Allow all origins (adjust for production)
60
  app.add_middleware(
61
  CORSMiddleware,
62
- allow_origins=["*"],
63
  allow_credentials=True,
64
  allow_methods=["*"],
65
  allow_headers=["*"],
66
  )
67
 
68
- # ------------------- API Endpoints -------------------
69
- @app.post("/api/predict")
70
- async def api_predict(inputs: str = Form(...), user_id: Optional[int] = Form(0), lang: str = Form("en")):
71
- try:
72
- fn = getattr(AI, "generate_response", getattr(AI, "process_text", None))
73
- reply = await call_ai(fn, inputs, int(user_id), lang)
74
- return {"data": [reply]}
75
- except Exception as e:
76
- return JSONResponse({"error": str(e)}, status_code=500)
77
 
78
  @app.post("/api/text")
79
  async def api_text(text: str = Form(...), user_id: Optional[int] = Form(0), lang: str = Form("en")):
@@ -114,7 +93,13 @@ async def api_image_caption(user_id: Optional[int] = Form(0), image_file: Upload
114
  return JSONResponse({"error": str(e)}, status_code=500)
115
 
116
  @app.post("/api/generate_image")
117
- async def api_generate_image(user_id: Optional[int] = Form(0), prompt: str = Form(...), width: int = Form(512), height: int = Form(512), steps: int = Form(30)):
 
 
 
 
 
 
118
  try:
119
  fn = getattr(AI, "generate_image_from_text", None)
120
  out_path = await call_ai(fn, prompt, int(user_id), width, height, steps)
@@ -123,14 +108,25 @@ async def api_generate_image(user_id: Optional[int] = Form(0), prompt: str = For
123
  return JSONResponse({"error": str(e)}, status_code=500)
124
 
125
  @app.post("/api/edit_image")
126
- async def api_edit_image(user_id: Optional[int] = Form(0), image_file: UploadFile = File(...), mask_file: Optional[UploadFile] = File(None), prompt: str = Form("")):
 
 
 
 
 
127
  try:
128
  img_path = await save_upload_to_tmp(image_file)
129
  mask_path = None
130
  if mask_file:
131
  mask_path = await save_upload_to_tmp(mask_file)
132
  fn = getattr(AI, "edit_image_inpaint", None)
133
- out_path = await call_ai(fn, FileWrapper(img_path), FileWrapper(mask_path) if mask_path else None, prompt, int(user_id))
 
 
 
 
 
 
134
  return {"status": "ok", "file": out_path}
135
  except Exception as e:
136
  return JSONResponse({"error": str(e)}, status_code=500)
@@ -167,7 +163,7 @@ async def api_code(user_id: Optional[int] = Form(0), prompt: str = Form(...), ma
167
  except Exception as e:
168
  return JSONResponse({"error": str(e)}, status_code=500)
169
 
170
- # ------------------- Gradio UI -------------------
171
  def gradio_text_fn(text, user_id, lang):
172
  fn = getattr(AI, "generate_response", getattr(AI, "process_text", None))
173
  if fn is None:
@@ -178,17 +174,22 @@ def gradio_text_fn(text, user_id, lang):
178
  with gr.Blocks(title="Multimodal Bot (UI)") as demo:
179
  gr.Markdown("# 🧠 Multimodal Bot β€” UI")
180
  with gr.Row():
181
- txt_uid = gr.Textbox(label="User ID", value="0")
182
- txt_lang = gr.Dropdown(["en","zh","ja","ko","es","fr","de","it"], value="en", label="Language")
183
  inp = gr.Textbox(lines=3, label="Message")
184
  out = gr.Textbox(lines=6, label="Reply")
185
- gr.Button("Send").click(gradio_text_fn, [inp, txt_uid, txt_lang], out)
186
 
187
- # Mount Gradio app into FastAPI
188
  app = gr.mount_gradio_app(app, demo, path="/")
189
 
190
- # ------------------- Run -------------------
 
 
 
 
 
 
 
191
  if __name__ == "__main__":
192
- port = int(os.environ.get("PORT", find_free_port()))
193
- print(f"πŸš€ Starting server on port {port}")
194
- uvicorn.run("app:app", host="0.0.0.0", port=port, reload=False)
 
1
+ # app.py β€” Full multimodal FastAPI + Gradio with auto free port pick
2
  import os
3
  import shutil
4
  import asyncio
 
11
  import gradio as gr
12
  import uvicorn
13
 
 
14
  from multimodal_module import MultiModalChatModule
15
  AI = MultiModalChatModule()
16
 
 
17
  TMP_DIR = "/tmp"
18
  os.makedirs(TMP_DIR, exist_ok=True)
19
 
20
+ # --- Helpers ---
21
  class FileWrapper:
22
  def __init__(self, path: str):
23
  self._path = path
 
24
  async def download_to_drive(self, dst_path: str) -> None:
25
  loop = asyncio.get_event_loop()
26
  await loop.run_in_executor(None, shutil.copyfile, self._path, dst_path)
 
41
  return await fn(*args, **kwargs)
42
  return await asyncio.to_thread(lambda: fn(*args, **kwargs))
43
 
44
+ # --- FastAPI app ---
 
 
 
 
 
 
 
 
 
 
45
  app = FastAPI(title="Multimodal Module API")
46
 
 
47
  app.add_middleware(
48
  CORSMiddleware,
49
+ allow_origins=["*"], # Change this in production!
50
  allow_credentials=True,
51
  allow_methods=["*"],
52
  allow_headers=["*"],
53
  )
54
 
55
+ # --- API Endpoints ---
 
 
 
 
 
 
 
 
56
 
57
  @app.post("/api/text")
58
  async def api_text(text: str = Form(...), user_id: Optional[int] = Form(0), lang: str = Form("en")):
 
93
  return JSONResponse({"error": str(e)}, status_code=500)
94
 
95
  @app.post("/api/generate_image")
96
+ async def api_generate_image(
97
+ user_id: Optional[int] = Form(0),
98
+ prompt: str = Form(...),
99
+ width: int = Form(512),
100
+ height: int = Form(512),
101
+ steps: int = Form(30)
102
+ ):
103
  try:
104
  fn = getattr(AI, "generate_image_from_text", None)
105
  out_path = await call_ai(fn, prompt, int(user_id), width, height, steps)
 
108
  return JSONResponse({"error": str(e)}, status_code=500)
109
 
110
  @app.post("/api/edit_image")
111
+ async def api_edit_image(
112
+ user_id: Optional[int] = Form(0),
113
+ image_file: UploadFile = File(...),
114
+ mask_file: Optional[UploadFile] = File(None),
115
+ prompt: str = Form("")
116
+ ):
117
  try:
118
  img_path = await save_upload_to_tmp(image_file)
119
  mask_path = None
120
  if mask_file:
121
  mask_path = await save_upload_to_tmp(mask_file)
122
  fn = getattr(AI, "edit_image_inpaint", None)
123
+ out_path = await call_ai(
124
+ fn,
125
+ FileWrapper(img_path),
126
+ FileWrapper(mask_path) if mask_path else None,
127
+ prompt,
128
+ int(user_id)
129
+ )
130
  return {"status": "ok", "file": out_path}
131
  except Exception as e:
132
  return JSONResponse({"error": str(e)}, status_code=500)
 
163
  except Exception as e:
164
  return JSONResponse({"error": str(e)}, status_code=500)
165
 
166
+ # --- Gradio UI ---
167
  def gradio_text_fn(text, user_id, lang):
168
  fn = getattr(AI, "generate_response", getattr(AI, "process_text", None))
169
  if fn is None:
 
174
  with gr.Blocks(title="Multimodal Bot (UI)") as demo:
175
  gr.Markdown("# 🧠 Multimodal Bot β€” UI")
176
  with gr.Row():
177
+ uid = gr.Textbox(label="User ID", value="0")
178
+ lang = gr.Dropdown(["en", "zh", "ja", "ko", "es", "fr", "de", "it"], value="en", label="Language")
179
  inp = gr.Textbox(lines=3, label="Message")
180
  out = gr.Textbox(lines=6, label="Reply")
181
+ gr.Button("Send").click(gradio_text_fn, [inp, uid, lang], out)
182
 
 
183
  app = gr.mount_gradio_app(app, demo, path="/")
184
 
185
+ # --- Auto free port finder ---
186
+ def get_free_port():
187
+ s = socket.socket()
188
+ s.bind(("", 0))
189
+ port = s.getsockname()[1]
190
+ s.close()
191
+ return port
192
+
193
  if __name__ == "__main__":
194
+ port = int(os.environ.get("PORT", get_free_port()))
195
+ uvicorn.run("app:app", host="0.0.0.0", port=port)