GOGO198 commited on
Commit
a9e075d
·
verified ·
1 Parent(s): 4120f6a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -65
app.py CHANGED
@@ -1,15 +1,14 @@
1
- import gradio as gr
2
  import numpy as np
3
  import os
4
  import torch
5
  import pandas as pd
6
- from sentence_transformers import SentenceTransformer
7
- from huggingface_hub import hf_hub_download
8
  import faiss
 
 
9
  import time
10
- import pathlib
11
 
12
- # 创建安全缓存目录(在用户目录下)
13
  CACHE_DIR = "/home/user/cache"
14
  os.makedirs(CACHE_DIR, exist_ok=True)
15
 
@@ -17,105 +16,97 @@ os.makedirs(CACHE_DIR, exist_ok=True)
17
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:32"
18
  torch.set_num_threads(1)
19
 
20
- # 初始化空模型
21
- model = None
22
  index = None
23
  metadata = None
24
- tokenizer = None
25
- retriever = None
26
 
27
  def load_resources():
28
- """按需加载资源"""
29
- global model, index, metadata, tokenizer, retriever
30
 
31
- # 仅当需要时加载
32
- if model is None:
33
- print("正在加载句子嵌入模型...")
34
- token = os.getenv("HF_TOKEN")
35
- model = SentenceTransformer("all-MiniLM-L6-v2", use_auth_token=token, cache_folder=CACHE_DIR)
36
- print("句子模型加载完成")
37
-
 
 
 
 
 
 
 
38
  if index is None:
39
  print("正在下载FAISS索引...")
40
  INDEX_PATH = hf_hub_download(
41
  repo_id="GOGO198/GOGO_rag_index",
42
  filename="faiss_index.bin",
43
  cache_dir=CACHE_DIR,
44
- use_auth_token=os.getenv("HF_TOKEN")
45
  )
46
  index = faiss.read_index(INDEX_PATH)
47
- print("FAISS索引加载完成")
48
-
 
 
 
 
 
49
  if metadata is None:
50
  print("正在下载元数据...")
51
  METADATA_PATH = hf_hub_download(
52
  repo_id="GOGO198/GOGO_rag_index",
53
  filename="metadata.csv",
54
  cache_dir=CACHE_DIR,
55
- use_auth_token=os.getenv("HF_TOKEN")
56
  )
57
  metadata = pd.read_csv(METADATA_PATH)
58
  print("元数据加载完成")
59
 
60
  def predict(vector):
61
- """处理向量输入并返回答案"""
62
  try:
63
- # start_time = time.time()
64
- # load_resources() # 确保资源已加载
65
-
66
- # # 转换为numpy数组
67
- # vector = np.array(vector, dtype=np.float32).reshape(1, -1)
68
 
69
- # # 检索相关文档
70
- # docs = retriever.retrieve(vector)
71
-
72
- # # 提取前3个相关文档
73
- # context = "\n".join([doc["text"] for doc in docs[:3]])
74
-
75
- # # 生成答案 (使用更轻量级的生成模型)
76
- # inputs = tokenizer(
77
- # f"基于以下信息回答问题: {context}\n问题: 用户查询向量",
78
- # return_tensors="pt"
79
- # )
80
-
81
- # # 使用轻量级生成模型
82
- # from transformers import AutoModelForCausalLM
83
- # generator = AutoModelForCausalLM.from_pretrained("gpt2")
84
- # outputs = generator.generate(
85
- # inputs["input_ids"],
86
- # max_length=200,
87
- # num_return_sequences=1
88
- # )
89
-
90
- # answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
91
-
92
- # print(f"处理时间: {time.time() - start_time:.2f}秒")
93
- # return answer
94
-
95
- # 如果遇到资源瓶颈,使用纯检索方案1
96
  vector = np.array(vector, dtype=np.float32).reshape(1, -1)
 
 
97
 
98
  # FAISS 搜索
99
  D, I = index.search(vector, k=3)
100
 
101
  # 获取最相关结果
102
- result = metadata.iloc[I[0][0]]
103
- return f"最相关结果: {result['title']}\n描述: {result['description'][:100]}..."
 
 
 
 
 
 
 
104
  except Exception as e:
105
  return f"处理错误: {str(e)}"
106
 
107
  # 创建简化接口
108
  with gr.Blocks() as demo:
109
- gr.Markdown("## 🛍️ 电商智能客服系统 (轻量版)")
 
110
 
111
  with gr.Row():
112
  vector_input = gr.Dataframe(
113
  headers=["向量值"],
114
  type="array",
115
- label="输入向量 (384维)",
116
- value=[[0.1]*384] # 默认值
117
  )
118
- output = gr.Textbox(label="智能回答")
119
 
120
  submit_btn = gr.Button("生成回答")
121
  submit_btn.click(
@@ -127,14 +118,14 @@ with gr.Blocks() as demo:
127
  # 启动应用
128
  if __name__ == "__main__":
129
  # 确保缓存目录存在
 
130
  pathlib.Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
131
 
132
- # 先加载必要资源
133
  print("启动前预加载资源...")
134
  load_resources()
135
 
136
  demo.launch(
137
  server_name="0.0.0.0",
138
- server_port=7860,
139
- share=False
140
- )
 
1
+ import gradio as gr
2
  import numpy as np
3
  import os
4
  import torch
5
  import pandas as pd
 
 
6
  import faiss
7
+ from huggingface_hub import hf_hub_download
8
+ from transformers import CLIPProcessor, CLIPModel
9
  import time
 
10
 
11
+ # 创建安全缓存目录
12
  CACHE_DIR = "/home/user/cache"
13
  os.makedirs(CACHE_DIR, exist_ok=True)
14
 
 
16
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:32"
17
  torch.set_num_threads(1)
18
 
19
+ # 全局变量
 
20
  index = None
21
  metadata = None
22
+ clip_model = None
23
+ clip_processor = None
24
 
25
  def load_resources():
26
+ """加载所有必要资源(768维专用)"""
27
+ global index, metadata, clip_model, clip_processor
28
 
29
+ # 加载CLIP模型(用于维度验证)
30
+ if clip_model is None:
31
+ print("正在加载CLIP模型...")
32
+ clip_model = CLIPModel.from_pretrained(
33
+ "openai/clip-vit-large-patch14",
34
+ cache_dir=CACHE_DIR
35
+ )
36
+ clip_processor = CLIPProcessor.from_pretrained(
37
+ "openai/clip-vit-large-patch14",
38
+ cache_dir=CACHE_DIR
39
+ )
40
+ print("CLIP模型加载完成")
41
+
42
+ # 加载FAISS索引(768维)
43
  if index is None:
44
  print("正在下载FAISS索引...")
45
  INDEX_PATH = hf_hub_download(
46
  repo_id="GOGO198/GOGO_rag_index",
47
  filename="faiss_index.bin",
48
  cache_dir=CACHE_DIR,
49
+ token=os.getenv("HF_TOKEN")
50
  )
51
  index = faiss.read_index(INDEX_PATH)
52
+
53
+ # 验证索引维度
54
+ if index.d != 768:
55
+ raise ValueError(f"索引维度错误:预期768维,实际{index.d}维")
56
+ print("FAISS索引加载完成 | 维度: 768")
57
+
58
+ # 加载元数据
59
  if metadata is None:
60
  print("正在下载元数据...")
61
  METADATA_PATH = hf_hub_download(
62
  repo_id="GOGO198/GOGO_rag_index",
63
  filename="metadata.csv",
64
  cache_dir=CACHE_DIR,
65
+ token=os.getenv("HF_TOKEN")
66
  )
67
  metadata = pd.read_csv(METADATA_PATH)
68
  print("元数据加载完成")
69
 
70
  def predict(vector):
71
+ """处理768维向量输入并返回答案"""
72
  try:
73
+ start_time = time.time()
74
+ load_resources() # 确保资源已加载
 
 
 
75
 
76
+ # 转换为numpy数组并验证维度
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  vector = np.array(vector, dtype=np.float32).reshape(1, -1)
78
+ if vector.shape[1] != 768:
79
+ return f"维度错误:预期768维,收到{vector.shape[1]}维"
80
 
81
  # FAISS 搜索
82
  D, I = index.search(vector, k=3)
83
 
84
  # 获取最相关结果
85
+ results = []
86
+ for i in range(3):
87
+ result = metadata.iloc[I[0][i]]
88
+ results.append(f"相关结果 {i+1}: {result['source']} | 相似度: {1/(1+D[0][i]):.2f}")
89
+
90
+ response = "\n".join(results)
91
+ print(f"处理时间: {time.time() - start_time:.2f}秒")
92
+ return response
93
+
94
  except Exception as e:
95
  return f"处理错误: {str(e)}"
96
 
97
  # 创建简化接口
98
  with gr.Blocks() as demo:
99
+ gr.Markdown("## 🛍 电商智能客服系统 (768维专用)")
100
+ gr.Markdown("**使用CLIP-vit-large-patch14模型 | 向量维度: 768**")
101
 
102
  with gr.Row():
103
  vector_input = gr.Dataframe(
104
  headers=["向量值"],
105
  type="array",
106
+ label="输入向量 (768维)",
107
+ value=[[0.1]*768] # 768维默认值
108
  )
109
+ output = gr.Textbox(label="智能回答", lines=5)
110
 
111
  submit_btn = gr.Button("生成回答")
112
  submit_btn.click(
 
118
  # 启动应用
119
  if __name__ == "__main__":
120
  # 确保缓存目录存在
121
+ import pathlib
122
  pathlib.Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
123
 
124
+ # 预加载资源
125
  print("启动前预加载资源...")
126
  load_resources()
127
 
128
  demo.launch(
129
  server_name="0.0.0.0",
130
+ server_port=7860
131
+ )