File size: 1,836 Bytes
868518d 080637e 7c338f3 868518d dbbaf71 868518d 7c338f3 4dd7f38 7c338f3 868518d 080637e dbbaf71 080637e dbbaf71 080637e 868518d dbbaf71 080637e dbbaf71 080637e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
import gradio as gr
import numpy as np
import os
from sentence_transformers import SentenceTransformer
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
# 使用更轻量的模型 - 添加 token 参数
model_name = "all-MiniLM-L6-v2"
token = os.getenv("HF_TOKEN") # 从环境变量获取令牌
model = SentenceTransformer(model_name, use_auth_token=token) if token else None
def predict(vector):
# 加载本地索引
retriever = RagRetriever.from_pretrained(
"facebook/rag-sequence-nq",
index_name="custom",
index_paths=["rag_index.faiss"]
)
# 检索相关文档
docs = retriever.retrieve(vector)
# 生成答案
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
model = RagSequenceForGeneration.from_pretrained(
"facebook/rag-sequence-nq",
torch_dtype=torch.float16
)
inputs = tokenizer.prepare_seq2seq_batch(
[vector],
return_tensors="pt"
)
outputs = model.generate(input_ids=inputs["input_ids"])
return tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
# 创建更简单的接口
with gr.Blocks() as demo:
gr.Markdown("## 🛍️ 电商智能客服系统")
# 添加模型状态显示
model_status = gr.Markdown(f"模型状态: {'已加载' if model else '未加载'}")
with gr.Row():
vector_input = gr.Dataframe(
headers=["vector"],
type="array",
label="输入向量 (384维)"
)
output = gr.Textbox(label="智能回答")
submit_btn = gr.Button("生成回答")
submit_btn.click(
fn=predict,
inputs=vector_input,
outputs=output
)
# 启动应用
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
) |