import streamlit as st import random from langchain_community.llms import HuggingFaceHub from langchain_community.embeddings import SentenceTransformerEmbeddings from langchain_community.vectorstores import FAISS from datasets import load_dataset from opencc import OpenCC # 使用 進擊的巨人 数据集 # 原数据集是是繁体中文,为了调试方便,将其转换成简体中文之后使用 if "dataset_loaded" not in st.session_state: st.session_state.dataset_loaded = False st.session_state.data_list = [] st.session_state.answer_list = [] if not st.session_state.dataset_loaded: try: with st.spinner("正在读取数据库..."): converter = OpenCC('tw2s') # 'tw2s.json' 表示繁体中文到简体中文的转换 dataset = load_dataset("rorubyy/attack_on_titan_wiki_chinese") for example in dataset["train"]: converted_answer = converter.convert(example["Answer"]) converted_question = converter.convert(example["Question"]) st.session_state.answer_list.append(converted_answer) st.session_state.data_list.append({"Question": converted_question, "Answer": converted_answer}) st.success("数据库读取完成!") print("数据库读取完成!") except Exception as e: st.error(f"读取数据集失败:{e}") st.stop() st.session_state.dataset_loaded = True # 构建向量数据库 (如果需要,仅构建一次) if "vector_created" not in st.session_state: st.session_state.vector_created = False if not st.session_state.vector_created: try: with st.spinner("正在构建向量数据库..."): # all-mpnet-base-v2 是一个由 Sentence Transformers 库提供的预训练模型, # 专门用于生成高质量的句子嵌入(sentence embeddings)。 # all-mpnet-base-v2 在多个自然语言处理任务上表现出色,包括语义相似度计算、 # 文本检索、聚类等。它能够有效地捕捉句子的语义信息,并生成具有代表性的向量表示。 st.session_state.embeddings = SentenceTransformerEmbeddings(model_name="all-mpnet-base-v2") st.session_state.db = FAISS.from_texts(st.session_state.answer_list, st.session_state.embeddings) st.success("向量数据库构建完成!") print("向量数据库构建完成!") except Exception as e: st.error(f"向量数据库构建失败:{e}") st.stop() st.session_state.vector_created = True # 问答函数 if "repo_id" not in st.session_state: st.session_state.repo_id = '' if "temperature" not in st.session_state: st.session_state.temperature = '' if "max_length" not in st.session_state: st.session_state.max_length = '' def answer_question(repo_id, temperature, max_length, question): # 初始化 Gemma 模型 if repo_id != st.session_state.repo_id or temperature != st.session_state.temperature or max_length != st.session_state.max_length: try: with st.spinner("正在初始化 Gemma 模型..."): llm = HuggingFaceHub(repo_id=repo_id, model_kwargs={"temperature": temperature, "max_length": max_length}) st.success("Gemma 模型初始化完成!") print("Gemma 模型初始化完成!") st.session_state.repo_id = repo_id st.session_state.temperature = temperature st.session_state.max_length = max_length except Exception as e: st.error(f"Gemma 模型加载失败:{e}") st.stop() # 获取答案 try: with st.spinner("正在筛选本地数据集..."): question_embedding = st.session_state.embeddings.embed_query(question) question_embedding_str = " ".join(map(str, question_embedding)) # print('question_embedding: ' + question_embedding_str) docs_and_scores = st.session_state.db.similarity_search_with_score(question_embedding_str) context = "\n".join([doc.page_content for doc, _ in docs_and_scores]) print('context: ' + context) prompt = f"请根据以下知识库回答问题:\n{context}\n问题:{question}" print('prompt: ' + prompt) st.success("本地数据集筛选完成!") print("本地数据集筛选完成!") with st.spinner("正在生成答案..."): answer = llm.invoke(prompt) # 去掉 prompt 的内容 answer = answer.replace(prompt, "").strip() st.success("答案已经生成!") print("答案已经生成!") return {"prompt": prompt, "answer": answer} except Exception as e: st.error(f"问答过程出错:{e}") return {"prompt": "", "answer": "An error occurred during the answering process."} # Streamlit 界面 st.title("進擊的巨人 知识库问答系统") col1, col2 = st.columns(2) with col1: gemma = st.selectbox("repo-id", ("google/gemma-2-9b-it", "google/gemma-2-2b-it", "google/recurrentgemma-2b-it"), 2) with col2: temperature = st.number_input("temperature", value=1.0) max_length = st.number_input("max_length", value=1024) st.divider() col3, col4 = st.columns(2) with col3: if st.button("使用原数据集中的随机问题"): dataset_size = len(st.session_state.data_list) random_index = random.randint(0, dataset_size - 1) # 读取随机问题 random_question = st.session_state.data_list[random_index]["Question"] random_question = converter.convert(random_question) origin_answer = st.session_state.data_list[random_index]["Answer"] origin_answer = converter.convert(origin_answer) print('[]' + str(random_index) + '/' + str(dataset_size) + ']random_question: ' + random_question) print('origin_answer: ' + origin_answer) st.write("随机问题:") st.write(random_question) st.write("原始答案:") st.write(origin_answer) result = answer_question(gemma, float(temperature), int(max_length), random_question) print('prompt: ' + result["prompt"]) print('answer: ' + result["answer"]) st.write("生成答案:") st.write(result["answer"]) with col4: question = st.text_area("请输入问题", "Gemma 有哪些特点?") if st.button("提交输入的问题"): if not question: st.warning("请输入问题!") else: result = answer_question(gemma, float(temperature), int(max_length), question) print('prompt: ' + result["prompt"]) print('answer: ' + result["answer"]) st.write("生成答案:") st.write(result["answer"])