GOGO198 commited on
Commit
49bc88a
·
verified ·
1 Parent(s): b6472df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -60
app.py CHANGED
@@ -1,58 +1,125 @@
1
  import gradio as gr
2
  import numpy as np
3
- from sentence_transformers import SentenceTransformer
4
- from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
5
  import os
 
 
 
6
  from huggingface_hub import hf_hub_download
7
  import faiss
 
8
 
9
- # 设置内存交换参数
10
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:32"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- # 设置默认线程数
13
- torch.set_num_threads(1) # 减少内存占用
 
 
 
 
 
 
 
 
14
 
15
- # 使用更轻量的模型 - 添加 token 参数
16
- model_name = "all-MiniLM-L6-v2"
17
- token = os.getenv("HF_TOKEN") # 从环境变量获取令牌
18
- model = SentenceTransformer(model_name, use_auth_token=token) if token else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  def predict(vector):
21
- # 加载本地索引
22
- retriever = RagRetriever.from_pretrained(
23
- "facebook/rag-sequence-nq",
24
- index_name="custom",
25
- index_paths=["rag_index.faiss"]
26
- )
27
-
28
- # 检索相关文档
29
- docs = retriever.retrieve(vector)
30
-
31
- # 生成答案
32
- tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
33
- model = RagSequenceForGeneration.from_pretrained(
34
- "facebook/rag-sequence-nq",
35
- torch_dtype=torch.float16
36
- )
37
- inputs = tokenizer.prepare_seq2seq_batch(
38
- [vector],
39
- return_tensors="pt"
40
- )
41
- outputs = model.generate(input_ids=inputs["input_ids"])
42
- return tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- # 创建更简单的接口
45
  with gr.Blocks() as demo:
46
- gr.Markdown("## 🛍️ 电商智能客服系统")
47
-
48
- # 添加模型状态显示
49
- model_status = gr.Markdown(f"模型状态: {'已加载' if model else '未加载'}")
50
 
51
  with gr.Row():
52
  vector_input = gr.Dataframe(
53
- headers=["vector"],
54
  type="array",
55
- label="输入向量 (384维)"
 
56
  )
57
  output = gr.Textbox(label="智能回答")
58
 
@@ -63,26 +130,14 @@ with gr.Blocks() as demo:
63
  outputs=output
64
  )
65
 
66
- # 在应用启动时下载索引
67
- INDEX_PATH = hf_hub_download(
68
- repo_id="GOGO198/GOGO_rag_index",
69
- filename="faiss_index.bin",
70
- cache_dir="/data"
71
- )
72
-
73
- METADATA_PATH = hf_hub_download(
74
- repo_id="GOGO198/GOGO_rag_index",
75
- filename="metadata.csv",
76
- cache_dir="/data"
77
- )
78
-
79
- # 加载索引
80
- index = faiss.read_index(INDEX_PATH)
81
- metadata = pd.read_csv(METADATA_PATH)
82
-
83
  # 启动应用
84
- demo.launch(
85
- server_name="0.0.0.0",
86
- server_port=7860,
87
- share=False
88
- )
 
 
 
 
 
 
1
  import gradio as gr
2
  import numpy as np
 
 
3
  import os
4
+ import torch
5
+ import pandas as pd
6
+ from sentence_transformers import SentenceTransformer
7
  from huggingface_hub import hf_hub_download
8
  import faiss
9
+ import time
10
 
11
+ # 减少内存占用
12
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:32"
13
+ torch.set_num_threads(1)
14
+
15
+ # 初始化空模型
16
+ model = None
17
+ index = None
18
+ metadata = None
19
+ tokenizer = None
20
+ retriever = None
21
+
22
+ def load_resources():
23
+ """按需加载资源"""
24
+ global model, index, metadata, tokenizer, retriever
25
+
26
+ # 仅当需要时加载
27
+ if model is None:
28
+ print("正在加载句子嵌入模型...")
29
+ token = os.getenv("HF_TOKEN")
30
+ model = SentenceTransformer("all-MiniLM-L6-v2", use_auth_token=token)
31
+ print("句子模型加载完成")
32
 
