GOGO198 commited on
Commit
2e36b27
·
verified ·
1 Parent(s): dbdb4ab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -23
app.py CHANGED
@@ -80,7 +80,7 @@ def predict(vector):
80
  print(f"输入向量类型: {type(vector)}, 长度: {len(vector) if hasattr(vector, '__len__') else 'N/A'}")
81
 
82
  try:
83
- # API调用会收到二维数组
84
  if isinstance(vector, list) and len(vector) > 0 and isinstance(vector[0], list):
85
  # 已经是正确的二维数组格式
86
  pass
@@ -90,12 +90,12 @@ def predict(vector):
90
  else:
91
  error_msg = "错误:输入格式无效"
92
  print(error_msg)
93
- return error_msg
94
 
95
  if len(vector) != 1 or len(vector[0]) != 768:
96
  error_msg = f"错误:需要1x768的二维数组,收到{len(vector)}x{len(vector[0]) if vector else 0}"
97
  print(error_msg)
98
- return error_msg
99
 
100
  # 添加实际处理逻辑
101
  vector_array = np.array(vector, dtype=np.float32)
@@ -106,31 +106,74 @@ def predict(vector):
106
  try:
107
  result = metadata.iloc[I[0][i]]
108
  confidence = 1/(1+D[0][i])
109
- results.append(f"匹配结果 {i+1}: {result['source']} | 置信度: {confidence:.2f}")
 
 
 
110
  except Exception as e:
111
  print(f"结果处理错误: {str(e)}")
112
- results.append(f"结果 {i+1}: 数据获取失败")
113
 
114
  print(f"处理完成 | 耗时: {time.time()-start_time:.2f}秒")
115
 
116
- # 返回Gradio期望的格式
117
- return json.dumps({
118
- "data": [
119
- json.dumps({
120
- "results": results
121
- })
122
- ]
123
- })
124
 
125
  except Exception as e:
126
  import traceback
127
  error_msg = f"处理错误: {str(e)}\n{traceback.format_exc()}"
128
  print(error_msg)
