Prathamesh1420 commited on
Commit
61e6b08
·
verified ·
1 Parent(s): 9d2b33c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -197
app.py CHANGED
@@ -1,217 +1,133 @@
1
  import streamlit as st
2
- from langchain.chains import RetrievalQA
3
- from langchain.vectorstores import Milvus
4
- from langchain.embeddings import HuggingFaceEmbeddings
5
- from transformers import AutoTokenizer
6
- from langchain_groq import ChatGroq
7
  import os
8
- from docling.document_converter import DocumentConverter, PdfFormatOption
9
- from docling.datamodel.base_models import InputFormat
10
- from docling.datamodel.pipeline_options import PdfPipelineOptions
11
- from docling_core.transforms.chunker.hybrid_chunker import HybridChunker
12
- from docling_core.types.doc.document import TableItem
13
- from langchain_core.documents import Document
14
- import itertools
15
- from docling_core.types.doc.labels import DocItemLabel
16
- import google.generativeai as genai
17
  from PIL import Image
18
- import base64
19
- import io
 
 
20
 
21
- # Initialize components (similar to your notebook)
22
- @st.cache_resource
23
- def initialize_components():
24
- # Initialize embeddings
25
- embeddings_model_path = "ibm-granite/granite-embedding-30m-english"
26
- embeddings_model = HuggingFaceEmbeddings(model_name=embeddings_model_path)
27
- embeddings_tokenizer = AutoTokenizer.from_pretrained(embeddings_model_path)
28
-
29
- # Initialize language model
30
- GROQ_API_KEY = "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"
31
- model = ChatGroq(model_name="llama3-70b-8192", api_key=GROQ_API_KEY)
 
 
 
 
 
32
 
33
- # Initialize vision model
34
- GOOGLE_API_KEY = "AIzaSyBTt66oOvxpLeYn41sR-KkjSYPK2vOAqkU"
35
- genai.configure(api_key=GOOGLE_API_KEY)
36
- vision_model = genai.GenerativeModel(model_name="gemini-1.5-flash")
37
 
38
- return embeddings_model, embeddings_tokenizer, model, vision_model
39
-
40
- def process_pdf(file_path, embeddings_tokenizer, vision_model):
41
- # PDF processing (similar to your notebook)
42
- pdf_pipeline_options = PdfPipelineOptions(
43
- do_ocr=True,
44
- generate_picture_images=True
45
  )
46
 
47
- format_options = {
48
- InputFormat.PDF: PdfFormatOption(pipeline_options=pdf_pipeline_options),
49
- }
50
-
51
- converter = DocumentConverter(format_options=format_options)
52
- sources = [file_path]
53
- conversions = {
54
- source: converter.convert(source=source).document for source in sources
55
- }
56
 
57
- # Process text chunks
58
- doc_id = 0
59
- texts = []
 
 
 
60
 
61
- for source, docling_document in conversions.items():
62
- chunker = HybridChunker(tokenizer=embeddings_tokenizer)
63
-
64
- for chunk in chunker.chunk(docling_document):
65
- items = chunk.meta.doc_items
66
-
67
- if len(items) == 1 and isinstance(items[0], TableItem):
68
- continue
 
 
69
 
70
- refs = "".join(item.get_ref().cref for item in items)
71
- text = chunk.text
72
-
73
- document = Document(
74
- page_content=text,
75
- metadata={
76
- "doc_id": (doc_id := doc_id + 1),
77
- "source": source,
78
- "ref": refs,
79
- }
80
- )
81
- texts.append(document)
82
-
83
- # Process tables (if any)
84
- tables = []
85
- for source, docling_document in conversions.items():
86
- for table in docling_document.tables:
87
- if table.label == DocItemLabel.TABLE:
88
- ref = table.get_ref().cref
89
- text = table.export_to_markdown()
90
 
91
- document = Document(
92
- page_content=text,
93
- metadata={
94
- "doc_id": (doc_id := doc_id + 1),
95
- "source": source,
96
- "ref": ref,
97
- },
98
- )
99
- tables.append(document)
100
-
101
- # Process images (if any)
102
- pictures = []
103
- start_doc_id = len(texts) + len(tables) + 1
104
-
105
- for source, docling_document in conversions.items():
106
- if hasattr(docling_document, 'pictures') and docling_document.pictures:
107
- for picture in docling_document.pictures:
108
- try:
109
- ref = picture.get_ref().cref
110
- image = picture.get_image(docling_document)
111
-
112
- if image:
113
- response = vision_model.generate_content([
114
- "Extract all text and describe key visual elements in this image. "
115
- "Include any numbers, labels, or important details.",
116
- image
117
- ])
118
-
119
- document = Document(
120
- page_content=response.text,
121
- metadata={
122
- "doc_id": doc_id,
123
- "source": source,
124
- "ref": ref,
125
- }
126
- )
127
- pictures.append(document)
128
- doc_id += 1
129
- except Exception as e:
130
- print(f"Error processing image: {str(e)}")
131
-
132
- return texts + tables + pictures
133
 
134
- def create_vector_store(docs, embeddings_model):
135
- # Create vector store (using Milvus as in your notebook)
136
- # Note: You'll need to have Milvus running
137
- vector_store = Milvus.from_documents(
138
- docs,
139
- embeddings_model,
140
- connection_args={"host": "127.0.0.1", "port": "19530"},
141
- collection_name="pdf_manual"
142
- )
143
- return vector_store
144
 
145
- def main():
146
- st.title("PDF Manual Chatbot")
147
-
148
- # Initialize components
149
- embeddings_model, embeddings_tokenizer, model, vision_model = initialize_components()
150
-
151
- # File upload
152
- uploaded_file = st.file_uploader("Upload a PDF manual", type="pdf")
153
-
154
- if uploaded_file is not None:
155
- # Save the uploaded file
156
- file_path = os.path.join("temp", uploaded_file.name)
157
- os.makedirs("temp", exist_ok=True)
158
- with open(file_path, "wb") as f:
159
- f.write(uploaded_file.getbuffer())
160
-
161
- # Process the PDF
162
- with st.spinner("Processing PDF..."):
163
- docs = process_pdf(file_path, embeddings_tokenizer, vision_model)
164
- vector_store = create_vector_store(docs, embeddings_model)
165
-
166
- st.success("PDF processed successfully!")
167
-
168
- # Initialize chat history
169
- if "messages" not in st.session_state:
170
- st.session_state.messages = []
171
-
172
- # Display chat messages from history on app rerun
173
- for message in st.session_state.messages:
174
- with st.chat_message(message["role"]):
175
- st.markdown(message["content"])
176
-
177
- # Accept user input
178
- if prompt := st.chat_input("Ask a question about the manual"):
179
- # Add user message to chat history
180
- st.session_state.messages.append({"role": "user", "content": prompt})
181
-
182
- # Display user message in chat message container
183
- with st.chat_message("user"):
184
- st.markdown(prompt)
185
 
186
- # Create QA chain
187
- qa_chain = RetrievalQA.from_chain_type(
188
- llm=model,
189
- chain_type="stuff",
190
- retriever=vector_store.as_retriever(),
191
- return_source_documents=True
192
  )
193
 
194
- # Get response
195
- with st.spinner("Thinking..."):
196
- result = qa_chain({"query": prompt})
197
- response = result["result"]
198
- source_docs = result["source_documents"]
199
 
200
- # Display assistant response in chat message container
201
- with st.chat_message("assistant"):
202
- st.markdown(response)
203
-
204
- # Show sources if available
205
- if source_docs:
206
- with st.expander("Source Documents"):
207
- for i, doc in enumerate(source_docs):
208
- st.write(f"Source {i+1}:")
209
- st.write(doc.page_content)
210
- st.write(f"Metadata: {doc.metadata}")
211
- st.write("---")
212
 
213
- # Add assistant response to chat history
214
- st.session_state.messages.append({"role": "assistant", "content": response})
 
 
 
215
 
216
- if __name__ == "__main__":
217
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
 
 
 
 
 
2
  import os
 
 
 
 
 
 
 
 
 
3
  from PIL import Image
4
+ import google.generativeai as genai
5
+ from utils.document_processing import process_pdf
6
+ from utils.models import load_models
7
+ from utils.rag import query_pipeline
8
 
9
+ # Configure the app
10
+ st.set_page_config(
11
+ page_title="PDF RAG Pipeline",
12
+ page_icon="📄",
13
+ layout="wide"
14
+ )
15
+
16
+ # Initialize session state
17
+ if 'models_loaded' not in st.session_state:
18
+ st.session_state.models_loaded = False
19
+ if 'processed_docs' not in st.session_state:
20
+ st.session_state.processed_docs = None
21
+
22
+ # Sidebar for configuration
23
+ with st.sidebar:
24
+ st.title("Configuration")
25
 
26
+ # API keys
27
+ groq_api_key = st.text_input("Groq API Key", type="password")
28
+ google_api_key = st.text_input("Google API Key", type="password")
 
29
 
30
+ # Model selection
31
+ embedding_model = st.selectbox(
32
+ "Embedding Model",
33
+ ["ibm-granite/granite-embedding-30m-english"],
34
+ index=0
 
 
35
  )
36
 
37
+ llm_model = st.selectbox(
38
+ "LLM Model",
39
+ ["llama3-70b-8192"],
40
+ index=0
41
+ )
 
 
 
 
42
 
43
+ # File upload
44
+ uploaded_file = st.file_uploader(
45
+ "Upload a PDF file",
46
+ type=["pdf"],
47
+ accept_multiple_files=False
48
+ )
49
 
50
+ if st.button("Initialize Models"):
51
+ with st.spinner("Loading models..."):
52
+ try:
53
+ # Load models
54
+ embeddings_model, embeddings_tokenizer, vision_model, llm_model = load_models(
55
+ embedding_model=embedding_model,
56
+ llm_model=llm_model,
57
+ google_api_key=google_api_key,
58
+ groq_api_key=groq_api_key
59
+ )
60
 
61
+ st.session_state.embeddings_model = embeddings_model
62
+ st.session_state.embeddings_tokenizer = embeddings_tokenizer
63
+ st.session_state.vision_model = vision_model
64
+ st.session_state.llm_model = llm_model
65
+ st.session_state.models_loaded = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
+ st.success("Models loaded successfully!")
68
+ except Exception as e:
69
+ st.error(f"Error loading models: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
+ # Main app interface
72
+ st.title("PDF RAG Pipeline")
73
+ st.write("Upload a PDF and ask questions about its content")
 
 
 
 
 
 
 
74
 
75
+ if uploaded_file and st.session_state.models_loaded:
76
+ with st.spinner("Processing PDF..."):
77
+ try:
78
+ # Save uploaded file temporarily
79
+ file_path = f"./temp_{uploaded_file.name}"
80
+ with open(file_path, "wb") as f:
81
+ f.write(uploaded_file.getbuffer())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
+ # Process the PDF
84
+ texts, tables, pictures = process_pdf(
85
+ file_path,
86
+ st.session_state.embeddings_tokenizer,
87
+ st.session_state.vision_model
 
88
  )
89
 
90
+ st.session_state.processed_docs = {
91
+ "texts": texts,
92
+ "tables": tables,
93
+ "pictures": pictures
94
+ }
95
 
96
+ st.success("PDF processed successfully!")
97
+
98
+ # Display document stats
99
+ col1, col2, col3 = st.columns(3)
100
+ col1.metric("Text Chunks", len(texts))
101
+ col2.metric("Tables", len(tables))
102
+ col3.metric("Images", len(pictures))
 
 
 
 
 
103
 
104
+ # Remove temp file
105
+ os.remove(file_path)
106
+
107
+ except Exception as e:
108
+ st.error(f"Error processing PDF: {str(e)}")
109
 
110
+ # Question answering section
111
+ if st.session_state.processed_docs:
112
+ st.divider()
113
+ st.subheader("Ask a Question")
114
+
115
+ question = st.text_input("Enter your question about the document:")
116
+
117
+ if question and st.button("Get Answer"):
118
+ with st.spinner("Generating answer..."):
119
+ try:
120
+ answer = query_pipeline(
121
+ question=question,
122
+ texts=st.session_state.processed_docs["texts"],
123
+ tables=st.session_state.processed_docs["tables"],
124
+ pictures=st.session_state.processed_docs["pictures"],
125
+ embeddings_model=st.session_state.embeddings_model,
126
+ llm_model=st.session_state.llm_model
127
+ )
128
+
129
+ st.subheader("Answer")
130
+ st.write(answer)
131
+
132
+ except Exception as e:
133
+ st.error(f"Error generating answer: {str(e)}")