33
+ if index is None:
34
+ print("正在下载FAISS索引...")
35
+ INDEX_PATH = hf_hub_download(
36
+ repo_id="GOGO198/GOGO_rag_index",
37
+ filename="faiss_index.bin",
38
+ cache_dir="/data",
39
+ use_auth_token=os.getenv("HF_TOKEN")
40
+ )
41
+ index = faiss.read_index(INDEX_PATH)
42
+ print("FAISS索引加载完成")
43
 
44
+ if metadata is None:
45
+ print("正在下载元数据...")
46
+ METADATA_PATH = hf_hub_download(
47
+ repo_id="GOGO198/GOGO_rag_index",
48
+ filename="metadata.csv",
49
+ cache_dir="/data",
50
+ use_auth_token=os.getenv("HF_TOKEN")
51
+ )
52
+ metadata = pd.read_csv(METADATA_PATH)
53
+ print("元数据加载完成")
54
+
55
+ # 延迟加载RAG组件
56
+ if tokenizer is None:
57
+ tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
58
+
59
+ if retriever is None:
60
+ retriever = RagRetriever.from_pretrained(
61
+ "facebook/rag-sequence-nq",
62
+ index_name="custom",
63
+ index_paths=["/data/rag_index.faiss"] # 使用已加载的索引
64
+ )
65
 
66
  def predict(vector):
67
+ """处理向量输入并返回答案"""
68
+ try:
69
+ # start_time = time.time()
70
+ # load_resources() # 确保资源已加载
71
+
72
+ # # 转换为numpy数组
73
+ # vector = np.array(vector, dtype=np.float32).reshape(1, -1)
74
+
75
+ # # 检索相关文档
76
+ # docs = retriever.retrieve(vector)
77
+
78
+ # # 提取前3个相关文档
79
+ # context = "\n".join([doc["text"] for doc in docs[:3]])
80
+
81
+ # # 生成答案 (使用更轻量级的生成模型)
82
+ # inputs = tokenizer(
83
+ # f"基于以下信息回答问题: {context}\n问题: 用户查询向量",
84
+ # return_tensors="pt"
85
+ # )
86
+
87
+ # # 使用轻量级生成模型
88
+ # from transformers import AutoModelForCausalLM
89
+ # generator = AutoModelForCausalLM.from_pretrained("gpt2")
90
+ # outputs = generator.generate(
91
+ # inputs["input_ids"],
92
+ # max_length=200,
93
+ # num_return_sequences=1
94
+ # )
95
+
96
+ # answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
97
+
98
+ # print(f"处理时间: {time.time() - start_time:.2f}秒")
99
+ # return answer
100
+
101
+ # 如果遇到资源瓶颈,使用纯检索方案
102
+ vector = np.array(vector, dtype=np.float32).reshape(1, -1)
103
+
104
+ # FAISS 搜索
105
+ D, I = index.search(vector, k=3)
106
+
107
+ # 获取最相关结果
108
+ result = metadata.iloc[I[0][0]]
109
+ return f"最相关结果: {result['title']}\n描述: {result['description'][:100]}..."
110
+ except Exception as e:
111
+ return f"处理错误: {str(e)}"
112
 
113
+ # 创建简化接口
114
  with gr.Blocks() as demo:
115
+ gr.Markdown("## 🛍️ 电商智能客服系统 (轻量版)")
 
 
 
116
 
117
  with gr.Row():
118
  vector_input = gr.Dataframe(
119
+ headers=["向量值"],
120
  type="array",
121
+ label="输入向量 (384维)",
122
+ value=[[0.1]*384] # 默认值
123
  )
124
  output = gr.Textbox(label="智能回答")
125
 
 
130
  outputs=output
131
  )
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  # 启动应用
134
+ if __name__ == "__main__":
135
+ # 先加载必要资源
136
+ print("启动前预加载资源...")
137
+ load_resources()
138
+
139
+ demo.launch(
140
+ server_name="0.0.0.0",
141
+ server_port=7860,
142
+ share=False
143
+ )