File size: 4,661 Bytes
1ffdd41
 
 
 
 
adcdd13
 
 
 
1ffdd41
adcdd13
666813f
1ffdd41
 
adcdd13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ffdd41
 
 
adcdd13
 
 
 
 
1ffdd41
 
 
 
 
 
adcdd13
 
 
1ffdd41
adcdd13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ffdd41
adcdd13
1ffdd41
adcdd13
1ffdd41
adcdd13
1ffdd41
 
adcdd13
1ffdd41
 
adcdd13
 
1ffdd41
adcdd13
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import gradio as gr
from langchain_community.document_loaders import PyPDFLoader
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from langchain_community.llms import HuggingFacePipeline
from langchain.chains.question_answering import load_qa_chain
from langchain_core.prompts import PromptTemplate
from transformers import pipeline

# 載入 Mistral 模型
model_path = "nvidia/Mistral-NeMo-Minitron-8B-instruct"
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = torch.bfloat16
print(f"使用設備: {device}")

# 初始化 tokenizer
mistral_tokenizer = AutoTokenizer.from_pretrained(model_path)

# 初始化模型
mistral_model = AutoModelForCausalLM.from_pretrained(
    model_path, 
    torch_dtype=dtype, 
    device_map=device,
    low_cpu_mem_usage=True
)

# 創建 pipeline
text_generation_pipeline = pipeline(
    "text-generation",
    model=mistral_model,
    tokenizer=mistral_tokenizer,
    max_length=512,
    temperature=0.3,
    top_p=0.95,
    device_map=device
)

# 為 pipeline 創建 LangChain 包裝器
llm = HuggingFacePipeline(pipeline=text_generation_pipeline)

def initialize(file_path, question):
    try:
        prompt_template = """根據提供的上下文盡可能準確地回答問題。如果上下文中沒有包含答案,請說「上下文中沒有提供答案」\n\n
                          上下文: \n {context}?\n
                          問題: \n {question} \n
                          回答:
                        """
        prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
        
        if os.path.exists(file_path):
            pdf_loader = PyPDFLoader(file_path)
            pages = pdf_loader.load_and_split()
            
            # 限制上下文以避免超出令牌限制
            max_pages = 5  # 根據模型容量和文檔長度調整
            context = "\n".join(str(page.page_content) for page in pages[:max_pages])
            
            try:
                # 使用 Mistral 創建問答鏈
                stuff_chain = load_qa_chain(llm, chain_type="stuff", prompt=prompt)
                
                # 使用有限的頁面獲取答案
                stuff_answer = stuff_chain(
                    {"input_documents": pages[:max_pages], "question": question, "context": context}, 
                    return_only_outputs=True
                )
                
                main_answer = stuff_answer['output_text']
                
                # 生成後續問題
                follow_up_prompt = f"根據這個回答: {main_answer}\n生成一個相關的後續問題:"
                follow_up_inputs = mistral_tokenizer.encode(follow_up_prompt, return_tensors='pt').to(device)
                
                with torch.no_grad():
                    follow_up_outputs = mistral_model.generate(
                        follow_up_inputs, 
                        max_length=256,
                        temperature=0.7,
                        top_p=0.9,
                        do_sample=True
                    )
                
                follow_up = mistral_tokenizer.decode(follow_up_outputs[0], skip_special_tokens=True)
                
                # 提取問題
                if "後續問題:" in follow_up.lower():
                    follow_up = follow_up.split("後續問題:", 1)[1].strip()
                
                combined_output = f"回答: {main_answer}\n\n可能的後續問題: {follow_up}"
                return combined_output
                
            except Exception as e:
                if "exceeds the maximum token count" in str(e):
                    return "錯誤: 文檔太大無法處理。請嘗試使用較小的文檔。"
                else:
                    raise e
        else:
            return "錯誤: 無法處理文檔。請確保PDF文件存在且有效。"
    except Exception as e:
        return f"發生錯誤: {str(e)}"

# 定義 Gradio 界面
def pdf_qa(file, question):
    if file is None:
        return "請先上傳PDF文件。"
    return initialize(file.name, question)

# 創建 Gradio 界面
demo = gr.Interface(
    fn=pdf_qa,
    inputs=[
        gr.File(label="上傳PDF文件", file_types=[".pdf"]),
        gr.Textbox(label="詢問文檔內容", placeholder="這個文檔主要講了什麼?")
    ],
    outputs=gr.Textbox(label="Mistral 回答"),
    title="基於Mistral的PDF問答系統",
    description="上傳PDF文件並提出問題,Mistral模型將分析內容並提供回答和可能的後續問題。"
)

if __name__ == "__main__":
    demo.launch()