File size: 1,656 Bytes
868518d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import gradio as gr
import numpy as np
from transformers import RagRetriever, RagTokenizer, RagSequenceForGeneration
import os

# 从环境变量加载模型
MODEL_NAME = os.getenv("MODEL_NAME", "facebook/rag-sequence-nq")
retriever = RagRetriever.from_pretrained(MODEL_NAME, index_name="custom")
tokenizer = RagTokenizer.from_pretrained(MODEL_NAME)
model = RagSequenceForGeneration.from_pretrained(MODEL_NAME, retriever=retriever)

def predict(vector):
    """处理向量输入并返回答案"""
    try:
        # 将向量转换为适合检索的格式
        vector = np.array(vector).reshape(1, -1)
        
        # 使用 RAG 进行检索和生成
        input_dict = tokenizer.prepare_seq2seq_batch(
            "", 
            return_tensors="pt"
        )
        input_dict["input_ids"] = None  # 使用向量而非文本
        input_dict["external_vector"] = vector  # 传递自定义向量
        
        # 生成答案
        outputs = model.generate(
            input_ids=input_dict["input_ids"],
            external_vector=input_dict["external_vector"],
            max_length=200
        )
        
        # 解码结果
        answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
        return answer
        
    except Exception as e:
        return f"处理错误: {str(e)}"

# 创建 Gradio 接口
iface = gr.Interface(
    fn=predict,
    inputs=gr.Dataframe(headers=["vector"], type="array"),  # 接收向量输入
    outputs="text",
    title="电商智能客服",
    description="输入商品/订单向量获取智能回答"
)

# 启动应用
iface.launch(server_name="0.0.0.0", server_port=7860)