Update app.py
Browse files
app.py
CHANGED
@@ -80,7 +80,7 @@ def predict(vector):
|
|
80 |
print(f"输入向量类型: {type(vector)}, 长度: {len(vector) if hasattr(vector, '__len__') else 'N/A'}")
|
81 |
|
82 |
try:
|
83 |
-
#
|
84 |
if isinstance(vector, list) and len(vector) > 0 and isinstance(vector[0], list):
|
85 |
# 已经是正确的二维数组格式
|
86 |
pass
|
@@ -90,12 +90,12 @@ def predict(vector):
|
|
90 |
else:
|
91 |
error_msg = "错误:输入格式无效"
|
92 |
print(error_msg)
|
93 |
-
return error_msg
|
94 |
|
95 |
if len(vector) != 1 or len(vector[0]) != 768:
|
96 |
error_msg = f"错误:需要1x768的二维数组,收到{len(vector)}x{len(vector[0]) if vector else 0}"
|
97 |
print(error_msg)
|
98 |
-
return error_msg
|
99 |
|
100 |
# 添加实际处理逻辑
|
101 |
vector_array = np.array(vector, dtype=np.float32)
|
@@ -106,31 +106,74 @@ def predict(vector):
|
|
106 |
try:
|
107 |
result = metadata.iloc[I[0][i]]
|
108 |
confidence = 1/(1+D[0][i])
|
109 |
-
results.append(
|
|
|
|
|
|
|
110 |
except Exception as e:
|
111 |
print(f"结果处理错误: {str(e)}")
|
112 |
-
results.append(f"结果 {i+1}: 数据获取失败")
|
113 |
|
114 |
print(f"处理完成 | 耗时: {time.time()-start_time:.2f}秒")
|
115 |
|
116 |
-
#
|
117 |
-
return
|
118 |
-
"
|
119 |
-
|
120 |
-
|
121 |
-
})
|
122 |
-
]
|
123 |
-
})
|
124 |
|
125 |
except Exception as e:
|
126 |
import traceback
|
127 |
error_msg = f"处理错误: {str(e)}\n{traceback.format_exc()}"
|
128 |
print(error_msg)
|
129 |
-
return
|
130 |
-
"
|
131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
|
133 |
-
#
|
134 |
with gr.Blocks() as demo:
|
135 |
gr.Markdown("## 🛍 电商智能客服系统 (768维专用)")
|
136 |
gr.Markdown("**使用CLIP-vit-large-patch14模型 | 向量维度: 768**")
|
@@ -142,7 +185,7 @@ with gr.Blocks() as demo:
|
|
142 |
label="输入向量 (768维)",
|
143 |
value=[[0.1]*768]
|
144 |
)
|
145 |
-
output = gr.
|
146 |
|
147 |
submit_btn = gr.Button("生成回答")
|
148 |
submit_btn.click(
|
@@ -153,6 +196,21 @@ with gr.Blocks() as demo:
|
|
153 |
|
154 |
# 启动应用
|
155 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
# 验证资源
|
157 |
print("="*50)
|
158 |
print("Space启动完成 | 准备接收请求")
|
@@ -160,8 +218,6 @@ if __name__ == "__main__":
|
|
160 |
print(f"元数据记录数: {len(metadata)}")
|
161 |
print("="*50)
|
162 |
|
163 |
-
#
|
164 |
-
|
165 |
-
|
166 |
-
server_port=7860
|
167 |
-
)
|
|
|
80 |
print(f"输入向量类型: {type(vector)}, 长度: {len(vector) if hasattr(vector, '__len__') else 'N/A'}")
|
81 |
|
82 |
try:
|
83 |
+
# 处理不同输入格式
|
84 |
if isinstance(vector, list) and len(vector) > 0 and isinstance(vector[0], list):
|
85 |
# 已经是正确的二维数组格式
|
86 |
pass
|
|
|
90 |
else:
|
91 |
error_msg = "错误:输入格式无效"
|
92 |
print(error_msg)
|
93 |
+
return {"error": error_msg}
|
94 |
|
95 |
if len(vector) != 1 or len(vector[0]) != 768:
|
96 |
error_msg = f"错误:需要1x768的二维数组,收到{len(vector)}x{len(vector[0]) if vector else 0}"
|
97 |
print(error_msg)
|
98 |
+
return {"error": error_msg}
|
99 |
|
100 |
# 添加实际处理逻辑
|
101 |
vector_array = np.array(vector, dtype=np.float32)
|
|
|
106 |
try:
|
107 |
result = metadata.iloc[I[0][i]]
|
108 |
confidence = 1/(1+D[0][i])
|
109 |
+
results.append({
|
110 |
+
"source": result['source'],
|
111 |
+
"confidence": round(float(confidence), 2)
|
112 |
+
})
|
113 |
except Exception as e:
|
114 |
print(f"结果处理错误: {str(e)}")
|
115 |
+
results.append({"error": f"结果 {i+1}: 数据获取失败"})
|
116 |
|
117 |
print(f"处理完成 | 耗时: {time.time()-start_time:.2f}秒")
|
118 |
|
119 |
+
# 返回标准化的JSON格式
|
120 |
+
return {
|
121 |
+
"status": "success",
|
122 |
+
"results": results
|
123 |
+
}
|
|
|
|
|
|
|
124 |
|
125 |
except Exception as e:
|
126 |
import traceback
|
127 |
error_msg = f"处理错误: {str(e)}\n{traceback.format_exc()}"
|
128 |
print(error_msg)
|
129 |
+
return {
|
130 |
+
"status": "error",
|
131 |
+
"message": "处理错误,请重试或联系管理员"
|
132 |
+
}
|
133 |
+
|
134 |
+
# 创建FastAPI应用 (替代Gradio Blocks)
|
135 |
+
import fastapi
|
136 |
+
from fastapi import FastAPI, Request
|
137 |
+
from fastapi.responses import JSONResponse
|
138 |
+
from fastapi.middleware.cors import CORSMiddleware
|
139 |
+
|
140 |
+
app = FastAPI()
|
141 |
+
|
142 |
+
# 添加CORS支持
|
143 |
+
app.add_middleware(
|
144 |
+
CORSMiddleware,
|
145 |
+
allow_origins=["*"],
|
146 |
+
allow_credentials=True,
|
147 |
+
allow_methods=["*"],
|
148 |
+
allow_headers=["*"],
|
149 |
+
)
|
150 |
+
|
151 |
+
@app.post("/predict")
|
152 |
+
async def api_predict(request: Request):
|
153 |
+
"""API预测端点"""
|
154 |
+
try:
|
155 |
+
data = await request.json()
|
156 |
+
vector = data.get("vector")
|
157 |
+
|
158 |
+
if not vector or not isinstance(vector, list):
|
159 |
+
return JSONResponse(
|
160 |
+
status_code=400,
|
161 |
+
content={"status": "error", "message": "无效输入: 需要向量数据"}
|
162 |
+
)
|
163 |
+
|
164 |
+
result = predict(vector)
|
165 |
+
return JSONResponse(content=result)
|
166 |
+
|
167 |
+
except Exception as e:
|
168 |
+
return JSONResponse(
|
169 |
+
status_code=500,
|
170 |
+
content={
|
171 |
+
"status": "error",
|
172 |
+
"message": f"服务器内部错误: {str(e)}"
|
173 |
+
}
|
174 |
+
)
|
175 |
|
176 |
+
# 保留Gradio界面
|
177 |
with gr.Blocks() as demo:
|
178 |
gr.Markdown("## 🛍 电商智能客服系统 (768维专用)")
|
179 |
gr.Markdown("**使用CLIP-vit-large-patch14模型 | 向量维度: 768**")
|
|
|
185 |
label="输入向量 (768维)",
|
186 |
value=[[0.1]*768]
|
187 |
)
|
188 |
+
output = gr.JSON(label="智能回答")
|
189 |
|
190 |
submit_btn = gr.Button("生成回答")
|
191 |
submit_btn.click(
|
|
|
196 |
|
197 |
# 启动应用
|
198 |
if __name__ == "__main__":
|
199 |
+
import uvicorn
|
200 |
+
import threading
|
201 |
+
|
202 |
+
# 启动FastAPI服务
|
203 |
+
def run_fastapi():
|
204 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|
205 |
+
|
206 |
+
# 启动Gradio界面
|
207 |
+
def run_gradio():
|
208 |
+
demo.launch(
|
209 |
+
server_name="0.0.0.0",
|
210 |
+
server_port=7861,
|
211 |
+
share=False
|
212 |
+
)
|
213 |
+
|
214 |
# 验证资源
|
215 |
print("="*50)
|
216 |
print("Space启动完成 | 准备接收请求")
|
|
|
218 |
print(f"元数据记录数: {len(metadata)}")
|
219 |
print("="*50)
|
220 |
|
221 |
+
# 启动两个服务
|
222 |
+
threading.Thread(target=run_fastapi, daemon=True).start()
|
223 |
+
run_gradio()
|
|
|
|