|
import gradio as gr |
|
import numpy as np |
|
import os |
|
from sentence_transformers import SentenceTransformer |
|
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration |
|
|
|
|
|
model_name = "all-MiniLM-L6-v2" |
|
token = os.getenv("HF_TOKEN") |
|
model = SentenceTransformer(model_name, use_auth_token=token) if token else None |
|
|
|
def predict(vector): |
|
|
|
retriever = RagRetriever.from_pretrained( |
|
"facebook/rag-sequence-nq", |
|
index_name="custom", |
|
index_paths=["rag_index.faiss"] |
|
) |
|
|
|
|
|
docs = retriever.retrieve(vector) |
|
|
|
|
|
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq") |
|
model = RagSequenceForGeneration.from_pretrained( |
|
"facebook/rag-sequence-nq", |
|
torch_dtype=torch.float16 |
|
) |
|
inputs = tokenizer.prepare_seq2seq_batch( |
|
[vector], |
|
return_tensors="pt" |
|
) |
|
outputs = model.generate(input_ids=inputs["input_ids"]) |
|
return tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## 🛍️ 电商智能客服系统") |
|
|
|
|
|
model_status = gr.Markdown(f"模型状态: {'已加载' if model else '未加载'}") |
|
|
|
with gr.Row(): |
|
vector_input = gr.Dataframe( |
|
headers=["vector"], |
|
type="array", |
|
label="输入向量 (384维)" |
|
) |
|
output = gr.Textbox(label="智能回答") |
|
|
|
submit_btn = gr.Button("生成回答") |
|
submit_btn.click( |
|
fn=predict, |
|
inputs=vector_input, |
|
outputs=output |
|
) |
|
|
|
|
|
demo.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
share=False |
|
) |