Update app.py
Browse files
app.py
CHANGED
@@ -81,68 +81,45 @@ def load_resources():
|
|
81 |
load_resources()
|
82 |
|
83 |
def predict(vector):
|
84 |
-
"""处理768维向量输入并返回答案"""
|
85 |
-
start_time = time.time()
|
86 |
-
print(f"输入向量类型: {type(vector)}, 长度: {len(vector) if hasattr(vector, '__len__') else 'N/A'}")
|
87 |
-
|
88 |
try:
|
89 |
-
#
|
90 |
-
if isinstance(vector, list) and len(vector) > 0 and isinstance(vector[0], list):
|
91 |
-
# 已经是正确的二维数组格式
|
92 |
-
pass
|
93 |
-
elif isinstance(vector, list) and all(isinstance(x, float) for x in vector):
|
94 |
-
# 一维列表输入 (API调用)
|
95 |
-
#vector = [vector]
|
96 |
-
vector = np.array(vector).reshape(1, -1).astype('float32')
|
97 |
-
else:
|
98 |
-
error_msg = "错误:输入格式无效"
|
99 |
-
print(error_msg)
|
100 |
-
return {"error": error_msg}
|
101 |
-
|
102 |
-
if len(vector) != 1 or len(vector[0]) != 768:
|
103 |
-
error_msg = f"错误:需要1x768的二维数组,收到{len(vector)}x{len(vector[0]) if vector else 0}"
|
104 |
-
print(error_msg)
|
105 |
-
return {"error": error_msg}
|
106 |
-
|
107 |
-
# 添加实际处理逻辑
|
108 |
-
#vector_array = np.array(vector, dtype=np.float32)
|
109 |
-
# 转换格式
|
110 |
query_vector = np.array(vector).astype('float32').reshape(1, -1)
|
|
|
|
|
111 |
D, I = index.search(query_vector, k=3)
|
112 |
-
|
|
|
113 |
results = []
|
114 |
for i in range(3):
|
115 |
try:
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
"
|
121 |
-
|
|
|
|
|
122 |
except Exception as e:
|
123 |
print(f"结果处理错误: {str(e)}")
|
124 |
-
results.append({
|
125 |
-
|
126 |
-
|
|
|
127 |
|
128 |
-
# 返回标准化的JSON格式
|
129 |
return {
|
130 |
-
|
131 |
-
|
132 |
-
"source": [metadata.iloc[i]["source"] for i in I[0]],
|
133 |
-
"content": [metadata.iloc[i]["content"] for i in I[0]],
|
134 |
-
"confidence": [float(1/(1+d)) for d in D[0]]
|
135 |
}
|
136 |
-
|
137 |
except Exception as e:
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
# 创建FastAPI应用
|
147 |
app = FastAPI()
|
148 |
|
|
|
81 |
load_resources()
|
82 |
|
83 |
def predict(vector):
|
|
|
|
|
|
|
|
|
84 |
try:
|
85 |
+
# 确保向量格式正确
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
query_vector = np.array(vector).astype('float32').reshape(1, -1)
|
87 |
+
|
88 |
+
# FAISS搜索
|
89 |
D, I = index.search(query_vector, k=3)
|
90 |
+
|
91 |
+
# 构建结果 - 使用安全的列名访问
|
92 |
results = []
|
93 |
for i in range(3):
|
94 |
try:
|
95 |
+
idx = I[0][i]
|
96 |
+
result = {
|
97 |
+
"source": metadata.iloc[idx]["source"],
|
98 |
+
# 安全访问content字段
|
99 |
+
"content": metadata.iloc[idx].get("content", ""),
|
100 |
+
"confidence": float(1/(1+D[0][i]))
|
101 |
+
}
|
102 |
+
results.append(result)
|
103 |
except Exception as e:
|
104 |
print(f"结果处理错误: {str(e)}")
|
105 |
+
results.append({
|
106 |
+
"error": f"结果 {i+1}: 数据获取失败",
|
107 |
+
"confidence": 0.0
|
108 |
+
})
|
109 |
|
|
|
110 |
return {
|
111 |
+
"status": "success",
|
112 |
+
"results": results
|
|
|
|
|
|
|
113 |
}
|
|
|
114 |
except Exception as e:
|
115 |
+
return JSONResponse(
|
116 |
+
status_code=500,
|
117 |
+
content={
|
118 |
+
"status": "error",
|
119 |
+
"message": f"服务器内部错误: {str(e)}"
|
120 |
+
}
|
121 |
+
)
|
122 |
+
|
123 |
# 创建FastAPI应用
|
124 |
app = FastAPI()
|
125 |
|