GOGO_rag / app.py
GOGO198's picture
Update app.py
4dd7f38 verified
raw
history blame
1.84 kB
import gradio as gr
import numpy as np
import os
from sentence_transformers import SentenceTransformer
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
# 使用更轻量的模型 - 添加 token 参数
model_name = "all-MiniLM-L6-v2"
token = os.getenv("HF_TOKEN") # 从环境变量获取令牌
model = SentenceTransformer(model_name, use_auth_token=token) if token else None
def predict(vector):
# 加载本地索引
retriever = RagRetriever.from_pretrained(
"facebook/rag-sequence-nq",
index_name="custom",
index_paths=["rag_index.faiss"]
)
# 检索相关文档
docs = retriever.retrieve(vector)
# 生成答案
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
model = RagSequenceForGeneration.from_pretrained(
"facebook/rag-sequence-nq",
torch_dtype=torch.float16
)
inputs = tokenizer.prepare_seq2seq_batch(
[vector],
return_tensors="pt"
)
outputs = model.generate(input_ids=inputs["input_ids"])
return tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
# 创建更简单的接口
with gr.Blocks() as demo:
gr.Markdown("## 🛍️ 电商智能客服系统")
# 添加模型状态显示
model_status = gr.Markdown(f"模型状态: {'已加载' if model else '未加载'}")
with gr.Row():
vector_input = gr.Dataframe(
headers=["vector"],
type="array",
label="输入向量 (384维)"
)
output = gr.Textbox(label="智能回答")
submit_btn = gr.Button("生成回答")
submit_btn.click(
fn=predict,
inputs=vector_input,
outputs=output
)
# 启动应用
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)