Update app.py
Browse files
app.py
CHANGED
@@ -25,6 +25,59 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
25 |
index = None
|
26 |
metadata = None
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
def load_resources():
|
29 |
"""加载所有必要资源(768维专用)"""
|
30 |
global index, metadata
|
@@ -77,27 +130,45 @@ def load_resources():
|
|
77 |
print(f"❌ 元数据加载失败: {str(e)}")
|
78 |
raise
|
79 |
|
|
|
|
|
|
|
80 |
# 确保资源在API调用前加载
|
81 |
load_resources()
|
82 |
|
83 |
def predict(vector):
|
84 |
try:
|
|
|
|
|
85 |
# 确保向量格式正确
|
86 |
query_vector = np.array(vector).astype('float32').reshape(1, -1)
|
87 |
|
88 |
-
#
|
89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
-
# 构建结果
|
92 |
results = []
|
93 |
-
for i in range(
|
94 |
try:
|
95 |
idx = I[0][i]
|
96 |
result = {
|
97 |
"source": metadata.iloc[idx]["source"],
|
98 |
-
# 安全访问content字段
|
99 |
"content": metadata.iloc[idx].get("content", ""),
|
100 |
-
"confidence": float(1/(1+D[0][i]))
|
|
|
101 |
}
|
102 |
results.append(result)
|
103 |
except Exception as e:
|
@@ -165,6 +236,6 @@ if __name__ == "__main__":
|
|
165 |
print(f"索引维度: {index.d}")
|
166 |
print(f"元数据记录数: {len(metadata)}")
|
167 |
print("="*50)
|
168 |
-
|
169 |
# 只启动FastAPI服务
|
170 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
25 |
index = None
|
26 |
metadata = None
|
27 |
|
28 |
+
# 新增全局变量
|
29 |
+
last_updated = 0
|
30 |
+
index_refresh_interval = 300 # 5分钟刷新一次
|
31 |
+
|
32 |
+
# 新增索引刷新函数
|
33 |
+
def refresh_index():
|
34 |
+
global index, metadata, last_updated
|
35 |
+
|
36 |
+
while True:
|
37 |
+
try:
|
38 |
+
# 检查是否有更新
|
39 |
+
current_time = time.time()
|
40 |
+
if current_time - last_updated > index_refresh_interval:
|
41 |
+
print("🔄 检查索引更新...")
|
42 |
+
|
43 |
+
# 获取最新元数据
|
44 |
+
METADATA_PATH = hf_hub_download(
|
45 |
+
repo_id="GOGO198/GOGO_rag_index",
|
46 |
+
filename="metadata.csv",
|
47 |
+
cache_dir=CACHE_DIR,
|
48 |
+
token=os.getenv("HF_TOKEN"),
|
49 |
+
force_download=True # 强制更新
|
50 |
+
)
|
51 |
+
|
52 |
+
# 检查文件修改时间
|
53 |
+
file_mtime = os.path.getmtime(METADATA_PATH)
|
54 |
+
if file_mtime > last_updated:
|
55 |
+
print("📥 检测到新索引,重新加载...")
|
56 |
+
|
57 |
+
# 重新加载索引
|
58 |
+
INDEX_PATH = hf_hub_download(
|
59 |
+
repo_id="GOGO198/GOGO_rag_index",
|
60 |
+
filename="faiss_index.bin",
|
61 |
+
cache_dir=CACHE_DIR,
|
62 |
+
token=os.getenv("HF_TOKEN"),
|
63 |
+
force_download=True
|
64 |
+
)
|
65 |
+
new_index = faiss.read_index(INDEX_PATH)
|
66 |
+
new_metadata = pd.read_csv(METADATA_PATH)
|
67 |
+
|
68 |
+
# 原子替换
|
69 |
+
index = new_index
|
70 |
+
metadata = new_metadata
|
71 |
+
last_updated = file_mtime
|
72 |
+
|
73 |
+
print(f"✅ 索引更新完成 | 记录数: {len(metadata)}")
|
74 |
+
|
75 |
+
except Exception as e:
|
76 |
+
print(f"索引更新失败: {str(e)}")
|
77 |
+
|
78 |
+
# 每30秒检查一次
|
79 |
+
time.sleep(30)
|
80 |
+
|
81 |
def load_resources():
|
82 |
"""加载所有必要资源(768维专用)"""
|
83 |
global index, metadata
|
|
|
130 |
print(f"❌ 元数据加载失败: {str(e)}")
|
131 |
raise
|
132 |
|
133 |
+
# 启动索引刷新线程
|
134 |
+
threading.Thread(target=refresh_index, daemon=True).start()
|
135 |
+
|
136 |
# 确保资源在API调用前加载
|
137 |
load_resources()
|
138 |
|
139 |
def predict(vector):
|
140 |
try:
|
141 |
+
print(f"接收向量: {vector[:3]}... (长度: {len(vector)})")
|
142 |
+
|
143 |
# 确保向量格式正确
|
144 |
query_vector = np.array(vector).astype('float32').reshape(1, -1)
|
145 |
|
146 |
+
# 动态结果数量 (最大不超过总文档数)
|
147 |
+
k = min(3, index.ntotal)
|
148 |
+
if k == 0:
|
149 |
+
return {
|
150 |
+
"status": "success",
|
151 |
+
"results": [],
|
152 |
+
"message": "索引为空"
|
153 |
+
}
|
154 |
+
|
155 |
+
print(f"执行FAISS搜索 | k={k}")
|
156 |
+
D, I = index.search(query_vector, k=k)
|
157 |
+
|
158 |
+
# 打印搜索结果
|
159 |
+
print(f"搜索结果索引: {I[0]}")
|
160 |
+
print(f"搜索距离: {D[0]}")
|
161 |
|
162 |
+
# 构建结果
|
163 |
results = []
|
164 |
+
for i in range(k):
|
165 |
try:
|
166 |
idx = I[0][i]
|
167 |
result = {
|
168 |
"source": metadata.iloc[idx]["source"],
|
|
|
169 |
"content": metadata.iloc[idx].get("content", ""),
|
170 |
+
"confidence": float(1/(1+D[0][i])),
|
171 |
+
"distance": float(D[0][i])
|
172 |
}
|
173 |
results.append(result)
|
174 |
except Exception as e:
|
|
|
236 |
print(f"索引维度: {index.d}")
|
237 |
print(f"元数据记录数: {len(metadata)}")
|
238 |
print("="*50)
|
239 |
+
|
240 |
# 只启动FastAPI服务
|
241 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|