GOGO198 commited on
Commit
e931572
·
verified ·
1 Parent(s): ebc728e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -46
app.py CHANGED
@@ -6,7 +6,6 @@ import pandas as pd
6
  import faiss
7
  from huggingface_hub import hf_hub_download
8
  import time
9
- import sys
10
  import json
11
 
12
  # 创建安全缓存目录
@@ -15,15 +14,16 @@ os.makedirs(CACHE_DIR, exist_ok=True)
15
 
16
  # 减少内存占用
17
  os.environ["OMP_NUM_THREADS"] = "2"
18
- os.environ["TOKENIZERS_PARALLELISM"] = "false" # 防止tokenizer内存泄漏
19
 
20
  # 全局变量
21
  index = None
22
  metadata = None
23
 
24
- # 加载资源函数(保持不变)
25
  def load_resources():
26
  """加载所有必要资源(768维专用)"""
 
 
27
  # 清理残留锁文件
28
  lock_files = [f for f in os.listdir(CACHE_DIR) if f.endswith('.lock')]
29
  for lock_file in lock_files:
@@ -31,13 +31,10 @@ def load_resources():
31
  os.remove(os.path.join(CACHE_DIR, lock_file))
32
  print(f"🧹 清理锁文件: {lock_file}")
33
  except: pass
34
-
35
- global index, metadata
36
-
37
- # 仅当资源未加载时才初始化
38
  if index is None or metadata is None:
39
  print("🔄 正在加载所有资源...")
40
-
41
  # 加载FAISS索引(768维)
42
  if index is None:
43
  print("📥 正在下载FAISS索引...")
@@ -50,20 +47,14 @@ def load_resources():
50
  )
51
  index = faiss.read_index(INDEX_PATH)
52
 
53
- # 验证索引维度
54
  if index.d != 768:
55
- raise ValueError(f"❌ 索引维度错误:预期768维,实际{index.d}维")
56
-
57
- # if index and not index.is_trained:
58
- # print("🔧 训练量化索引...")
59
- # index.train(np.random.rand(10000, 768).astype('float32'))
60
- # print("✅ 索引训练完成")
61
-
62
  print(f"✅ FAISS索引加载完成 | 维度: {index.d}")
63
  except Exception as e:
64
  print(f"❌ FAISS索引加载失败: {str(e)}")
65
  raise
66
-
67
  # 加载元数据
68
  if metadata is None:
69
  print("📄 正在下载元数据...")
@@ -86,9 +77,8 @@ def predict(vector):
86
  """处理768维向量输入并返回答案"""
87
  start_time = time.time()
88
  print(f"输入向量维度: {np.array(vector).shape}")
89
-
90
  try:
91
- # 验证输入格式
92
  if not isinstance(vector, list) or len(vector) == 0:
93
  error_msg = "错误:输入格式无效"
94
  print(error_msg)
@@ -98,11 +88,10 @@ def predict(vector):
98
  error_msg = f"错误:需要1x768的二维数组,收到{len(vector)}x{len(vector[0]) if vector else 0}"
99
  print(error_msg)
100
  return error_msg
101
-
102
- # 添加实际处理逻辑
103
  vector_array = np.array(vector, dtype=np.float32)
104
  D, I = index.search(vector_array, k=3)
105
-
106
  results = []
107
  for i in range(3):
108
  try:
@@ -112,10 +101,10 @@ def predict(vector):
112
  except Exception as e:
113
  print(f"结果处理错误: {str(e)}")
114
  results.append(f"结果 {i+1}: 数据获取失败")
115
-
116
  print(f"处理完成 | 耗时: {time.time()-start_time:.2f}秒")
117
  return json.dumps({
118
- "results": results # 确保嵌套结构合法
119
  })
120
 
121
  except Exception as e:
@@ -124,7 +113,7 @@ def predict(vector):
124
  print(error_msg)
125
  return "处理错误,请重试或联系管理员"
126
 
127
- # **合并后的 Blocks 实例**
128
  with gr.Blocks() as demo:
129
  gr.Markdown("## 🛍 电商智能客服系统 (768维专用)")
130
  gr.Markdown("**使用CLIP-vit-large-patch14模型 | 向量维度: 768**")
@@ -134,7 +123,7 @@ with gr.Blocks() as demo:
134
  headers=["向量值"],
135
  type="array",
136
  label="输入向量 (768维)",
137
- value=[[0.1]*768] # 768维默认值
138
  )
139
  output = gr.Textbox(label="智能回答", lines=5)
140
 
@@ -145,36 +134,29 @@ with gr.Blocks() as demo:
145
  outputs=output
146
  )
147
 
148
- # 显式暴露 API 端点(关键修改)
149
- #demo.express_api(predict, input=vector_input, output=output, api_name="predict")
150
-
151
- # 确保 API 输入输出与界面组件完全一致
152
- api = gr.Interface(
153
- fn=predict,
154
- inputs=gr.Dataframe( # 与 vector_input 完全相同
155
- headers=["向量值"],
156
- type="array",
157
- value=[[0.1]*768]
158
- ),
159
- outputs=gr.Textbox(), # 与 output 类型一致
160
- api_name="predict" # 显式声明 API 名称
161
- )
162
-
163
- # 启动应用
164
  if __name__ == "__main__":
