GOGO198 commited on
Commit
868518d
·
verified ·
1 Parent(s): a19d569

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -0
app.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from transformers import RagRetriever, RagTokenizer, RagSequenceForGeneration
4
+ import os
5
+
6
+ # 从环境变量加载模型
7
+ MODEL_NAME = os.getenv("MODEL_NAME", "facebook/rag-sequence-nq")
8
+ retriever = RagRetriever.from_pretrained(MODEL_NAME, index_name="custom")
9
+ tokenizer = RagTokenizer.from_pretrained(MODEL_NAME)
10
+ model = RagSequenceForGeneration.from_pretrained(MODEL_NAME, retriever=retriever)
11
+
12
+ def predict(vector):
13
+ """处理向量输入并返回答案"""
14
+ try:
15
+ # 将向量转换为适合检索的格式
16
+ vector = np.array(vector).reshape(1, -1)
17
+
18
+ # 使用 RAG 进行检索和生成
19
+ input_dict = tokenizer.prepare_seq2seq_batch(
20
+ "",
21
+ return_tensors="pt"
22
+ )
23
+ input_dict["input_ids"] = None # 使用向量而非文本
24
+ input_dict["external_vector"] = vector # 传递自定义向量
25
+
26
+ # 生成答案
27
+ outputs = model.generate(
28
+ input_ids=input_dict["input_ids"],
29
+ external_vector=input_dict["external_vector"],
30
+ max_length=200
31
+ )
32
+
33
+ # 解码结果
34
+ answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
35
+ return answer
36
+
37
+ except Exception as e:
38
+ return f"处理错误: {str(e)}"
39
+
40
+ # 创建 Gradio 接口
41
+ iface = gr.Interface(
42
+ fn=predict,
43
+ inputs=gr.Dataframe(headers=["vector"], type="array"), # 接收向量输入
44
+ outputs="text",
45
+ title="电商智能客服",
46
+ description="输入商品/订单向量获取智能回答"
47
+ )
48
+
49
+ # 启动应用
50
+ iface.launch(server_name="0.0.0.0", server_port=7860)