File size: 4,534 Bytes
fb6f6d9
630d3f4
fdfcf53
4020981
a06a315
cb8213b
630d3f4
7c119cb
908d31d
cb8213b
86b4310
cb8213b
86b4310
cb8213b
 
 
 
 
908d31d
fb6f6d9
908d31d
 
 
 
fb6f6d9
 
 
7c119cb
908d31d
b2369fc
908d31d
a06a315
908d31d
 
630d3f4
a06a315
 
 
5b5abf5
908d31d
fb6f6d9
908d31d
 
 
 
 
5b5abf5
908d31d
 
309abbd
908d31d
 
4bf350f
630d3f4
 
908d31d
 
 
 
630d3f4
908d31d
fb6f6d9
 
908d31d
 
 
 
7c119cb
908d31d
 
 
 
 
 
fb6f6d9
86b4310
 
908d31d
 
86b4310
908d31d
 
 
 
86b4310
908d31d
86b4310
908d31d
 
fb6f6d9
908d31d
 
 
 
 
 
 
 
 
cb8213b
908d31d
 
86b4310
908d31d
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import streamlit as st
import random
from langchain_community.llms import HuggingFaceHub
from langchain_community.embeddings import SentenceTransformerEmbeddings
from langchain_community.vectorstores import FAISS
from datasets import load_dataset
from transformers import pipeline

# 使用 進擊的巨人 数据集
try:
    converter = pipeline("translation_zh_tw_zh_cn")
    dataset = load_dataset("rorubyy/attack_on_titan_wiki_chinese")
    answer_list = [converter(example["Answer"])[0]["translation_text"] for example in dataset["train"]]

except Exception as e:
    st.error(f"读取数据集失败:{e}")
    st.stop()

# 构建向量数据库 (如果需要,仅构建一次)
try:
    with st.spinner("正在读取数据库..."):
        embeddings = SentenceTransformerEmbeddings(model_name="all-mpnet-base-v2")
        db = FAISS.from_texts(answer_list, embeddings)
        st.success("数据库读取完成!")
except Exception as e:
    st.error(f"向量数据库构建失败:{e}")
    st.stop()

# 问答函数
def answer_question(repo_id, temperature, max_length, question):
    # 初始化 Gemma 模型
    try:
        with st.spinner("正在初始化 Gemma 模型..."):
            llm = HuggingFaceHub(repo_id=repo_id, model_kwargs={"temperature": temperature, "max_length": max_length})
            st.success("Gemma 模型初始化完成!")
    except Exception as e:
        st.error(f"Gemma 模型加载失败:{e}")
        st.stop()

    # 获取答案
    try:
        with st.spinner("正在筛选本地数据集..."):
            question_embedding = embeddings.embed_query(question)
            question_embedding_str = " ".join(map(str, question_embedding))
            # print('question_embedding: ' + question_embedding_str)
            docs_and_scores = db.similarity_search_with_score(question_embedding_str)

            context = "\n".join([doc.page_content for doc, _ in docs_and_scores])
            print('context: ' + context)

            prompt = f"请根据以下知识库回答问题:\n{context}\n问题:{question}"
            print('prompt: ' + prompt)

            st.success("本地数据集筛选完成!")

        with st.spinner("正在生成答案..."):
            answer = llm.invoke(prompt)
            # 去掉 prompt 的内容
            answer = answer.replace(prompt, "").strip()
            st.success("答案已经生成!")
        return {"prompt": prompt, "answer": answer}
    except Exception as e:
        st.error(f"问答过程出错:{e}")
        return {"prompt": "", "answer": "An error occurred during the answering process."}

# Streamlit 界面
st.title("進擊的巨人 知识库问答系统")

col1, col2 = st.columns(2)
with col1:
    gemma = st.selectbox("repo-id", ("google/gemma-2-9b-it", "google/gemma-2-2b-it", "google/recurrentgemma-2b-it"), 2)
with col2:
    temperature = st.number_input("temperature", value=1.0)
    max_length = st.number_input("max_length", value=1024)

st.divider()

col3, col4 = st.columns(2)
with col3:
    if st.button("使用原数据集中的随机问题"):
        dataset_size = len(dataset["train"])
        random_index = random.randint(0, dataset_size - 1)
        # 读取随机问题
        random_question = dataset["train"][random_index]["Question"]
        random_question = converter(random_question)[0]["translation_text"]
        origin_answer = dataset["train"][random_index]["Answer"]
        origin_answer = converter(origin_answer)[0]["translation_text"]
        print('[]' + str(random_index) + '/' + str(dataset_size) + ']random_question: ' + random_question)
        print('origin_answer: ' + origin_answer)

        st.write("随机问题:")
        st.write(random_question)
        st.write("原始答案:")
        st.write(origin_answer)
        result = answer_question(gemma, float(temperature), int(max_length), random_question)
        print('prompt: ' + result["prompt"])
        print('answer: ' + result["answer"])
        st.write("生成答案:")
        st.write(result["answer"])

with col4:
    question = st.text_area("请输入问题", "Gemma 有哪些特点?")
    if st.button("提交输入的问题"):
        if not question:
            st.warning("请输入问题!")
        else:
            result = answer_question(gemma, float(temperature), int(max_length), question)
            print('prompt: ' + result["prompt"])
            print('answer: ' + result["answer"])
            st.write("生成答案:")
            st.write(result["answer"])