Princeaka commited on
Commit
ea54d29
·
verified ·
1 Parent(s): 1de5d77

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -32
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # app.py -- HF-ready single-server FastAPI + Gradio mounted app
2
  import os
3
  import shutil
4
  import asyncio
@@ -25,28 +25,29 @@ class FileWrapper:
25
  self._path = path
26
 
27
  async def download_to_drive(self, dst_path: str) -> None:
28
- # keep API similar to your GradioFileWrapper
29
  loop = asyncio.get_event_loop()
30
  await loop.run_in_executor(None, shutil.copyfile, self._path, dst_path)
31
 
32
- def save_upload_to_tmp(up: UploadFile) -> str:
33
- """Save FastAPI UploadFile to /tmp and return path."""
34
- assert up and up.filename, "UploadFile missing filename"
 
35
  dest = os.path.join(TMP_DIR, up.filename)
36
- # overwrite if exists
37
  with open(dest, "wb") as f:
38
- f.write(up.file.read())
39
  return dest
40
 
41
  async def call_ai(fn, *args, **kwargs):
42
  """
43
  Call AI functions safely: if fn is async, await it; if sync, run in thread.
44
- This avoids blocking the event loop.
45
  """
 
 
46
  if inspect.iscoroutinefunction(fn):
47
  return await fn(*args, **kwargs)
48
  else:
49
- # run sync function in a thread to avoid blocking
50
  return await asyncio.to_thread(lambda: fn(*args, **kwargs))
51
 
52
  # ---------- FastAPI app ----------
@@ -56,7 +57,7 @@ app = FastAPI(title="Multimodal Module API")
56
  from fastapi.middleware.cors import CORSMiddleware