129
- return json.dumps({
130
- "error": "处理错误,请重试或联系管理员"
131
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
- # 创建Blocks应用
134
  with gr.Blocks() as demo:
135
  gr.Markdown("## 🛍 电商智能客服系统 (768维专用)")
136
  gr.Markdown("**使用CLIP-vit-large-patch14模型 | 向量维度: 768**")
@@ -142,7 +185,7 @@ with gr.Blocks() as demo:
142
  label="输入向量 (768维)",
143
  value=[[0.1]*768]
144
  )
145
- output = gr.Textbox(label="智能回答", lines=5)
146
 
147
  submit_btn = gr.Button("生成回答")
148
  submit_btn.click(
@@ -153,6 +196,21 @@ with gr.Blocks() as demo:
153
 
154
  # 启动应用
155
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  # 验证资源
157
  print("="*50)
158
  print("Space启动完成 | 准备接收请求")
@@ -160,8 +218,6 @@ if __name__ == "__main__":
160
  print(f"元数据记录数: {len(metadata)}")
161
  print("="*50)
162
 
163
- # API端点自动创建在 /run/predict
164
- demo.launch(
165
- server_name="0.0.0.0",
166
- server_port=7860
167
- )
 
80
  print(f"输入向量类型: {type(vector)}, 长度: {len(vector) if hasattr(vector, '__len__') else 'N/A'}")
81
 
82
  try:
83
+ # 处理不同输入格式
84
  if isinstance(vector, list) and len(vector) > 0 and isinstance(vector[0], list):
85
  # 已经是正确的二维数组格式
86
  pass
 
90
  else:
91
  error_msg = "错误:输入格式无效"
92
  print(error_msg)
93
+ return {"error": error_msg}
94
 
95
  if len(vector) != 1 or len(vector[0]) != 768:
96
  error_msg = f"错误:需要1x768的二维数组,收到{len(vector)}x{len(vector[0]) if vector else 0}"
97
  print(error_msg)
98
+ return {"error": error_msg}
99
 
100
  # 添加实际处理逻辑
101
  vector_array = np.array(vector, dtype=np.float32)
 
106
  try:
107
  result = metadata.iloc[I[0][i]]
108
  confidence = 1/(1+D[0][i])
109
+ results.append({
110
+ "source": result['source'],
111
+ "confidence": round(float(confidence), 2)
112
+ })
113
  except Exception as e:
114
  print(f"结果处理错误: {str(e)}")
115
+ results.append({"error": f"结果 {i+1}: 数据获取失败"})
116
 
117
  print(f"处理完成 | 耗时: {time.time()-start_time:.2f}秒")
118
 
119
+ # 返回标准化的JSON格式
120
+ return {
121
+ "status": "success",
122
+ "results": results
123
+ }
 
 
 
124
 
125
  except Exception as e:
126
  import traceback
127
  error_msg = f"处理错误: {str(e)}\n{traceback.format_exc()}"
128
  print(error_msg)
129
+ return {
130
+ "status": "error",
131
+ "message": "处理错误,请重试或联系管理员"
132
+ }
133
+
134
+ # 创建FastAPI应用 (替代Gradio Blocks)
135
+ import fastapi
136
+ from fastapi import FastAPI, Request
137
+ from fastapi.responses import JSONResponse
138
+ from fastapi.middleware.cors import CORSMiddleware
139
+
140
+ app = FastAPI()
141
+
142
+ # 添加CORS支持
143
+ app.add_middleware(
144
+ CORSMiddleware,
145
+ allow_origins=["*"],
146
+ allow_credentials=True,
147
+ allow_methods=["*"],
148
+ allow_headers=["*"],
149
+ )
150
+
151
+ @app.post("/predict")
152
+ async def api_predict(request: Request):
153
+ """API预测端点"""
154
+ try:
155
+ data = await request.json()
156
+ vector = data.get("vector")
157
+
158
+ if not vector or not isinstance(vector, list):
159
+ return JSONResponse(
160
+ status_code=400,
161
+ content={"status": "error", "message": "无效输入: 需要向量数据"}
162
+ )
163
+
164
+ result = predict(vector)
165
+ return JSONResponse(content=result)
166
+
167
+ except Exception as e:
168
+ return JSONResponse(
169
+ status_code=500,
170
+ content={
171
+ "status": "error",
172
+ "message": f"服务器内部错误: {str(e)}"
173
+ }
174
+ )
175
 
176
+ # 保留Gradio界面
177
  with gr.Blocks() as demo:
178
  gr.Markdown("## 🛍 电商智能客服系统 (768维专用)")
179
  gr.Markdown("**使用CLIP-vit-large-patch14模型 | 向量维度: 768**")
 
185
  label="输入向量 (768维)",
186
  value=[[0.1]*768]
187
  )
188
+ output = gr.JSON(label="智能回答")
189
 
190
  submit_btn = gr.Button("生成回答")
191
  submit_btn.click(
 
196
 
197
  # 启动应用
198
  if __name__ == "__main__":
199
+ import uvicorn
200
+ import threading
201
+
202
+ # 启动FastAPI服务
203
+ def run_fastapi():
204
+ uvicorn.run(app, host="0.0.0.0", port=7860)
205
+
206
+ # 启动Gradio界面
207
+ def run_gradio():
208
+ demo.launch(
209
+ server_name="0.0.0.0",
210
+ server_port=7861,
211
+ share=False
212
+ )
213
+
214
  # 验证资源
215
  print("="*50)
216
  print("Space启动完成 | 准备接收请求")
 
218
  print(f"元数据记录数: {len(metadata)}")
219
  print("="*50)
220
 
221
+ # 启动两个服务
222
+ threading.Thread(target=run_fastapi, daemon=True).start()
223
+ run_gradio()