Update app.py
Browse files
app.py
CHANGED
@@ -12,6 +12,7 @@ from fastapi.responses import JSONResponse
|
|
12 |
from fastapi.middleware.cors import CORSMiddleware
|
13 |
import uvicorn
|
14 |
import threading
|
|
|
15 |
|
16 |
# 创建安全缓存目录
|
17 |
CACHE_DIR = "/home/user/cache"
|
@@ -136,6 +137,24 @@ threading.Thread(target=refresh_index, daemon=True).start()
|
|
136 |
# 确保资源在API调用前加载
|
137 |
load_resources()
|
138 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
def predict(vector):
|
140 |
try:
|
141 |
print(f"接收向量: {vector[:3]}... (长度: {len(vector)})")
|
@@ -164,56 +183,45 @@ def predict(vector):
|
|
164 |
for i in range(k):
|
165 |
try:
|
166 |
idx = I[0][i]
|
167 |
-
|
168 |
-
# # 安全处理距离值
|
169 |
-
# distance = D[0][i]
|
170 |
-
# if not np.isfinite(distance) or distance < 0:
|
171 |
-
# print(f"警告: 无效距离值 {distance},使用默认值0")
|
172 |
-
# distance = 0.0
|
173 |
-
|
174 |
-
# # 安全处理置信度值
|
175 |
-
# try:
|
176 |
-
# confidence = 1 / (1 + distance)
|
177 |
-
# except ZeroDivisionError:
|
178 |
-
# confidence = 1.0
|
179 |
-
|
180 |
-
# if not np.isfinite(confidence) or confidence < 0:
|
181 |
-
# print(f"警告: 无效置信度 {confidence},使用默认值0.5")
|
182 |
-
# confidence = 0.5
|
183 |
-
|
184 |
distance = D[0][i]
|
185 |
|
186 |
-
#
|
187 |
-
if not np.isfinite(distance):
|
188 |
-
distance = 100.0 #
|
189 |
-
|
190 |
-
#
|
191 |
-
confidence =
|
192 |
-
|
|
|
|
|
|
|
|
|
|
|
193 |
result = {
|
194 |
"source": metadata.iloc[idx]["source"],
|
195 |
"content": metadata.iloc[idx].get("content", ""),
|
196 |
-
"confidence":
|
197 |
-
"distance":
|
198 |
}
|
199 |
results.append(result)
|
200 |
except Exception as e:
|
201 |
-
|
202 |
results.append({
|
203 |
-
"error":
|
204 |
"confidence": 0.5,
|
205 |
"distance": 0.0
|
206 |
})
|
207 |
|
208 |
return {
|
209 |
"status": "success",
|
210 |
-
"results": results
|
211 |
}
|
212 |
except Exception as e:
|
213 |
# 返回错误响应
|
214 |
return {
|
215 |
"status": "error",
|
216 |
-
"message": f"服务器内部错误: {str(e)}"
|
|
|
217 |
}
|
218 |
|
219 |
# 创建FastAPI应用
|
|
|
12 |
from fastapi.middleware.cors import CORSMiddleware
|
13 |
import uvicorn
|
14 |
import threading
|
15 |
+
import math
|
16 |
|
17 |
# 创建安全缓存目录
|
18 |
CACHE_DIR = "/home/user/cache"
|
|
|
137 |
# 确保资源在API调用前加载
|
138 |
load_resources()
|
139 |
|
140 |
+
def sanitize_floats(obj):
|
141 |
+
if isinstance(obj, float):
|
142 |
+
if math.isnan(obj) or math.isinf(obj):
|
143 |
+
return 0.0 # 替换非法值为默认值
|
144 |
+
return obj
|
145 |
+
elif isinstance(obj, dict):
|
146 |
+
return {k: sanitize_floats(v) for k, v in obj.items()}
|
147 |
+
elif isinstance(obj, list):
|
148 |
+
return [sanitize_floats(x) for x in obj]
|
149 |
+
else:
|
150 |
+
return obj
|
151 |
+
|
152 |
+
# 在返回结果前调用清理器
|
153 |
+
return {
|
154 |
+
"status": "success",
|
155 |
+
"results": sanitize_floats(results) # 深度清理
|
156 |
+
}
|
157 |
+
|
158 |
def predict(vector):
|
159 |
try:
|
160 |
print(f"接收向量: {vector[:3]}... (长度: {len(vector)})")
|
|
|
183 |
for i in range(k):
|
184 |
try:
|
185 |
idx = I[0][i]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
distance = D[0][i]
|
187 |
|
188 |
+
# 修复1:处理非法浮点数
|
189 |
+
if not np.isfinite(distance) or distance < 0:
|
190 |
+
distance = 100.0 # 设置为安全阈值
|
191 |
+
|
192 |
+
# 修复2:安全计算置信度 (0-1范围)
|
193 |
+
confidence = 1 / (1 + distance)
|
194 |
+
confidence = max(0.0, min(1.0, confidence)) # 钳制到[0,1]
|
195 |
+
|
196 |
+
# 修复3:强制转换为合法浮点
|
197 |
+
distance = float(distance)
|
198 |
+
confidence = float(confidence)
|
199 |
+
|
200 |
result = {
|
201 |
"source": metadata.iloc[idx]["source"],
|
202 |
"content": metadata.iloc[idx].get("content", ""),
|
203 |
+
"confidence": confidence,
|
204 |
+
"distance": distance
|
205 |
}
|
206 |
results.append(result)
|
207 |
except Exception as e:
|
208 |
+
# 确保异常结果也符合JSON规范
|
209 |
results.append({
|
210 |
+
"error": str(e),
|
211 |
"confidence": 0.5,
|
212 |
"distance": 0.0
|
213 |
})
|
214 |
|
215 |
return {
|
216 |
"status": "success",
|
217 |
+
"results": sanitize_floats(results)
|
218 |
}
|
219 |
except Exception as e:
|
220 |
# 返回错误响应
|
221 |
return {
|
222 |
"status": "error",
|
223 |
+
"message": f"服务器内部错误: {str(e)}",
|
224 |
+
"details": sanitize_floats({"trace": traceback.format_exc()})
|
225 |
}
|
226 |
|
227 |
# 创建FastAPI应用
|