165
  if index is None or metadata is None:
166
  load_resources()
167
 
168
- # 验证 API 是否生��
169
  print("="*50)
170
  print("Space启动完成 | 准备接收请求")
171
  print(f"索引维度: {index.d if index else '未加载'}")
172
  print(f"元数据记录: {len(metadata) if metadata is not None else 0}")
173
  print("="*50)
174
 
175
- # 启动应用(确保 API 暴露)
176
- api.launch(
177
  server_name="0.0.0.0",
178
  server_port=7860,
179
  ssr_mode=False
180
- )
 
6
  import faiss
7
  from huggingface_hub import hf_hub_download
8
  import time
 
9
  import json
10
 
11
  # 创建安全缓存目录
 
14
 
15
  # 减少内存占用
16
  os.environ["OMP_NUM_THREADS"] = "2"
17
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
18
 
19
  # 全局变量
20
  index = None
21
  metadata = None
22
 
 
23
  def load_resources():
24
  """加载所有必要资源(768维专用)"""
25
+ global index, metadata
26
+
27
  # 清理残留锁文件
28
  lock_files = [f for f in os.listdir(CACHE_DIR) if f.endswith('.lock')]
29
  for lock_file in lock_files:
 
31
  os.remove(os.path.join(CACHE_DIR, lock_file))
32
  print(f"🧹 清理锁文件: {lock_file}")
33
  except: pass
34
+
 
 
 
35
  if index is None or metadata is None:
36
  print("🔄 正在加载所有资源...")
37
+
38
  # 加载FAISS索引(768维)
39
  if index is None:
40
  print("📥 正在下载FAISS索引...")
 
47
  )
48
  index = faiss.read_index(INDEX_PATH)
49
 
 
50
  if index.d != 768:
51
+ raise ValueError("❌ 索引维度错误:预期768维")
52
+
 
 
 
 
 
53
  print(f"✅ FAISS索引加载完成 | 维度: {index.d}")
54
  except Exception as e:
55
  print(f"❌ FAISS索引加载失败: {str(e)}")
56
  raise
57
+
58
  # 加载元数据
59
  if metadata is None:
60
  print("📄 正在下载元数据...")
 
77
  """处理768维向量输入并返回答案"""
78
  start_time = time.time()
79
  print(f"输入向量维度: {np.array(vector).shape}")
80
+
81
  try:
 
82
  if not isinstance(vector, list) or len(vector) == 0:
83
  error_msg = "错误:输入格式无效"
84
  print(error_msg)
 
88
  error_msg = f"错误:需要1x768的二维数组,收到{len(vector)}x{len(vector[0]) if vector else 0}"
89
  print(error_msg)
90
  return error_msg
91
+
 
92
  vector_array = np.array(vector, dtype=np.float32)
93
  D, I = index.search(vector_array, k=3)
94
+
95
  results = []
96
  for i in range(3):
97
  try:
 
101
  except Exception as e:
102
  print(f"结果处理错误: {str(e)}")
103
  results.append(f"结果 {i+1}: 数据获取失败")
104
+
105
  print(f"处理完成 | 耗时: {time.time()-start_time:.2f}秒")
106
  return json.dumps({
107
+ "results": results
108
  })
109
 
110
  except Exception as e:
 
113
  print(error_msg)
114
  return "处理错误,请重试或联系管理员"
115
 
116
+ # 创建Blocks应用
117
  with gr.Blocks() as demo:
118
  gr.Markdown("## 🛍 电商智能客服系统 (768维专用)")
119
  gr.Markdown("**使用CLIP-vit-large-patch14模型 | 向量维度: 768**")
 
123
  headers=["向量值"],
124
  type="array",
125
  label="输入向量 (768维)",
126
+ value=[[0.1]*768]
127
  )
128
  output = gr.Textbox(label="智能回答", lines=5)
129
 
 
134
  outputs=output
135
  )
136
 
137
+ # 暴露为API
138
+ demo.expose_api(
139
+ fn=predict,
140
+ input=vector_input,
141
+ output=output,
142
+ api_name="predict"
143
+ )
144
+
145
+ # 在Blocks内部加载资源
 
 
 
 
 
 
 
146
  if __name__ == "__main__":
147
  if index is None or metadata is None:
148
  load_resources()
149
 
150
+ # 验证API
151
  print("="*50)
152
  print("Space启动完成 | 准备接收请求")
153
  print(f"索引维度: {index.d if index else '未加载'}")
154
  print(f"元数据记录: {len(metadata) if metadata is not None else 0}")
155
  print("="*50)
156
 
157
+ # 启动应用
158
+ demo.launch(
159
  server_name="0.0.0.0",
160
  server_port=7860,
161
  ssr_mode=False
162
+ )