|
import gradio as gr |
|
import numpy as np |
|
from transformers import RagRetriever, RagTokenizer, RagSequenceForGeneration |
|
import os |
|
|
|
|
|
MODEL_NAME = os.getenv("MODEL_NAME", "facebook/rag-sequence-nq") |
|
retriever = RagRetriever.from_pretrained(MODEL_NAME, index_name="custom") |
|
tokenizer = RagTokenizer.from_pretrained(MODEL_NAME) |
|
model = RagSequenceForGeneration.from_pretrained(MODEL_NAME, retriever=retriever) |
|
|
|
def predict(vector): |
|
"""处理向量输入并返回答案""" |
|
try: |
|
|
|
vector = np.array(vector).reshape(1, -1) |
|
|
|
|
|
input_dict = tokenizer.prepare_seq2seq_batch( |
|
"", |
|
return_tensors="pt" |
|
) |
|
input_dict["input_ids"] = None |
|
input_dict["external_vector"] = vector |
|
|
|
|
|
outputs = model.generate( |
|
input_ids=input_dict["input_ids"], |
|
external_vector=input_dict["external_vector"], |
|
max_length=200 |
|
) |
|
|
|
|
|
answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] |
|
return answer |
|
|
|
except Exception as e: |
|
return f"处理错误: {str(e)}" |
|
|
|
|
|
iface = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Dataframe(headers=["vector"], type="array"), |
|
outputs="text", |
|
title="电商智能客服", |
|
description="输入商品/订单向量获取智能回答" |
|
) |
|
|
|
|
|
iface.launch(server_name="0.0.0.0", server_port=7860) |