Update app.py
Browse files
app.py
CHANGED
@@ -6,7 +6,6 @@ import pandas as pd
|
|
6 |
import faiss
|
7 |
from huggingface_hub import hf_hub_download
|
8 |
import time
|
9 |
-
import sys
|
10 |
import json
|
11 |
|
12 |
# 创建安全缓存目录
|
@@ -15,15 +14,16 @@ os.makedirs(CACHE_DIR, exist_ok=True)
|
|
15 |
|
16 |
# 减少内存占用
|
17 |
os.environ["OMP_NUM_THREADS"] = "2"
|
18 |
-
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
19 |
|
20 |
# 全局变量
|
21 |
index = None
|
22 |
metadata = None
|
23 |
|
24 |
-
# 加载资源函数(保持不变)
|
25 |
def load_resources():
|
26 |
"""加载所有必要资源(768维专用)"""
|
|
|
|
|
27 |
# 清理残留锁文件
|
28 |
lock_files = [f for f in os.listdir(CACHE_DIR) if f.endswith('.lock')]
|
29 |
for lock_file in lock_files:
|
@@ -31,13 +31,10 @@ def load_resources():
|
|
31 |
os.remove(os.path.join(CACHE_DIR, lock_file))
|
32 |
print(f"🧹 清理锁文件: {lock_file}")
|
33 |
except: pass
|
34 |
-
|
35 |
-
global index, metadata
|
36 |
-
|
37 |
-
# 仅当资源未加载时才初始化
|
38 |
if index is None or metadata is None:
|
39 |
print("🔄 正在加载所有资源...")
|
40 |
-
|
41 |
# 加载FAISS索引(768维)
|
42 |
if index is None:
|
43 |
print("📥 正在下载FAISS索引...")
|
@@ -50,20 +47,14 @@ def load_resources():
|
|
50 |
)
|
51 |
index = faiss.read_index(INDEX_PATH)
|
52 |
|
53 |
-
# 验证索引维度
|
54 |
if index.d != 768:
|
55 |
-
raise ValueError(
|
56 |
-
|
57 |
-
# if index and not index.is_trained:
|
58 |
-
# print("🔧 训练量化索引...")
|
59 |
-
# index.train(np.random.rand(10000, 768).astype('float32'))
|
60 |
-
# print("✅ 索引训练完成")
|
61 |
-
|
62 |
print(f"✅ FAISS索引加载完成 | 维度: {index.d}")
|
63 |
except Exception as e:
|
64 |
print(f"❌ FAISS索引加载失败: {str(e)}")
|
65 |
raise
|
66 |
-
|
67 |
# 加载元数据
|
68 |
if metadata is None:
|
69 |
print("📄 正在下载元数据...")
|
@@ -86,9 +77,8 @@ def predict(vector):
|
|
86 |
"""处理768维向量输入并返回答案"""
|
87 |
start_time = time.time()
|
88 |
print(f"输入向量维度: {np.array(vector).shape}")
|
89 |
-
|
90 |
try:
|
91 |
-
# 验证输入格式
|
92 |
if not isinstance(vector, list) or len(vector) == 0:
|
93 |
error_msg = "错误:输入格式无效"
|
94 |
print(error_msg)
|
@@ -98,11 +88,10 @@ def predict(vector):
|
|
98 |
error_msg = f"错误:需要1x768的二维数组,收到{len(vector)}x{len(vector[0]) if vector else 0}"
|
99 |
print(error_msg)
|
100 |
return error_msg
|
101 |
-
|
102 |
-
# 添加实际处理逻辑
|
103 |
vector_array = np.array(vector, dtype=np.float32)
|
104 |
D, I = index.search(vector_array, k=3)
|
105 |
-
|
106 |
results = []
|
107 |
for i in range(3):
|
108 |
try:
|
@@ -112,10 +101,10 @@ def predict(vector):
|
|
112 |
except Exception as e:
|
113 |
print(f"结果处理错误: {str(e)}")
|
114 |
results.append(f"结果 {i+1}: 数据获取失败")
|
115 |
-
|
116 |
print(f"处理完成 | 耗时: {time.time()-start_time:.2f}秒")
|
117 |
return json.dumps({
|
118 |
-
"results": results
|
119 |
})
|
120 |
|
121 |
except Exception as e:
|
@@ -124,7 +113,7 @@ def predict(vector):
|
|
124 |
print(error_msg)
|
125 |
return "处理错误,请重试或联系管理员"
|
126 |
|
127 |
-
#
|
128 |
with gr.Blocks() as demo:
|
129 |
gr.Markdown("## 🛍 电商智能客服系统 (768维专用)")
|
130 |
gr.Markdown("**使用CLIP-vit-large-patch14模型 | 向量维度: 768**")
|
@@ -134,7 +123,7 @@ with gr.Blocks() as demo:
|
|
134 |
headers=["向量值"],
|
135 |
type="array",
|
136 |
label="输入向量 (768维)",
|
137 |
-
value=[[0.1]*768]
|
138 |
)
|
139 |
output = gr.Textbox(label="智能回答", lines=5)
|
140 |
|
@@ -145,36 +134,29 @@ with gr.Blocks() as demo:
|
|
145 |
outputs=output
|
146 |
)
|
147 |
|
148 |
-
#
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
value=[[0.1]*768]
|
158 |
-
),
|
159 |
-
outputs=gr.Textbox(), # 与 output 类型一致
|
160 |
-
api_name="predict" # 显式声明 API 名称
|
161 |
-
)
|
162 |
-
|
163 |
-
# 启动应用
|
164 |
if __name__ == "__main__":
|
165 |
if index is None or metadata is None:
|
166 |
load_resources()
|
167 |
|
168 |
-
# 验证
|
169 |
print("="*50)
|
170 |
print("Space启动完成 | 准备接收请求")
|
171 |
print(f"索引维度: {index.d if index else '未加载'}")
|
172 |
print(f"元数据记录: {len(metadata) if metadata is not None else 0}")
|
173 |
print("="*50)
|
174 |
|
175 |
-
#
|
176 |
-
|
177 |
server_name="0.0.0.0",
|
178 |
server_port=7860,
|
179 |
ssr_mode=False
|
180 |
-
)
|
|
|
6 |
import faiss
|
7 |
from huggingface_hub import hf_hub_download
|
8 |
import time
|
|
|
9 |
import json
|
10 |
|
11 |
# 创建安全缓存目录
|
|
|
14 |
|
15 |
# 减少内存占用
|
16 |
os.environ["OMP_NUM_THREADS"] = "2"
|
17 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
18 |
|
19 |
# 全局变量
|
20 |
index = None
|
21 |
metadata = None
|
22 |
|
|
|
23 |
def load_resources():
|
24 |
"""加载所有必要资源(768维专用)"""
|
25 |
+
global index, metadata
|
26 |
+
|
27 |
# 清理残留锁文件
|
28 |
lock_files = [f for f in os.listdir(CACHE_DIR) if f.endswith('.lock')]
|
29 |
for lock_file in lock_files:
|
|
|
31 |
os.remove(os.path.join(CACHE_DIR, lock_file))
|
32 |
print(f"🧹 清理锁文件: {lock_file}")
|
33 |
except: pass
|
34 |
+
|
|
|
|
|
|
|
35 |
if index is None or metadata is None:
|
36 |
print("🔄 正在加载所有资源...")
|
37 |
+
|
38 |
# 加载FAISS索引(768维)
|
39 |
if index is None:
|
40 |
print("📥 正在下载FAISS索引...")
|
|
|
47 |
)
|
48 |
index = faiss.read_index(INDEX_PATH)
|
49 |
|
|
|
50 |
if index.d != 768:
|
51 |
+
raise ValueError("❌ 索引维度错误:预期768维")
|
52 |
+
|
|
|
|
|
|
|
|
|
|
|
53 |
print(f"✅ FAISS索引加载完成 | 维度: {index.d}")
|
54 |
except Exception as e:
|
55 |
print(f"❌ FAISS索引加载失败: {str(e)}")
|
56 |
raise
|
57 |
+
|
58 |
# 加载元数据
|
59 |
if metadata is None:
|
60 |
print("📄 正在下载元数据...")
|
|
|
77 |
"""处理768维向量输入并返回答案"""
|
78 |
start_time = time.time()
|
79 |
print(f"输入向量维度: {np.array(vector).shape}")
|
80 |
+
|
81 |
try:
|
|
|
82 |
if not isinstance(vector, list) or len(vector) == 0:
|
83 |
error_msg = "错误:输入格式无效"
|
84 |
print(error_msg)
|
|
|
88 |
error_msg = f"错误:需要1x768的二维数组,收到{len(vector)}x{len(vector[0]) if vector else 0}"
|
89 |
print(error_msg)
|
90 |
return error_msg
|
91 |
+
|
|
|
92 |
vector_array = np.array(vector, dtype=np.float32)
|
93 |
D, I = index.search(vector_array, k=3)
|
94 |
+
|
95 |
results = []
|
96 |
for i in range(3):
|
97 |
try:
|
|
|
101 |
except Exception as e:
|
102 |
print(f"结果处理错误: {str(e)}")
|
103 |
results.append(f"结果 {i+1}: 数据获取失败")
|
104 |
+
|
105 |
print(f"处理完成 | 耗时: {time.time()-start_time:.2f}秒")
|
106 |
return json.dumps({
|
107 |
+
"results": results
|
108 |
})
|
109 |
|
110 |
except Exception as e:
|
|
|
113 |
print(error_msg)
|
114 |
return "处理错误,请重试或联系管理员"
|
115 |
|
116 |
+
# 创建Blocks应用
|
117 |
with gr.Blocks() as demo:
|
118 |
gr.Markdown("## 🛍 电商智能客服系统 (768维专用)")
|
119 |
gr.Markdown("**使用CLIP-vit-large-patch14模型 | 向量维度: 768**")
|
|
|
123 |
headers=["向量值"],
|
124 |
type="array",
|
125 |
label="输入向量 (768维)",
|
126 |
+
value=[[0.1]*768]
|
127 |
)
|
128 |
output = gr.Textbox(label="智能回答", lines=5)
|
129 |
|
|
|
134 |
outputs=output
|
135 |
)
|
136 |
|
137 |
+
# 暴露为API
|
138 |
+
demo.expose_api(
|
139 |
+
fn=predict,
|
140 |
+
input=vector_input,
|
141 |
+
output=output,
|
142 |
+
api_name="predict"
|
143 |
+
)
|
144 |
+
|
145 |
+
# 在Blocks内部加载资源
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
if __name__ == "__main__":
|
147 |
if index is None or metadata is None:
|
148 |
load_resources()
|
149 |
|
150 |
+
# 验证API
|
151 |
print("="*50)
|
152 |
print("Space启动完成 | 准备接收请求")
|
153 |
print(f"索引维度: {index.d if index else '未加载'}")
|
154 |
print(f"元数据记录: {len(metadata) if metadata is not None else 0}")
|
155 |
print("="*50)
|
156 |
|
157 |
+
# 启动应用
|
158 |
+
demo.launch(
|
159 |
server_name="0.0.0.0",
|
160 |
server_port=7860,
|
161 |
ssr_mode=False
|
162 |
+
)
|