Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
from langchain_community.document_loaders import PyPDFLoader | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from langchain_community.llms import HuggingFacePipeline | |
from langchain.chains.question_answering import load_qa_chain | |
from langchain_core.prompts import PromptTemplate | |
from transformers import pipeline | |
# 載入 Mistral 模型 | |
model_path = "nvidia/Mistral-NeMo-Minitron-8B-instruct" | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
dtype = torch.bfloat16 | |
print(f"使用設備: {device}") | |
# 初始化 tokenizer | |
mistral_tokenizer = AutoTokenizer.from_pretrained(model_path) | |
# 初始化模型 | |
mistral_model = AutoModelForCausalLM.from_pretrained( | |
model_path, | |
torch_dtype=dtype, | |
device_map=device, | |
low_cpu_mem_usage=True | |
) | |
# 創建 pipeline | |
text_generation_pipeline = pipeline( | |
"text-generation", | |
model=mistral_model, | |
tokenizer=mistral_tokenizer, | |
max_length=512, | |
temperature=0.3, | |
top_p=0.95, | |
device_map=device | |
) | |
# 為 pipeline 創建 LangChain 包裝器 | |
llm = HuggingFacePipeline(pipeline=text_generation_pipeline) | |
def initialize(file_path, question): | |
try: | |
prompt_template = """根據提供的上下文盡可能準確地回答問題。如果上下文中沒有包含答案,請說「上下文中沒有提供答案」\n\n | |
上下文: \n {context}?\n | |
問題: \n {question} \n | |
回答: | |
""" | |
prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"]) | |
if os.path.exists(file_path): | |
pdf_loader = PyPDFLoader(file_path) | |
pages = pdf_loader.load_and_split() | |
# 限制上下文以避免超出令牌限制 | |
max_pages = 5 # 根據模型容量和文檔長度調整 | |
context = "\n".join(str(page.page_content) for page in pages[:max_pages]) | |
try: | |
# 使用 Mistral 創建問答鏈 | |
stuff_chain = load_qa_chain(llm, chain_type="stuff", prompt=prompt) | |
# 使用有限的頁面獲取答案 | |
stuff_answer = stuff_chain( | |
{"input_documents": pages[:max_pages], "question": question, "context": context}, | |
return_only_outputs=True | |
) | |
main_answer = stuff_answer['output_text'] | |
# 生成後續問題 | |
follow_up_prompt = f"根據這個回答: {main_answer}\n生成一個相關的後續問題:" | |
follow_up_inputs = mistral_tokenizer.encode(follow_up_prompt, return_tensors='pt').to(device) | |
with torch.no_grad(): | |
follow_up_outputs = mistral_model.generate( | |
follow_up_inputs, | |
max_length=256, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True | |
) | |
follow_up = mistral_tokenizer.decode(follow_up_outputs[0], skip_special_tokens=True) | |
# 提取問題 | |
if "後續問題:" in follow_up.lower(): | |
follow_up = follow_up.split("後續問題:", 1)[1].strip() | |
combined_output = f"回答: {main_answer}\n\n可能的後續問題: {follow_up}" | |
return combined_output | |
except Exception as e: | |
if "exceeds the maximum token count" in str(e): | |
return "錯誤: 文檔太大無法處理。請嘗試使用較小的文檔。" | |
else: | |
raise e | |
else: | |
return "錯誤: 無法處理文檔。請確保PDF文件存在且有效。" | |
except Exception as e: | |
return f"發生錯誤: {str(e)}" | |
# 定義 Gradio 界面 | |
def pdf_qa(file, question): | |
if file is None: | |
return "請先上傳PDF文件。" | |
return initialize(file.name, question) | |
# 創建 Gradio 界面 | |
demo = gr.Interface( | |
fn=pdf_qa, | |
inputs=[ | |
gr.File(label="上傳PDF文件", file_types=[".pdf"]), | |
gr.Textbox(label="詢問文檔內容", placeholder="這個文檔主要講了什麼?") | |
], | |
outputs=gr.Textbox(label="Mistral 回答"), | |
title="基於Mistral的PDF問答系統", | |
description="上傳PDF文件並提出問題,Mistral模型將分析內容並提供回答和可能的後續問題。" | |
) | |
if __name__ == "__main__": | |
demo.launch() |