GOGO198 commited on
Commit
7c338f3
·
verified ·
1 Parent(s): 7401ffc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -18
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
  import numpy as np
3
  import os
4
  from sentence_transformers import SentenceTransformer
 
5
 
6
  # 使用更轻量的模型 - 添加 token 参数
7
  model_name = "all-MiniLM-L6-v2"
@@ -9,24 +10,25 @@ token = os.getenv("HF_TOKEN") # 从环境变量获取令牌
9
  model = SentenceTransformer(model_name, use_auth_token=token) if token else None
10
 
11
  def predict(vector):
12
- """处理向量输入并返回答案"""
13
- try:
14
- # 检查模型是否加载成功
15
- if model is None:
16
- return "错误:模型未加载,请检查 HF_TOKEN 设置"
17
-
18
- # 转换为 numpy 数组
19
- vector = np.array(vector).reshape(1, -1)
20
-
21
- # 验证向量维度
22
- if vector.shape[1] == 384: # MiniLM 向量维度
23
- # 实际向量处理逻辑
24
- return "这是基于向量生成的回答示例。实际使用时应连接您的 RAG 模型"
25
- else:
26
- return f"错误:接收的向量维度是 {vector.shape[1]},但预期是 384"
27
-
28
- except Exception as e:
29
- return f"处理错误: {str(e)}"
 
30
 
31
  # 创建更简单的接口
32
  with gr.Blocks() as demo:
 
2
  import numpy as np
3
  import os
4
  from sentence_transformers import SentenceTransformer
5
+ from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
6
 
7
  # 使用更轻量的模型 - 添加 token 参数
8
  model_name = "all-MiniLM-L6-v2"
 
10
  model = SentenceTransformer(model_name, use_auth_token=token) if token else None
11
 
12
  def predict(vector):
13
+ # 加载本地索引
14
+ retriever = RagRetriever.from_pretrained(
15
+ "facebook/rag-sequence-nq",
16
+ index_name="custom",
17
+ index_paths=["rag_index.faiss"]
18
+ )
19
+
20
+ # 检索相关文档
21
+ docs = retriever.retrieve(vector)
22
+
23
+ # 生成答案
24
+ tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
25
+ model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq")
26
+ inputs = tokenizer.prepare_seq2seq_batch(
27
+ [vector],
28
+ return_tensors="pt"
29
+ )
30
+ outputs = model.generate(input_ids=inputs["input_ids"])
31
+ return tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
32
 
33
  # 创建更简单的接口
34
  with gr.Blocks() as demo: