shallou commited on
Commit
ba33fd4
Β·
verified Β·
1 Parent(s): 0882176

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +278 -0
app.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Streamlit application for PDF-based Retrieval-Augmented Generation (RAG) using Ollama + LangChain.
3
+
4
+ This application allows users to upload a PDF, process it,
5
+ and then ask questions about the content using a selected language model.
6
+ """
7
+
8
+ import streamlit as st
9
+ import logging
10
+ import os
11
+ import tempfile
12
+ import shutil
13
+ import pdfplumber
14
+ import ollama
15
+
16
+ from langchain_community.document_loaders import UnstructuredPDFLoader
17
+ from langchain_community.embeddings import OllamaEmbeddings
18
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
19
+ from langchain_community.vectorstores import Chroma
20
+ from langchain.prompts import ChatPromptTemplate, PromptTemplate
21
+ from langchain_core.output_parsers import StrOutputParser
22
+ from langchain_community.chat_models import ChatOllama
23
+ from langchain_core.runnables import RunnablePassthrough
24
+ from langchain.retrievers.multi_query import MultiQueryRetriever
25
+ from typing import List, Tuple, Dict, Any, Optional
26
+
27
+ # Streamlit page configuration
28
+ st.set_page_config(
29
+ page_title="Ollama PDF RAG Streamlit UI",
30
+ page_icon="🎈",
31
+ layout="wide",
32
+ initial_sidebar_state="collapsed",
33
+ )
34
+
35
+ # Logging configuration
36
+ logging.basicConfig(
37
+ level=logging.INFO,
38
+ format="%(asctime)s - %(levelname)s - %(message)s",
39
+ datefmt="%Y-%m-%d %H:%M:%S",
40
+ )
41
+
42
+ logger = logging.getLogger(__name__)
43
+
44
+
45
+ @st.cache_resource(show_spinner=True)
46
+ def extract_model_names(
47
+ models_info: Dict[str, List[Dict[str, Any]]],
48
+ ) -> Tuple[str, ...]:
49
+ """
50
+ Extract model names from the provided models information.
51
+
52
+ Args:
53
+ models_info (Dict[str, List[Dict[str, Any]]]): Dictionary containing information about available models.
54
+
55
+ Returns:
56
+ Tuple[str, ...]: A tuple of model names.
57
+ """
58
+ logger.info("Extracting model names from models_info")
59
+ model_names = tuple(model["name"] for model in models_info["models"])
60
+ logger.info(f"Extracted model names: {model_names}")
61
+ return model_names
62
+
63
+
64
+ def create_vector_db(file_upload) -> Chroma:
65
+ """
66
+ Create a vector database from an uploaded PDF file.
67
+
68
+ Args:
69
+ file_upload (st.UploadedFile): Streamlit file upload object containing the PDF.
70
+
71
+ Returns:
72
+ Chroma: A vector store containing the processed document chunks.
73
+ """
74
+ logger.info(f"Creating vector DB from file upload: {file_upload.name}")
75
+ temp_dir = tempfile.mkdtemp()
76
+
77
+ path = os.path.join(temp_dir, file_upload.name)
78
+ with open(path, "wb") as f:
79
+ f.write(file_upload.getvalue())
80
+ logger.info(f"File saved to temporary path: {path}")
81
+ loader = UnstructuredPDFLoader(path)
82
+ data = loader.load()
83
+
84
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=7500, chunk_overlap=100)
85
+ chunks = text_splitter.split_documents(data)
86
+ logger.info("Document split into chunks")
87
+
88
+ embeddings = OllamaEmbeddings(model="nomic-embed-text", show_progress=True)
89
+ vector_db = Chroma.from_documents(
90
+ documents=chunks, embedding=embeddings, collection_name="myRAG"
91
+ )
92
+ logger.info("Vector DB created")
93
+
94
+ shutil.rmtree(temp_dir)
95
+ logger.info(f"Temporary directory {temp_dir} removed")
96
+ return vector_db
97
+
98
+
99
+ def process_question(question: str, vector_db: Chroma, selected_model: str) -> str:
100
+ """
101
+ Process a user question using the vector database and selected language model.
102
+
103
+ Args:
104
+ question (str): The user's question.
105
+ vector_db (Chroma): The vector database containing document embeddings.
106
+ selected_model (str): The name of the selected language model.
107
+
108
+ Returns:
109
+ str: The generated response to the user's question.
110
+ """
111
+ logger.info(f"""Processing question: {
112
+ question} using model: {selected_model}""")
113
+ llm = ChatOllama(model=selected_model, temperature=0)
114
+ QUERY_PROMPT = PromptTemplate(
115
+ input_variables=["question"],
116
+ template="""You are an AI language model assistant. Your task is to generate 3
117
+ different versions of the given user question to retrieve relevant documents from
118
+ a vector database. By generating multiple perspectives on the user question, your
119
+ goal is to help the user overcome some of the limitations of the distance-based
120
+ similarity search. Provide these alternative questions separated by newlines.
121
+ Original question: {question}""",
122
+ )
123
+
124
+ retriever = MultiQueryRetriever.from_llm(
125
+ vector_db.as_retriever(), llm, prompt=QUERY_PROMPT
126
+ )
127
+
128
+ template = """Answer the question based ONLY on the following context:
129
+ {context}
130
+ Question: {question}
131
+ If you don't know the answer, just say that you don't know, don't try to make up an answer.
132
+ Only provide the answer from the {context}, nothing else.
133
+ Add snippets of the context you used to answer the question.
134
+ """
135
+
136
+ prompt = ChatPromptTemplate.from_template(template)
137
+
138
+ chain = (
139
+ {"context": retriever, "question": RunnablePassthrough()}
140
+ | prompt
141
+ | llm
142
+ | StrOutputParser()
143
+ )
144
+
145
+ response = chain.invoke(question)
146
+ logger.info("Question processed and response generated")
147
+ return response
148
+
149
+
150
+ @st.cache_data
151
+ def extract_all_pages_as_images(file_upload) -> List[Any]:
152
+ """
153
+ Extract all pages from a PDF file as images.
154
+
155
+ Args:
156
+ file_upload (st.UploadedFile): Streamlit file upload object containing the PDF.
157
+
158
+ Returns:
159
+ List[Any]: A list of image objects representing each page of the PDF.
160
+ """
161
+ logger.info(f"""Extracting all pages as images from file: {
162
+ file_upload.name}""")
163
+ pdf_pages = []
164
+ with pdfplumber.open(file_upload) as pdf:
165
+ pdf_pages = [page.to_image().original for page in pdf.pages]
166
+ logger.info("PDF pages extracted as images")
167
+ return pdf_pages
168
+
169
+
170
+ def delete_vector_db(vector_db: Optional[Chroma]) -> None:
171
+ """
172
+ Delete the vector database and clear related session state.
173
+
174
+ Args:
175
+ vector_db (Optional[Chroma]): The vector database to be deleted.
176
+ """
177
+ logger.info("Deleting vector DB")
178
+ if vector_db is not None:
179
+ vector_db.delete_collection()
180
+ st.session_state.pop("pdf_pages", None)
181
+ st.session_state.pop("file_upload", None)
182
+ st.session_state.pop("vector_db", None)
183
+ st.success("Collection and temporary files deleted successfully.")
184
+ logger.info("Vector DB and related session state cleared")
185
+ st.rerun()
186
+ else:
187
+ st.error("No vector database found to delete.")
188
+ logger.warning("Attempted to delete vector DB, but none was found")
189
+
190
+
191
+ def main() -> None:
192
+ """
193
+ Main function to run the Streamlit application.
194
+
195
+ This function sets up the user interface, handles file uploads,
196
+ processes user queries, and displays results.
197
+ """
198
+ st.subheader("🧠 Ollama PDF RAG playground", divider="gray", anchor=False)
199
+
200
+ models_info = ollama.list()
201
+ available_models = extract_model_names(models_info)
202
+
203
+ col1, col2 = st.columns([1.5, 2])
204
+
205
+ if "messages" not in st.session_state:
206
+ st.session_state["messages"] = []
207
+
208
+ if "vector_db" not in st.session_state:
209
+ st.session_state["vector_db"] = None
210
+
211
+ if available_models:
212
+ selected_model = col2.selectbox(
213
+ "Pick a model available locally on your system ↓", available_models
214
+ )
215
+
216
+ file_upload = col1.file_uploader(
217
+ "Upload a PDF file ↓", type="pdf", accept_multiple_files=False
218
+ )
219
+
220
+ if file_upload:
221
+ st.session_state["file_upload"] = file_upload
222
+ if st.session_state["vector_db"] is None:
223
+ st.session_state["vector_db"] = create_vector_db(file_upload)
224
+ pdf_pages = extract_all_pages_as_images(file_upload)
225
+ st.session_state["pdf_pages"] = pdf_pages
226
+
227
+ zoom_level = col1.slider(
228
+ "Zoom Level", min_value=100, max_value=1000, value=700, step=50
229
+ )
230
+
231
+ with col1:
232
+ with st.container(height=410, border=True):
233
+ for page_image in pdf_pages:
234
+ st.image(page_image, width=zoom_level)
235
+
236
+ delete_collection = col1.button("⚠️ Delete collection", type="secondary")
237
+
238
+ if delete_collection:
239
+ delete_vector_db(st.session_state["vector_db"])
240
+
241
+ with col2:
242
+ message_container = st.container(height=500, border=True)
243
+
244
+ for message in st.session_state["messages"]:
245
+ avatar = "πŸ€–" if message["role"] == "assistant" else "😎"
246
+ with message_container.chat_message(message["role"], avatar=avatar):
247
+ st.markdown(message["content"])
248
+
249
+ if prompt := st.chat_input("Enter a prompt here..."):
250
+ try:
251
+ st.session_state["messages"].append({"role": "user", "content": prompt})
252
+ message_container.chat_message("user", avatar="😎").markdown(prompt)
253
+
254
+ with message_container.chat_message("assistant", avatar="πŸ€–"):
255
+ with st.spinner(":green[processing...]"):
256
+ if st.session_state["vector_db"] is not None:
257
+ response = process_question(
258
+ prompt, st.session_state["vector_db"], selected_model
259
+ )
260
+ st.markdown(response)
261
+ else:
262
+ st.warning("Please upload a PDF file first.")
263
+
264
+ if st.session_state["vector_db"] is not None:
265
+ st.session_state["messages"].append(
266
+ {"role": "assistant", "content": response}
267
+ )
268
+
269
+ except Exception as e:
270
+ st.error(e, icon="⛔️")
271
+ logger.error(f"Error processing prompt: {e}")
272
+ else:
273
+ if st.session_state["vector_db"] is None:
274
+ st.warning("Upload a PDF file to begin chat...")
275
+
276
+
277
+ if __name__ == "__main__":
278
+ main()