CCCDev commited on
Commit
555dc75
·
verified ·
1 Parent(s): 0a07848

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +194 -0
app.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+
4
+ from langchain_community.document_loaders import PyPDFLoader
5
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
6
+ from langchain_community.vectorstores import Chroma
7
+ from langchain.chains import ConversationalRetrievalChain
8
+ from langchain_community.embeddings import HuggingFaceEmbeddings
9
+ from langchain_community.llms import HuggingFacePipeline
10
+ from langchain.chains import ConversationChain
11
+ from langchain.memory import ConversationBufferMemory
12
+ from langchain_community.llms import HuggingFaceEndpoint
13
+
14
+ from pathlib import Path
15
+ import chromadb
16
+ from unidecode import unidecode
17
+
18
+ from transformers import AutoTokenizer, AutoModel
19
+ import torch
20
+ import re
21
+
22
+ # Adjustments for the new LLM model
23
+ LLM_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
24
+ LLM_MAX_TOKEN = 512
25
+ DB_CHUNK_SIZE = 512
26
+ CHUNK_OVERLAP = 24
27
+ TEMPERATURE = 0.1
28
+ MAX_TOKENS = 512
29
+ TOP_K = 20
30
+ pdf_url = "https://huggingface.co/spaces/CCCDev/PDFChat/resolve/main/Privacy-Policy%20(1).pdf" # Replace with your static PDF URL or path
31
+
32
+ # Load PDF document and create doc splits
33
+ def load_doc(pdf_url, chunk_size, chunk_overlap):
34
+ loader = PyPDFLoader(pdf_url)
35
+ pages = loader.load()
36
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
37
+ doc_splits = text_splitter.split_documents(pages)
38
+ return doc_splits
39
+
40
+ # Create vector database
41
+ def create_db(splits, collection_name):
42
+ embedding = HuggingFaceEmbeddings(model_name=LLM_MODEL)
43
+ new_client = chromadb.EphemeralClient()
44
+ vectordb = Chroma.from_documents(
45
+ documents=splits,
46
+ embedding=embedding,
47
+ client=new_client,
48
+ collection_name=collection_name,
49
+ )
50
+ return vectordb
51
+
52
+ # Initialize langchain LLM chain
53
+ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
54
+ progress(0.5, desc="Initializing HF Hub...")
55
+
56
+ # Use HuggingFacePipeline instead of HuggingFaceEndpoint
57
+ tokenizer = AutoTokenizer.from_pretrained(llm_model)
58
+ model = AutoModel.from_pretrained(llm_model)
59
+ pipe = HuggingFacePipeline(model=model, tokenizer=tokenizer)
60
+
61
+ progress(0.75, desc="Defining buffer memory...")
62
+ memory = ConversationBufferMemory(
63
+ memory_key="chat_history",
64
+ output_key='answer',
65
+ return_messages=True
66
+ )
67
+ retriever = vector_db.as_retriever()
68
+ progress(0.8, desc="Defining retrieval chain...")
69
+ qa_chain = ConversationalRetrievalChain.from_llm(
70
+ pipe,
71
+ retriever=retriever,
72
+ chain_type="stuff",
73
+ memory=memory,
74
+ return_source_documents=True,
75
+ verbose=False,
76
+ )
77
+ progress(0.9, desc="Done!")
78
+ return qa_chain
79
+
80
+ # Generate collection name for vector database
81
+ def create_collection_name(filepath):
82
+ collection_name = Path(filepath).stem
83
+ collection_name = collection_name.replace(" ", "-")
84
+ collection_name = unidecode(collection_name)
85
+ collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)
86
+ collection_name = collection_name[:50]
87
+ if len(collection_name) < 3:
88
+ collection_name = collection_name + 'xyz'
89
+ if not collection_name[0].isalnum():
90
+ collection_name = 'A' + collection_name[1:]
91
+ if not collection_name[-1].isalnum():
92
+ collection_name = collection_name[:-1] + 'Z'
93
+ return collection_name
94
+
95
+ # Initialize database
96
+ def initialize_database(pdf_url, chunk_size, chunk_overlap, progress=gr.Progress()):
97
+ collection_name = create_collection_name(pdf_url)
98
+ progress(0.25, desc="Loading document...")
99
+ doc_splits = load_doc(pdf_url, chunk_size, chunk_overlap)
100
+ progress(0.5, desc="Generating vector database...")
101
+ vector_db = create_db(doc_splits, collection_name)
102
+ progress(0.9, desc="Done!")
103
+ return vector_db, collection_name, "Complete!"
104
+
105
+ def initialize_LLM(llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
106
+ qa_chain = initialize_llmchain(LLM_MODEL, llm_temperature, max_tokens, top_k, vector_db, progress)
107
+ return qa_chain, "Complete!"
108
+
109
+ def format_chat_history(message, chat_history):
110
+ formatted_chat_history = []
111
+ for user_message, bot_message in chat_history:
112
+ formatted_chat_history.append(f"User: {user_message}")
113
+ formatted_chat_history.append(f"Assistant: {bot_message}")
114
+ return formatted_chat_history
115
+
116
+ def conversation(qa_chain, message, history):
117
+ formatted_chat_history = format_chat_history(message, history)
118
+ response = qa_chain({"question": message, "chat_history": formatted_chat_history})
119
+ response_answer = response["answer"]
120
+ if response_answer.find("Helpful Answer:") != -1:
121
+ response_answer = response_answer.split("Helpful Answer:")[-1]
122
+ response_sources = response["source_documents"]
123
+ response_source1 = response_sources[0].page_content.strip()
124
+ response_source2 = response_sources[1].page_content.strip()
125
+ response_source3 = response_sources[2].page_content.strip()
126
+ response_source1_page = response_sources[0].metadata["page"] + 1
127
+ response_source2_page = response_sources[1].metadata["page"] + 1
128
+ response_source3_page = response_sources[2].metadata["page"] + 1
129
+ new_history = history + [(message, response_answer)]
130
+ return qa_chain, gr.update(
131
+ value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
132
+
133
+ def demo():
134
+ with gr.Blocks(theme="base") as demo:
135
+ vector_db = gr.State()
136
+ qa_chain = gr.State()
137
+ collection_name = gr.State()
138
+
139
+ gr.Markdown(
140
+ """<center><h2>PDF-based chatbot</center></h2>
141
+ <h3>Ask any questions about your PDF documents</h3>""")
142
+ gr.Markdown(
143
+ """<b>Note:</b> This AI assistant, using Langchain and open-source LLMs, performs retrieval-augmented generation (RAG) from your PDF documents. \
144
+ The user interface explicitely shows multiple steps to help understand the RAG workflow.
145
+ This chatbot takes past questions into account when generating answers (via conversational memory), and includes document references for clarity purposes.<br>
146
+ <br><b>Warning:</b> This space uses the free CPU Basic hardware from Hugging Face. Some steps and LLM models used below (free inference endpoints) can take some time to generate a reply.
147
+ """)
148
+
149
+ with gr.Tab("Step 4 - Chatbot"):
150
+ chatbot = gr.Chatbot(height=300)
151
+ with gr.Accordion("Advanced - Document references", open=False):
152
+ with gr.Row():
153
+ doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
154
+ source1_page = gr.Number(label="Page", scale=1)
155
+ with gr.Row():
156
+ doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
157
+ source2_page = gr.Number(label="Page", scale=1)
158
+ with gr.Row():
159
+ doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
160
+ source3_page = gr.Number(label="Page", scale=1)
161
+ with gr.Row():
162
+ msg = gr.Textbox(placeholder="Type message (e.g. 'What is this document about?')", container=True)
163
+ with gr.Row():
164
+ submit_btn = gr.Button("Submit message")
165
+ clear_btn = gr.ClearButton([msg, chatbot], value="Clear conversation")
166
+
167
+ # Automatic preprocessing
168
+ db_progress = gr.Textbox(label="Vector database initialization", value="Initializing...")
169
+ db_btn = gr.Button("Generate vector database", visible=False)
170
+ qachain_btn = gr.Button("Initialize Question Answering chain", visible=False)
171
+ llm_progress = gr.Textbox(value="None", label="QA chain initialization")
172
+
173
+ def auto_initialize():
174
+ vector_db, collection_name, db_status = initialize_database(pdf_url, DB_CHUNK_SIZE, CHUNK_OVERLAP)
175
+ qa_chain, llm_status = initialize_LLM(TEMPERATURE, LLM_MAX_TOKEN, 20, vector_db)
176
+ return vector_db, collection_name, db_status, qa_chain, llm_status, "Initialization complete."
177
+
178
+ demo.load(auto_initialize, [], [vector_db, collection_name, db_progress, qa_chain, llm_progress])
179
+
180
+ # Chatbot events
181
+ msg.submit(conversation, \
182
+ inputs=[qa_chain, msg, chatbot], \
183
+ outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3,
184
+ source3_page], \
185
+ queue=False)
186
+ submit_btn.click(conversation, \
187
+ inputs=[qa_chain, msg, chatbot], \
188
+ outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page,
189
+ doc_source3, source3_page], \
190
+ queue=False)
191
+ return demo.queue().launch(debug=True)
192
+
193
+ if __name__ == "__main__":
194
+ demo()