Update app.py
Browse files
app.py
CHANGED
@@ -1,15 +1,14 @@
|
|
1 |
-
import gradio as gr
|
2 |
import numpy as np
|
3 |
import os
|
4 |
import torch
|
5 |
import pandas as pd
|
6 |
-
from sentence_transformers import SentenceTransformer
|
7 |
-
from huggingface_hub import hf_hub_download
|
8 |
import faiss
|
|
|
|
|
9 |
import time
|
10 |
-
import pathlib
|
11 |
|
12 |
-
#
|
13 |
CACHE_DIR = "/home/user/cache"
|
14 |
os.makedirs(CACHE_DIR, exist_ok=True)
|
15 |
|
@@ -17,105 +16,97 @@ os.makedirs(CACHE_DIR, exist_ok=True)
|
|
17 |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:32"
|
18 |
torch.set_num_threads(1)
|
19 |
|
20 |
-
#
|
21 |
-
model = None
|
22 |
index = None
|
23 |
metadata = None
|
24 |
-
|
25 |
-
|
26 |
|
27 |
def load_resources():
|
28 |
-
"""
|
29 |
-
global
|
30 |
|
31 |
-
#
|
32 |
-
if
|
33 |
-
print("
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
if index is None:
|
39 |
print("正在下载FAISS索引...")
|
40 |
INDEX_PATH = hf_hub_download(
|
41 |
repo_id="GOGO198/GOGO_rag_index",
|
42 |
filename="faiss_index.bin",
|
43 |
cache_dir=CACHE_DIR,
|
44 |
-
|
45 |
)
|
46 |
index = faiss.read_index(INDEX_PATH)
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
49 |
if metadata is None:
|
50 |
print("正在下载元数据...")
|
51 |
METADATA_PATH = hf_hub_download(
|
52 |
repo_id="GOGO198/GOGO_rag_index",
|
53 |
filename="metadata.csv",
|
54 |
cache_dir=CACHE_DIR,
|
55 |
-
|
56 |
)
|
57 |
metadata = pd.read_csv(METADATA_PATH)
|
58 |
print("元数据加载完成")
|
59 |
|
60 |
def predict(vector):
|
61 |
-
"""
|
62 |
try:
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
# # 转换为numpy数组
|
67 |
-
# vector = np.array(vector, dtype=np.float32).reshape(1, -1)
|
68 |
|
69 |
-
#
|
70 |
-
# docs = retriever.retrieve(vector)
|
71 |
-
|
72 |
-
# # 提取前3个相关文档
|
73 |
-
# context = "\n".join([doc["text"] for doc in docs[:3]])
|
74 |
-
|
75 |
-
# # 生成答案 (使用更轻量级的生成模型)
|
76 |
-
# inputs = tokenizer(
|
77 |
-
# f"基于以下信息回答问题: {context}\n问题: 用户查询向量",
|
78 |
-
# return_tensors="pt"
|
79 |
-
# )
|
80 |
-
|
81 |
-
# # 使用轻量级生成模型
|
82 |
-
# from transformers import AutoModelForCausalLM
|
83 |
-
# generator = AutoModelForCausalLM.from_pretrained("gpt2")
|
84 |
-
# outputs = generator.generate(
|
85 |
-
# inputs["input_ids"],
|
86 |
-
# max_length=200,
|
87 |
-
# num_return_sequences=1
|
88 |
-
# )
|
89 |
-
|
90 |
-
# answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
91 |
-
|
92 |
-
# print(f"处理时间: {time.time() - start_time:.2f}秒")
|
93 |
-
# return answer
|
94 |
-
|
95 |
-
# 如果遇到资源瓶颈,使用纯检索方案1
|
96 |
vector = np.array(vector, dtype=np.float32).reshape(1, -1)
|
|
|
|
|
97 |
|
98 |
# FAISS 搜索
|
99 |
D, I = index.search(vector, k=3)
|
100 |
|
101 |
# 获取最相关结果
|
102 |
-
|
103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
except Exception as e:
|
105 |
return f"处理错误: {str(e)}"
|
106 |
|
107 |
# 创建简化接口
|
108 |
with gr.Blocks() as demo:
|
109 |
-
gr.Markdown("##
|
|
|
110 |
|
111 |
with gr.Row():
|
112 |
vector_input = gr.Dataframe(
|
113 |
headers=["向量值"],
|
114 |
type="array",
|
115 |
-
label="输入向量 (
|
116 |
-
value=[[0.1]*
|
117 |
)
|
118 |
-
output = gr.Textbox(label="智能回答")
|
119 |
|
120 |
submit_btn = gr.Button("生成回答")
|
121 |
submit_btn.click(
|
@@ -127,14 +118,14 @@ with gr.Blocks() as demo:
|
|
127 |
# 启动应用
|
128 |
if __name__ == "__main__":
|
129 |
# 确保缓存目录存在
|
|
|
130 |
pathlib.Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
|
131 |
|
132 |
-
#
|
133 |
print("启动前预加载资源...")
|
134 |
load_resources()
|
135 |
|
136 |
demo.launch(
|
137 |
server_name="0.0.0.0",
|
138 |
-
server_port=7860
|
139 |
-
|
140 |
-
)
|
|
|
1 |
+
import gradio as gr
|
2 |
import numpy as np
|
3 |
import os
|
4 |
import torch
|
5 |
import pandas as pd
|
|
|
|
|
6 |
import faiss
|
7 |
+
from huggingface_hub import hf_hub_download
|
8 |
+
from transformers import CLIPProcessor, CLIPModel
|
9 |
import time
|
|
|
10 |
|
11 |
+
# 创建安全缓存目录
|
12 |
CACHE_DIR = "/home/user/cache"
|
13 |
os.makedirs(CACHE_DIR, exist_ok=True)
|
14 |
|
|
|
16 |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:32"
|
17 |
torch.set_num_threads(1)
|
18 |
|
19 |
+
# 全局变量
|
|
|
20 |
index = None
|
21 |
metadata = None
|
22 |
+
clip_model = None
|
23 |
+
clip_processor = None
|
24 |
|
25 |
def load_resources():
|
26 |
+
"""加载所有必要资源(768维专用)"""
|
27 |
+
global index, metadata, clip_model, clip_processor
|
28 |
|
29 |
+
# 加载CLIP模型(用于维度验证)
|
30 |
+
if clip_model is None:
|
31 |
+
print("正在加载CLIP模型...")
|
32 |
+
clip_model = CLIPModel.from_pretrained(
|
33 |
+
"openai/clip-vit-large-patch14",
|
34 |
+
cache_dir=CACHE_DIR
|
35 |
+
)
|
36 |
+
clip_processor = CLIPProcessor.from_pretrained(
|
37 |
+
"openai/clip-vit-large-patch14",
|
38 |
+
cache_dir=CACHE_DIR
|
39 |
+
)
|
40 |
+
print("CLIP模型加载完成")
|
41 |
+
|
42 |
+
# 加载FAISS索引(768维)
|
43 |
if index is None:
|
44 |
print("正在下载FAISS索引...")
|
45 |
INDEX_PATH = hf_hub_download(
|
46 |
repo_id="GOGO198/GOGO_rag_index",
|
47 |
filename="faiss_index.bin",
|
48 |
cache_dir=CACHE_DIR,
|
49 |
+
token=os.getenv("HF_TOKEN")
|
50 |
)
|
51 |
index = faiss.read_index(INDEX_PATH)
|
52 |
+
|
53 |
+
# 验证索引维度
|
54 |
+
if index.d != 768:
|
55 |
+
raise ValueError(f"索引维度错误:预期768维,实际{index.d}维")
|
56 |
+
print("FAISS索引加载完成 | 维度: 768")
|
57 |
+
|
58 |
+
# 加载元数据
|
59 |
if metadata is None:
|
60 |
print("正在下载元数据...")
|
61 |
METADATA_PATH = hf_hub_download(
|
62 |
repo_id="GOGO198/GOGO_rag_index",
|
63 |
filename="metadata.csv",
|
64 |
cache_dir=CACHE_DIR,
|
65 |
+
token=os.getenv("HF_TOKEN")
|
66 |
)
|
67 |
metadata = pd.read_csv(METADATA_PATH)
|
68 |
print("元数据加载完成")
|
69 |
|
70 |
def predict(vector):
|
71 |
+
"""处理768维向量输入并返回答案"""
|
72 |
try:
|
73 |
+
start_time = time.time()
|
74 |
+
load_resources() # 确保资源已加载
|
|
|
|
|
|
|
75 |
|
76 |
+
# 转换为numpy数组并验证维度
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
vector = np.array(vector, dtype=np.float32).reshape(1, -1)
|
78 |
+
if vector.shape[1] != 768:
|
79 |
+
return f"维度错误:预期768维,收到{vector.shape[1]}维"
|
80 |
|
81 |
# FAISS 搜索
|
82 |
D, I = index.search(vector, k=3)
|
83 |
|
84 |
# 获取最相关结果
|
85 |
+
results = []
|
86 |
+
for i in range(3):
|
87 |
+
result = metadata.iloc[I[0][i]]
|
88 |
+
results.append(f"相关结果 {i+1}: {result['source']} | 相似度: {1/(1+D[0][i]):.2f}")
|
89 |
+
|
90 |
+
response = "\n".join(results)
|
91 |
+
print(f"处理时间: {time.time() - start_time:.2f}秒")
|
92 |
+
return response
|
93 |
+
|
94 |
except Exception as e:
|
95 |
return f"处理错误: {str(e)}"
|
96 |
|
97 |
# 创建简化接口
|
98 |
with gr.Blocks() as demo:
|
99 |
+
gr.Markdown("## 🛍 电商智能客服系统 (768维专用)")
|
100 |
+
gr.Markdown("**使用CLIP-vit-large-patch14模型 | 向量维度: 768**")
|
101 |
|
102 |
with gr.Row():
|
103 |
vector_input = gr.Dataframe(
|
104 |
headers=["向量值"],
|
105 |
type="array",
|
106 |
+
label="输入向量 (768维)",
|
107 |
+
value=[[0.1]*768] # 768维默认值
|
108 |
)
|
109 |
+
output = gr.Textbox(label="智能回答", lines=5)
|
110 |
|
111 |
submit_btn = gr.Button("生成回答")
|
112 |
submit_btn.click(
|
|
|
118 |
# 启动应用
|
119 |
if __name__ == "__main__":
|
120 |
# 确保缓存目录存在
|
121 |
+
import pathlib
|
122 |
pathlib.Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
|
123 |
|
124 |
+
# 预加载资源
|
125 |
print("启动前预加载资源...")
|
126 |
load_resources()
|
127 |
|
128 |
demo.launch(
|
129 |
server_name="0.0.0.0",
|
130 |
+
server_port=7860
|
131 |
+
)
|
|