57
  app.add_middleware(
58
  CORSMiddleware,
59
- allow_origins=["*"], # change to specific domains for production
60
  allow_credentials=True,
61
  allow_methods=["*"],
62
  allow_headers=["*"],
@@ -71,7 +72,8 @@ async def api_predict(inputs: str = Form(...), user_id: Optional[int] = Form(0),
71
  Form field 'inputs' used as text.
72
  """
73
  try:
74
- reply = await call_ai(getattr(AI, "generate_response", getattr(AI, "process_text", None)), inputs, int(user_id), lang)
 
75
  # HF-style returns "data" array
76
  return {"data": [reply]}
77
  except Exception as e:
@@ -80,7 +82,8 @@ async def api_predict(inputs: str = Form(...), user_id: Optional[int] = Form(0),
80
  @app.post("/api/text")
81
  async def api_text(text: str = Form(...), user_id: Optional[int] = Form(0), lang: str = Form("en")):
82
  try:
83
- reply = await call_ai(getattr(AI, "generate_response", getattr(AI, "process_text", None)), text, int(user_id), lang)
 
84
  return {"status": "ok", "reply": reply}
85
  except Exception as e:
86
  return JSONResponse({"error": str(e)}, status_code=500)
@@ -91,8 +94,9 @@ async def api_voice(user_id: Optional[int] = Form(0), audio_file: UploadFile = F
91
  Upload audio file (multipart/form-data). Returns whatever your AI.process_voice_message returns (JSON/dict).
92
  """
93
  try:
94
- path = save_upload_to_tmp(audio_file)
95
- result = await call_ai(getattr(AI, "process_voice_message", None), FileWrapper(path), int(user_id))
 
96
  return JSONResponse(result)
97
  except Exception as e:
98
  return JSONResponse({"error": str(e)}, status_code=500)
@@ -100,7 +104,8 @@ async def api_voice(user_id: Optional[int] = Form(0), audio_file: UploadFile = F
100
  @app.post("/api/voice_reply")
101
  async def api_voice_reply(user_id: Optional[int] = Form(0), reply_text: str = Form(...), fmt: str = Form("ogg")):
102
  try:
103
- result = await call_ai(getattr(AI, "generate_voice_reply", None), reply_text, int(user_id), fmt)
 
104
  return {"status": "ok", "file": result}
105
  except Exception as e:
106
  return JSONResponse({"error": str(e)}, status_code=500)
@@ -108,8 +113,9 @@ async def api_voice_reply(user_id: Optional[int] = Form(0), reply_text: str = Fo
108
  @app.post("/api/image_caption")
109
  async def api_image_caption(user_id: Optional[int] = Form(0), image_file: UploadFile = File(...)):
110
  try:
111
- path = save_upload_to_tmp(image_file)
112
- caption = await call_ai(getattr(AI, "process_image_message", None), FileWrapper(path), int(user_id))
 
113
  return {"status": "ok", "caption": caption}
114
  except Exception as e:
115
  return JSONResponse({"error": str(e)}, status_code=500)
@@ -117,7 +123,8 @@ async def api_image_caption(user_id: Optional[int] = Form(0), image_file: Upload
117
  @app.post("/api/generate_image")
118
  async def api_generate_image(user_id: Optional[int] = Form(0), prompt: str = Form(...), width: int = Form(512), height: int = Form(512), steps: int = Form(30)):
119
  try:
120
- out_path = await call_ai(getattr(AI, "generate_image_from_text", None), prompt, int(user_id), width, height, steps)
 
121
  return {"status": "ok", "file": out_path}
122
  except Exception as e:
123
  return JSONResponse({"error": str(e)}, status_code=500)
@@ -125,11 +132,12 @@ async def api_generate_image(user_id: Optional[int] = Form(0), prompt: str = For
125
  @app.post("/api/edit_image")
126
  async def api_edit_image(user_id: Optional[int] = Form(0), image_file: UploadFile = File(...), mask_file: Optional[UploadFile] = File(None), prompt: str = Form("")):
127
  try:
128
- img_path = save_upload_to_tmp(image_file)
129
  mask_path = None
130
  if mask_file:
131
- mask_path = save_upload_to_tmp(mask_file)
132
- out_path = await call_ai(getattr(AI, "edit_image_inpaint", None), FileWrapper(img_path), FileWrapper(mask_path) if mask_path else None, prompt, int(user_id))
 
133
  return {"status": "ok", "file": out_path}
134
  except Exception as e:
135
  return JSONResponse({"error": str(e)}, status_code=500)
@@ -137,8 +145,9 @@ async def api_edit_image(user_id: Optional[int] = Form(0), image_file: UploadFil
137
  @app.post("/api/video")
138
  async def api_video(user_id: Optional[int] = Form(0), video_file: UploadFile = File(...)):
139
  try:
140
- path = save_upload_to_tmp(video_file)
141
- result = await call_ai(getattr(AI, "process_video", None), FileWrapper(path), int(user_id))
 
142
  return JSONResponse(result)
143
  except Exception as e:
144
  return JSONResponse({"error": str(e)}, status_code=500)
@@ -146,8 +155,9 @@ async def api_video(user_id: Optional[int] = Form(0), video_file: UploadFile = F
146
  @app.post("/api/file")
147
  async def api_file(user_id: Optional[int] = Form(0), file_obj: UploadFile = File(...)):
148
  try:
149
- path = save_upload_to_tmp(file_obj)
150
- result = await call_ai(getattr(AI, "process_file", None), FileWrapper(path), int(user_id))
 
151
  return JSONResponse(result)
152
  except Exception as e:
153
  return JSONResponse({"error": str(e)}, status_code=500)
@@ -155,20 +165,26 @@ async def api_file(user_id: Optional[int] = Form(0), file_obj: UploadFile = File
155
  @app.post("/api/code")
156
  async def api_code(user_id: Optional[int] = Form(0), prompt: str = Form(...), max_tokens: int = Form(512)):
157
  try:
158
- result = await call_ai(getattr(AI, "code_complete", None), int(user_id), prompt, max_tokens)
159
- # Some modules return string/code, others dict normalize:
 
 
 
 
 
160
  return {"status": "ok", "code": result}
161
  except Exception as e:
162
  return JSONResponse({"error": str(e)}, status_code=500)
163
 
164
  # ---------- Minimal Gradio UI (mounted) ----------
165
  def gradio_text_fn(text, user_id, lang):
166
- # call AI synchronously from Gradio (blocking safe in Gradio)
167
- if inspect.iscoroutinefunction(getattr(AI, "generate_response", getattr(AI, "process_text", None))):
168
- return asyncio.run(call_ai(getattr(AI, "generate_response", getattr(AI, "process_text", None)), text, int(user_id or 0), lang))
 
 
169
  else:
170
- # sync
171
- return getattr(AI, "generate_response", getattr(AI, "process_text", None))(text, int(user_id or 0), lang)
172
 
173
  with gr.Blocks(title="Multimodal Bot (UI)") as demo:
174
  gr.Markdown("# 🧠 Multimodal Bot — UI")
 
1
+ # app.py -- HF-ready single-server FastAPI + Gradio mounted app (fixed)
2
  import os
3
  import shutil
4
  import asyncio
 
25
  self._path = path
26
 
27
  async def download_to_drive(self, dst_path: str) -> None:
 
28
  loop = asyncio.get_event_loop()
29
  await loop.run_in_executor(None, shutil.copyfile, self._path, dst_path)
30
 
31
+ async def save_upload_to_tmp(up: UploadFile) -> str:
32
+ """Save FastAPI UploadFile to /tmp and return path. Uses async read."""
33
+ if not up or not up.filename:
34
+ raise ValueError("UploadFile missing or has no filename")
35
  dest = os.path.join(TMP_DIR, up.filename)
36
+ data = await up.read() # <-- important: async read
37
  with open(dest, "wb") as f:
38
+ f.write(data)
39
  return dest
40
 
41
  async def call_ai(fn, *args, **kwargs):
42
  """
43
  Call AI functions safely: if fn is async, await it; if sync, run in thread.
44
+ If fn is None, raise a clear error.
45
  """
46
+ if fn is None:
47
+ raise AttributeError("Requested AI method is not implemented in multimodal_module")
48
  if inspect.iscoroutinefunction(fn):
49
  return await fn(*args, **kwargs)
50
  else:
 
51
  return await asyncio.to_thread(lambda: fn(*args, **kwargs))
52
 
53
  # ---------- FastAPI app ----------
 
57
  from fastapi.middleware.cors import CORSMiddleware
58
  app.add_middleware(
59
  CORSMiddleware,
60
+ allow_origins=["*"], # tighten in production
61
  allow_credentials=True,
62
  allow_methods=["*"],
63
  allow_headers=["*"],
 
72
  Form field 'inputs' used as text.
73
  """
74
  try:
75
+ fn = getattr(AI, "generate_response", getattr(AI, "process_text", None))
76
+ reply = await call_ai(fn, inputs, int(user_id), lang)
77
  # HF-style returns "data" array
78
  return {"data": [reply]}
79
  except Exception as e:
 
82
  @app.post("/api/text")
83
  async def api_text(text: str = Form(...), user_id: Optional[int] = Form(0), lang: str = Form("en")):
84
  try:
85
+ fn = getattr(AI, "generate_response", getattr(AI, "process_text", None))
86
+ reply = await call_ai(fn, text, int(user_id), lang)
87
  return {"status": "ok", "reply": reply}
88
  except Exception as e:
89
  return JSONResponse({"error": str(e)}, status_code=500)
 
94
  Upload audio file (multipart/form-data). Returns whatever your AI.process_voice_message returns (JSON/dict).
95
  """
96
  try:
97
+ path = await save_upload_to_tmp(audio_file)
98
+ fn = getattr(AI, "process_voice_message", None)
99
+ result = await call_ai(fn, FileWrapper(path), int(user_id))
100
  return JSONResponse(result)
101
  except Exception as e:
102
  return JSONResponse({"error": str(e)}, status_code=500)
 
104
  @app.post("/api/voice_reply")
105
  async def api_voice_reply(user_id: Optional[int] = Form(0), reply_text: str = Form(...), fmt: str = Form("ogg")):
106
  try:
107
+ fn = getattr(AI, "generate_voice_reply", None)
108
+ result = await call_ai(fn, reply_text, int(user_id), fmt)
109
  return {"status": "ok", "file": result}
110
  except Exception as e:
111
  return JSONResponse({"error": str(e)}, status_code=500)
 
113
  @app.post("/api/image_caption")
114
  async def api_image_caption(user_id: Optional[int] = Form(0), image_file: UploadFile = File(...)):
115
  try:
116
+ path = await save_upload_to_tmp(image_file)
117
+ fn = getattr(AI, "process_image_message", None)
118
+ caption = await call_ai(fn, FileWrapper(path), int(user_id))
119
  return {"status": "ok", "caption": caption}
120
  except Exception as e:
121
  return JSONResponse({"error": str(e)}, status_code=500)
 
123
  @app.post("/api/generate_image")
124
  async def api_generate_image(user_id: Optional[int] = Form(0), prompt: str = Form(...), width: int = Form(512), height: int = Form(512), steps: int = Form(30)):
125
  try:
126
+ fn = getattr(AI, "generate_image_from_text", None)
127
+ out_path = await call_ai(fn, prompt, int(user_id), width, height, steps)
128
  return {"status": "ok", "file": out_path}
129
  except Exception as e:
130
  return JSONResponse({"error": str(e)}, status_code=500)
 
132
  @app.post("/api/edit_image")
133
  async def api_edit_image(user_id: Optional[int] = Form(0), image_file: UploadFile = File(...), mask_file: Optional[UploadFile] = File(None), prompt: str = Form("")):
134
  try:
135
+ img_path = await save_upload_to_tmp(image_file)
136
  mask_path = None
137
  if mask_file:
138
+ mask_path = await save_upload_to_tmp(mask_file)
139
+ fn = getattr(AI, "edit_image_inpaint", None)
140
+ out_path = await call_ai(fn, FileWrapper(img_path), FileWrapper(mask_path) if mask_path else None, prompt, int(user_id))
141
  return {"status": "ok", "file": out_path}
142
  except Exception as e:
143
  return JSONResponse({"error": str(e)}, status_code=500)
 
145
  @app.post("/api/video")
146
  async def api_video(user_id: Optional[int] = Form(0), video_file: UploadFile = File(...)):
147
  try:
148
+ path = await save_upload_to_tmp(video_file)
149
+ fn = getattr(AI, "process_video", None)
150
+ result = await call_ai(fn, FileWrapper(path), int(user_id))
151
  return JSONResponse(result)
152
  except Exception as e:
153
  return JSONResponse({"error": str(e)}, status_code=500)
 
155
  @app.post("/api/file")
156
  async def api_file(user_id: Optional[int] = Form(0), file_obj: UploadFile = File(...)):
157
  try:
158
+ path = await save_upload_to_tmp(file_obj)
159
+ fn = getattr(AI, "process_file", None)
160
+ result = await call_ai(fn, FileWrapper(path), int(user_id))
161
  return JSONResponse(result)
162
  except Exception as e:
163
  return JSONResponse({"error": str(e)}, status_code=500)
 
165
  @app.post("/api/code")
166
  async def api_code(user_id: Optional[int] = Form(0), prompt: str = Form(...), max_tokens: int = Form(512)):
167
  try:
168
+ fn = getattr(AI, "code_complete", None)
169
+ # Many implementations expect (user_id, prompt, max_tokens) or (prompt, max_tokens)
170
+ # Try user-first signature first, fallback to prompt-first
171
+ try:
172
+ result = await call_ai(fn, int(user_id), prompt, max_tokens)
173
+ except TypeError:
174
+ result = await call_ai(fn, prompt, max_tokens=max_tokens)
175
  return {"status": "ok", "code": result}
176
  except Exception as e:
177
  return JSONResponse({"error": str(e)}, status_code=500)
178
 
179
  # ---------- Minimal Gradio UI (mounted) ----------
180
  def gradio_text_fn(text, user_id, lang):
181
+ fn = getattr(AI, "generate_response", getattr(AI, "process_text", None))
182
+ if fn is None:
183
+ return "Error: text handler not implemented in multimodal_module"
184
+ if inspect.iscoroutinefunction(fn):
185
+ return asyncio.run(call_ai(fn, text, int(user_id or 0), lang))
186
  else:
187
+ return fn(text, int(user_id or 0), lang)
 
188
 
189
  with gr.Blocks(title="Multimodal Bot (UI)") as demo:
190
  gr.Markdown("# 🧠 Multimodal Bot — UI")