pradeepodela commited on
Commit
1d85c92
·
verified ·
1 Parent(s): 0385eb1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +333 -0
app.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import clip
4
+ from PIL import Image
5
+ import glob
6
+ import os
7
+ import numpy as np
8
+ import torch.nn.functional as F
9
+ from haystack import Pipeline
10
+ from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder
11
+ from haystack.components.preprocessors import DocumentSplitter
12
+ from haystack.components.writers import DocumentWriter
13
+ from haystack.components.converters import PyPDFToDocument
14
+ from haystack.document_stores.in_memory import InMemoryDocumentStore
15
+ from haystack.components.retrievers.in_memory import InMemoryBM25Retriever, InMemoryEmbeddingRetriever
16
+ from haystack.components.joiners import DocumentJoiner
17
+ from haystack.components.rankers import TransformersSimilarityRanker
18
+ from haystack.components.builders import PromptBuilder
19
+ from haystack_integrations.components.generators.google_ai import GoogleAIGeminiGenerator
20
+
21
+ # Initialize Streamlit session state
22
+ if "messages" not in st.session_state:
23
+ st.session_state.messages = []
24
+ if "document_store" not in st.session_state:
25
+ st.session_state.document_store = InMemoryDocumentStore()
26
+ st.session_state.pipeline_initialized = False
27
+
28
+ # CLIP Model initialization
29
+ device = "cuda" if torch.cuda.is_available() else "cpu"
30
+ IMAGE_DIR = "./new_data"
31
+
32
+ @st.cache_resource
33
+ def load_clip_model():
34
+ return clip.load("ViT-L/14", device=device)
35
+
36
+ model, preprocess = load_clip_model()
37
+
38
+ @st.cache_data
39
+ def load_images():
40
+ images = []
41
+ if os.path.exists(IMAGE_DIR):
42
+ image_files = [f for f in os.listdir(IMAGE_DIR) if f.endswith(('png', 'jpg', 'jpeg'))]
43
+ for image_file in image_files:
44
+ image_path = os.path.join(IMAGE_DIR, image_file)
45
+ image = Image.open(image_path).convert("RGB")
46
+ images.append((image_file, image))
47
+ return images
48
+
49
+ @st.cache_data
50
+ def encode_images(images):
51
+ image_features = []
52
+ for image_file, image in images:
53
+ image_input = preprocess(image).unsqueeze(0).to(device)
54
+ with torch.no_grad():
55
+ image_feature = model.encode_image(image_input)
56
+ image_feature = F.normalize(image_feature, dim=-1)
57
+ image_features.append((image_file, image_feature))
58
+ return image_features
59
+
60
+ def search_images_by_text(text_query, top_k=5):
61
+ text_inputs = clip.tokenize([text_query]).to(device)
62
+ with torch.no_grad():
63
+ text_features = model.encode_text(text_inputs)
64
+ text_features = F.normalize(text_features, dim=-1)
65
+
66
+ similarities = []
67
+ for image_file, image_feature in image_features:
68
+ similarity = torch.cosine_similarity(text_features, image_feature).item()
69
+ similarities.append((image_file, similarity))
70
+
71
+ similarities.sort(key=lambda x: x[1], reverse=True)
72
+ return similarities[:top_k]
73
+
74
+ def search_images_by_image(query_image, top_k=5):
75
+ query_image = preprocess(query_image).unsqueeze(0).to(device)
76
+ with torch.no_grad():
77
+ query_image_feature = model.encode_image(query_image)
78
+ query_image_feature = F.normalize(query_image_feature, dim=-1)
79
+
80
+ similarities = []
81
+ for image_file, image_feature in image_features:
82
+ similarity = torch.cosine_similarity(query_image_feature, image_feature).item()
83
+ similarities.append((image_file, similarity))
84
+
85
+ similarities.sort(key=lambda x: x[1], reverse=True)
86
+ return similarities[:top_k]
87
+
88
+ # Custom CSS
89
+ st.markdown("""
90
+ <style>
91
+ .title {
92
+ font-size: 40px;
93
+ color: #FF4B4B;
94
+ font-weight: bold;
95
+ text-align: center;
96
+ }
97
+ .subtitle {
98
+ font-size: 24px;
99
+ color: #FF914D;
100
+ font-weight: bold;
101
+ margin-top: 30px;
102
+ }
103
+ .result-container {
104
+ border: 1px solid #ddd;
105
+ padding: 10px;
106
+ border-radius: 10px;
107
+ text-align: center;
108
+ margin-bottom: 10px;
109
+ }
110
+ .score-badge {
111
+ color: white;
112
+ background-color: #007BFF;
113
+ padding: 5px;
114
+ border-radius: 5px;
115
+ font-weight: bold;
116
+ }
117
+ </style>
118
+ """, unsafe_allow_html=True)
119
+
120
+ # Main App
121
+ st.markdown('<h1 class="title">Multi-Model Search & QA System</h1>', unsafe_allow_html=True)
122
+
123
+ # Sidebar for app selection and setup
124
+ with st.sidebar:
125
+ st.header("Application Settings")
126
+ app_mode = st.radio("Select Application Mode:", ["Document Q&A", "Image Search"])
127
+
128
+ if app_mode == "Document Q&A":
129
+ st.header("Document Setup")
130
+ uploaded_file = st.file_uploader("Upload PDF Document", type=['pdf'])
131
+
132
+ if uploaded_file and not st.session_state.pipeline_initialized:
133
+ with open("temp.pdf", "wb") as f:
134
+ f.write(uploaded_file.getvalue())
135
+
136
+ # Initialize components
137
+ document_embedder = SentenceTransformersDocumentEmbedder(model="BAAI/bge-small-en-v1.5")
138
+
139
+ # Create indexing pipeline
140
+ indexing_pipeline = Pipeline()
141
+ indexing_pipeline.add_component("converter", PyPDFToDocument())
142
+ indexing_pipeline.add_component("splitter", DocumentSplitter(split_by="sentence", split_length=2))
143
+ indexing_pipeline.add_component("embedder", document_embedder)
144
+ indexing_pipeline.add_component("writer", DocumentWriter(st.session_state.document_store))
145
+
146
+ indexing_pipeline.connect("converter", "splitter")
147
+ indexing_pipeline.connect("splitter", "embedder")
148
+ indexing_pipeline.connect("embedder", "writer")
149
+
150
+ text_embedder2 = SentenceTransformersTextEmbedder(model="BAAI/bge-small-en-v1.5")
151
+ embedding_retriever2 = InMemoryEmbeddingRetriever(st.session_state.document_store)
152
+ bm25_retriever2 = InMemoryBM25Retriever(st.session_state.document_store)
153
+ document_joiner2 = DocumentJoiner()
154
+ ranker2 = TransformersSimilarityRanker(model="BAAI/bge-reranker-base")
155
+
156
+ with st.spinner("Processing document..."):
157
+ try:
158
+ indexing_pipeline.run({"converter": {"sources": ["temp.pdf"]}})
159
+ st.success(f"Processed {st.session_state.document_store.count_documents()} document chunks")
160
+ st.session_state.pipeline_initialized = True
161
+
162
+ # Initialize retrieval components
163
+ text_embedder = SentenceTransformersTextEmbedder(model="BAAI/bge-small-en-v1.5")
164
+ embedding_retriever = InMemoryEmbeddingRetriever(st.session_state.document_store)
165
+ bm25_retriever = InMemoryBM25Retriever(st.session_state.document_store)
166
+ document_joiner = DocumentJoiner()
167
+ ranker = TransformersSimilarityRanker(model="BAAI/bge-reranker-base")
168
+
169
+ template = """
170
+ act as a senior customer care executive and help users sorting out their queries. Be polite and friendly. Answer the user's questions based on the below context only dont try to make up any answer make sure that create a good version of all the documents that u recived and make the answer complining to the question make user the you sound exactly same as the documents delow.:
171
+ CONTEXT:
172
+ {% for document in documents %}
173
+ {{ document.content }}
174
+ {% endfor %}
175
+ Make sure to provide all the details. If the answer is not in the provided context just say, 'answer is not available in the context'. Don't provide the wrong answer.
176
+ If the person asks any external recommendation just say 'sorry i can't help you with that'.
177
+
178
+ Question: {{question}}
179
+
180
+ explain in detail
181
+ """
182
+
183
+ prompt_builder = PromptBuilder(template=template)
184
+
185
+ if "GOOGLE_API_KEY" not in os.environ:
186
+ os.environ["GOOGLE_API_KEY"] = 'AIzaSyDNIiOX5-Z1YFxZcaHFIEQr0DcXNvRelqI'
187
+ generator = GoogleAIGeminiGenerator(model="gemini-pro")
188
+
189
+ # Create retrieval pipeline
190
+ st.session_state.retrieval_pipeline = Pipeline()
191
+ st.session_state.retrieval_pipeline.add_component("text_embedder", text_embedder)
192
+ st.session_state.retrieval_pipeline.add_component("embedding_retriever", embedding_retriever)
193
+ st.session_state.retrieval_pipeline.add_component("bm25_retriever", bm25_retriever)
194
+ st.session_state.retrieval_pipeline.add_component("document_joiner", document_joiner)
195
+ st.session_state.retrieval_pipeline.add_component("ranker", ranker)
196
+ st.session_state.retrieval_pipeline.add_component("prompt_builder", prompt_builder)
197
+ st.session_state.retrieval_pipeline.add_component("llm", generator)
198
+
199
+ # Connect pipeline components
200
+ st.session_state.retrieval_pipeline.connect("text_embedder", "embedding_retriever")
201
+ st.session_state.retrieval_pipeline.connect("bm25_retriever", "document_joiner")
202
+ st.session_state.retrieval_pipeline.connect("embedding_retriever", "document_joiner")
203
+ st.session_state.retrieval_pipeline.connect("document_joiner", "ranker")
204
+ st.session_state.retrieval_pipeline.connect("ranker", "prompt_builder.documents")
205
+ st.session_state.retrieval_pipeline.connect("prompt_builder", "llm")
206
+
207
+ # Ranker pipeline
208
+ st.session_state.hybrid_retrieval2 = Pipeline()
209
+ st.session_state.hybrid_retrieval2.add_component("text_embedder", text_embedder2)
210
+ st.session_state.hybrid_retrieval2.add_component("embedding_retriever", embedding_retriever2)
211
+ st.session_state.hybrid_retrieval2.add_component("bm25_retriever", bm25_retriever2)
212
+ st.session_state.hybrid_retrieval2.add_component("document_joiner", document_joiner2)
213
+ st.session_state.hybrid_retrieval2.add_component("ranker", ranker2)
214
+
215
+ st.session_state.hybrid_retrieval2.connect("text_embedder", "embedding_retriever")
216
+ st.session_state.hybrid_retrieval2.connect("bm25_retriever", "document_joiner")
217
+ st.session_state.hybrid_retrieval2.connect("embedding_retriever", "document_joiner")
218
+ st.session_state.hybrid_retrieval2.connect("document_joiner", "ranker")
219
+
220
+ except Exception as e:
221
+ st.error(f"Error processing document: {str(e)}")
222
+ finally:
223
+ if os.path.exists("temp.pdf"):
224
+ os.remove("temp.pdf")
225
+
226
+ # Main content area
227
+ if app_mode == "Document Q&A":
228
+ st.markdown('<h2 class="subtitle">Document Q&A System</h2>', unsafe_allow_html=True)
229
+
230
+ # Display chat messages
231
+ for message in st.session_state.messages:
232
+ with st.chat_message(message["role"]):
233
+ st.markdown(message["content"])
234
+
235
+ # Chat input
236
+ if prompt := st.chat_input("Ask a question about your document"):
237
+ st.session_state.messages.append({"role": "user", "content": prompt})
238
+ with st.chat_message("user"):
239
+ st.markdown(prompt)
240
+
241
+ if st.session_state.pipeline_initialized:
242
+ with st.chat_message("assistant"):
243
+ with st.spinner("Thinking..."):
244
+ try:
245
+ result = st.session_state.retrieval_pipeline.run(
246
+ {
247
+ "text_embedder": {"text": prompt},
248
+ "bm25_retriever": {"query": prompt},
249
+ "ranker": {"query": prompt},
250
+ "prompt_builder": {"question": prompt}
251
+ }
252
+ )
253
+ result2 = st.session_state.hybrid_retrieval2.run(
254
+ {
255
+ "text_embedder": {"text": prompt},
256
+ "bm25_retriever": {"query": prompt},
257
+ "ranker": {"query": prompt}
258
+ }
259
+ )
260
+ l = []
261
+ for i in result2['ranker']['documents']:
262
+ if i.meta['file_path'] in l:
263
+ pass
264
+ else:
265
+ l.append(i.meta['file_path'])
266
+ l.append(i.meta['page_number'])
267
+
268
+ response = result['llm']['replies'][0]
269
+ response = f"{response} \n\nsource: {l} "
270
+ st.markdown(response)
271
+ st.session_state.messages.append({"role": "assistant", "content": response})
272
+
273
+ except Exception as e:
274
+ error_message = f"Error generating response: {str(e)}"
275
+ st.error(error_message)
276
+ st.session_state.messages.append({"role": "assistant", "content": error_message})
277
+ else:
278
+ with st.chat_message("assistant"):
279
+ message = "Please upload a document first to start the conversation."
280
+ st.warning(message)
281
+ st.session_state.messages.append({"role": "assistant", "content": message})
282
+
283
+ else: # Image Search mode
284
+ st.markdown('<h2 class="subtitle">Image Search System</h2>', unsafe_allow_html=True)
285
+
286
+ # Load and encode images
287
+ images = load_images()
288
+ image_features = encode_images(images)
289
+
290
+ search_type = st.radio("Select Search Type:", ["Text-to-Image", "Image-to-Image"])
291
+
292
+ if search_type == "Text-to-Image":
293
+ query = st.text_input("Enter a text description to find similar images:")
294
+
295
+ if query:
296
+ results = search_images_by_text(query)
297
+ st.write(f"Top results for query: **{query}**")
298
+
299
+ cols = st.columns(3)
300
+ for idx, (image_file, score) in enumerate(results):
301
+ with cols[idx % 3]:
302
+ st.markdown(f'<div class="result-container">', unsafe_allow_html=True)
303
+ image_path = os.path.join(IMAGE_DIR, image_file)
304
+ image = Image.open(image_path)
305
+ st.image(image, caption=image_file)
306
+ st.markdown(f'<span class="score-badge">Score: {score:.4f}</span>', unsafe_allow_html=True)
307
+ st.markdown('</div>', unsafe_allow_html=True)
308
+
309
+ else: # Image-to-Image search
310
+ uploaded_image = st.file_uploader("Upload an image to find similar images:", type=["png", "jpg", "jpeg"])
311
+
312
+ if uploaded_image is not None:
313
+ query_image = Image.open(uploaded_image).convert("RGB")
314
+ st.image(query_image, caption="Query Image", use_column_width=True)
315
+
316
+ # Search and display results
317
+ results = search_images_by_image(query_image)
318
+ st.write("Top results for the uploaded image:")
319
+
320
+ cols = st.columns(3)
321
+ for idx, (image_file, score) in enumerate(results):
322
+ with cols[idx % 3]:
323
+ st.markdown(f'<div class="result-container">', unsafe_allow_html=True)
324
+ image_path = os.path.join(IMAGE_DIR, image_file)
325
+ image = Image.open(image_path)
326
+ st.image(image, caption=image_file)
327
+ st.markdown(f'<span class="score-badge">Score: {score:.4f}</span>', unsafe_allow_html=True)
328
+ st.markdown('</div>', unsafe_allow_html=True)
329
+
330
+ if __name__ == "__main__":
331
+ # Create the image directory if it doesn't exist
332
+ if not os.path.exists(IMAGE_DIR):
333
+ os.makedirs(IMAGE_DIR)