GOGO198 commited on
Commit
758f12d
·
verified ·
1 Parent(s): 7057632

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -50
app.py CHANGED
@@ -81,68 +81,45 @@ def load_resources():
81
  load_resources()
82
 
83
  def predict(vector):
84
- """处理768维向量输入并返回答案"""
85
- start_time = time.time()
86
- print(f"输入向量类型: {type(vector)}, 长度: {len(vector) if hasattr(vector, '__len__') else 'N/A'}")
87
-
88
  try:
89
- # 处理不同输入格式
90
- if isinstance(vector, list) and len(vector) > 0 and isinstance(vector[0], list):
91
- # 已经是正确的二维数组格式
92
- pass
93
- elif isinstance(vector, list) and all(isinstance(x, float) for x in vector):
94
- # 一维列表输入 (API调用)
95
- #vector = [vector]
96
- vector = np.array(vector).reshape(1, -1).astype('float32')
97
- else:
98
- error_msg = "错误:输入格式无效"
99
- print(error_msg)
100
- return {"error": error_msg}
101
-
102
- if len(vector) != 1 or len(vector[0]) != 768:
103
- error_msg = f"错误:需要1x768的二维数组,收到{len(vector)}x{len(vector[0]) if vector else 0}"
104
- print(error_msg)
105
- return {"error": error_msg}
106
-
107
- # 添加实际处理逻辑
108
- #vector_array = np.array(vector, dtype=np.float32)
109
- # 转换格式
110
  query_vector = np.array(vector).astype('float32').reshape(1, -1)
 
 
111
  D, I = index.search(query_vector, k=3)
112
-
 
113
  results = []
114
  for i in range(3):
115
  try:
116
- result = metadata.iloc[I[0][i]]
117
- confidence = 1/(1+D[0][i])
118
- results.append({
119
- "source": result['source'],
120
- "confidence": round(float(confidence), 2)
121
- })
 
 
122
  except Exception as e:
123
  print(f"结果处理错误: {str(e)}")
124
- results.append({"error": f"结果 {i+1}: 数据获取失败"})
125
-
126
- print(f"处理完成 | 耗时: {time.time()-start_time:.2f}秒")
 
127
 
128
- # 返回标准化的JSON格式
129
  return {
130
- # "status": "success",
131
- # "results": results
132
- "source": [metadata.iloc[i]["source"] for i in I[0]],
133
- "content": [metadata.iloc[i]["content"] for i in I[0]],
134
- "confidence": [float(1/(1+d)) for d in D[0]]
135
  }
136
-
137
  except Exception as e:
138
- import traceback
139
- error_msg = f"处理错误: {str(e)}\n{traceback.format_exc()}"
140
- print(error_msg)
141
- return {
142
- "status": "error",
143
- "message": "处理错误,请重试或联系管理员"
144
- }
145
-
146
  # 创建FastAPI应用
147
  app = FastAPI()
148
 
 
81
  load_resources()
82
 
83
  def predict(vector):
 
 
 
 
84
  try:
85
+ # 确保向量格式正确
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  query_vector = np.array(vector).astype('float32').reshape(1, -1)
87
+
88
+ # FAISS搜索
89
  D, I = index.search(query_vector, k=3)
90
+
91
+ # 构建结果 - 使用安全的列名访问
92
  results = []
93
  for i in range(3):
94
  try:
95
+ idx = I[0][i]
96
+ result = {
97
+ "source": metadata.iloc[idx]["source"],
98
+ # 安全访问content字段
99
+ "content": metadata.iloc[idx].get("content", ""),
100
+ "confidence": float(1/(1+D[0][i]))
101
+ }
102
+ results.append(result)
103
  except Exception as e:
104
  print(f"结果处理错误: {str(e)}")
105
+ results.append({
106
+ "error": f"结果 {i+1}: 数据获取失败",
107
+ "confidence": 0.0
108
+ })
109
 
 
110
  return {
111
+ "status": "success",
112
+ "results": results
 
 
 
113
  }
 
114
  except Exception as e:
115
+ return JSONResponse(
116
+ status_code=500,
117
+ content={
118
+ "status": "error",
119
+ "message": f"服务器内部错误: {str(e)}"
120
+ }
121
+ )
122
+
123
  # 创建FastAPI应用
124
  app = FastAPI()
125