RanjeetB commited on
Commit
162383c
·
verified ·
1 Parent(s): 8be5120

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +305 -356
app.py CHANGED
@@ -1,376 +1,325 @@
 
 
1
  import streamlit as st
2
- import boto3
3
- import json
4
- import chromadb
5
  from datasets import load_dataset
6
- import uuid
7
- import time
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- # Simple function to connect to AWS Bedrock
10
- def connect_to_bedrock():
11
- client = boto3.client('bedrock-runtime', region_name='us-east-1')
12
- return client
13
 
14
- # Simple function to load Wikipedia documents
15
- def load_wikipedia_docs(num_docs=100):
16
- st.write(f"📚 Loading {num_docs} Wikipedia documents...")
17
-
18
- # Load Wikipedia dataset from Hugging Face
19
- dataset = load_dataset("Cohere/wikipedia-22-12-simple-embeddings", split="train")
20
-
21
- # Take only the first num_docs documents
22
- documents = []
23
- for i in range(min(num_docs, len(dataset))):
24
- doc = dataset[i]
25
- documents.append({
26
- 'text': doc['text'],
27
- 'title': doc.get('title', f'Document {i+1}'),
28
- 'id': str(i)
29
- })
30
-
31
- return documents
32
 
33
- # Simple function to split text into chunks
34
- def split_into_chunks(documents, chunk_size=500):
35
- st.write("✂️ Splitting documents into 500-character chunks...")
36
-
37
- chunks = []
38
- chunk_id = 0
39
-
40
- for doc in documents:
41
- text = doc['text']
42
- title = doc['title']
43
-
44
- # Split text into chunks of 500 characters
45
- for i in range(0, len(text), chunk_size):
46
- chunk_text = text[i:i + chunk_size]
47
- if len(chunk_text.strip()) > 50: # Only keep meaningful chunks
48
- chunks.append({
49
- 'id': str(chunk_id),
50
- 'text': chunk_text,
51
- 'title': title,
52
- 'doc_id': doc['id']
53
- })
54
- chunk_id += 1
55
-
56
- return chunks
57
 
58
- # Get embeddings from Bedrock Titan model
59
- def get_embeddings(bedrock_client, text):
60
- body = json.dumps({
61
- "inputText": text
62
- })
63
-
64
- response = bedrock_client.invoke_model(
65
- modelId="amazon.titan-embed-text-v1",
66
- body=body
67
- )
68
-
69
- result = json.loads(response['body'].read())
70
- return result['embedding']
 
 
 
 
 
 
 
71
 
72
- # Store chunks in ChromaDB
73
- def store_in_chromadb(bedrock_client, chunks):
74
- st.write("💾 Storing chunks in ChromaDB with embeddings...")
75
-
76
- # Create ChromaDB client
77
- chroma_client = chromadb.Client()
78
-
79
- # Create or get collection
80
  try:
81
- collection = chroma_client.get_collection("wikipedia_chunks")
82
- chroma_client.delete_collection("wikipedia_chunks")
83
- except:
84
- pass
85
-
86
- collection = chroma_client.create_collection("wikipedia_chunks")
87
-
88
- # Prepare data for ChromaDB
89
- ids = []
90
- texts = []
91
- metadatas = []
92
- embeddings = []
93
-
94
- progress_bar = st.progress(0)
95
-
96
- for i, chunk in enumerate(chunks):
97
- # Get embedding for each chunk
98
- embedding = get_embeddings(bedrock_client, chunk['text'])
99
-
100
- ids.append(chunk['id'])
101
- texts.append(chunk['text'])
102
- metadatas.append({
103
- 'title': chunk['title'],
104
- 'doc_id': chunk['doc_id']
105
- })
106
- embeddings.append(embedding)
107
-
108
- # Update progress
109
- progress_bar.progress((i + 1) / len(chunks))
110
 
111
- # Add to ChromaDB in batches of 100
112
- if len(ids) == 100 or i == len(chunks) - 1:
113
- collection.add(
114
- ids=ids,
115
- documents=texts,
116
- metadatas=metadatas,
117
- embeddings=embeddings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  )
119
- ids, texts, metadatas, embeddings = [], [], [], []
120
-
121
- return collection
 
 
122
 
123
- # Simple retrieval without re-ranking
124
- def simple_retrieval(collection, bedrock_client, query, top_k=10):
125
- # Get query embedding
126
- query_embedding = get_embeddings(bedrock_client, query)
127
-
128
- # Search in ChromaDB
129
- results = collection.query(
130
- query_embeddings=[query_embedding],
131
- n_results=top_k
132
- )
133
-
134
- # Format results
135
- retrieved_docs = []
136
- for i in range(len(results['documents'][0])):
137
- retrieved_docs.append({
138
- 'text': results['documents'][0][i],
139
- 'title': results['metadatas'][0][i]['title'],
140
- 'distance': results['distances'][0][i]
141
- })
142
-
143
- return retrieved_docs
144
 
145
- # Re-ranking using Claude 3.5
146
- def rerank_with_claude(bedrock_client, query, documents, top_k=5):
147
- # Create prompt for re-ranking
148
- docs_text = ""
149
- for i, doc in enumerate(documents):
150
- docs_text += f"[{i+1}] {doc['text'][:200]}...\n\n"
151
-
152
- prompt = f"""
153
- Given the query: "{query}"
154
-
155
- Please rank the following documents by relevance to the query.
156
- Return only the numbers (1, 2, 3, etc.) of the most relevant documents in order, separated by commas.
157
- Return exactly {top_k} numbers.
158
-
159
- Documents:
160
- {docs_text}
161
-
162
- Most relevant document numbers (in order):
163
- """
164
-
165
- body = json.dumps({
166
- "anthropic_version": "bedrock-2023-05-31",
167
- "max_tokens": 100,
168
- "messages": [{"role": "user", "content": prompt}]
169
- })
170
-
171
- response = bedrock_client.invoke_model(
172
- modelId="anthropic.claude-3-haiku-20240307-v1:0",
173
- body=body
174
- )
175
-
176
- result = json.loads(response['body'].read())
177
- ranking_text = result['content'][0]['text'].strip()
178
-
179
  try:
180
- # Parse the ranking
181
- rankings = [int(x.strip()) - 1 for x in ranking_text.split(',')] # Convert to 0-based index
182
-
183
- # Reorder documents based on ranking
184
- reranked_docs = []
185
- for rank in rankings[:top_k]:
186
- if 0 <= rank < len(documents):
187
- reranked_docs.append(documents[rank])
188
-
189
- return reranked_docs
190
- except:
191
- # If parsing fails, return original order
192
- return documents[:top_k]
193
 
194
- # Generate answer using retrieved documents
195
- def generate_answer(bedrock_client, query, documents):
196
- # Combine documents into context
197
- context = "\n\n".join([f"Source: {doc['title']}\n{doc['text']}" for doc in documents])
198
-
199
- prompt = f"""
200
- Based on the following information, please answer the question.
201
-
202
- Question: {query}
203
-
204
- Information:
205
- {context}
206
-
207
- Please provide a clear and comprehensive answer based on the information above.
208
- """
209
-
210
- body = json.dumps({
211
- "anthropic_version": "bedrock-2023-05-31",
212
- "max_tokens": 500,
213
- "messages": [{"role": "user", "content": prompt}]
214
- })
215
-
216
- response = bedrock_client.invoke_model(
217
- modelId="anthropic.claude-3-haiku-20240307-v1:0",
218
- body=body
219
- )
220
-
221
- result = json.loads(response['body'].read())
222
- return result['content'][0]['text']
223
 
224
- # Main app
225
- def main():
226
- st.title("🔍 Wikipedia Retrieval ")
227
- st.write("Compare search results with and without re-ranking!")
228
-
229
- # Initialize session state
230
- if 'collection' not in st.session_state:
231
- st.session_state.collection = None
232
- if 'setup_done' not in st.session_state:
233
- st.session_state.setup_done = False
234
-
235
- # Setup section
236
- if not st.session_state.setup_done:
237
- st.subheader("🛠️ Setup")
238
-
239
- if st.button("🚀 Load Wikipedia Data and Setup ChromaDB"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
  try:
241
- with st.spinner("Setting up... This may take a few minutes..."):
242
- # Connect to Bedrock
243
- bedrock_client = connect_to_bedrock()
244
-
245
- # Load Wikipedia documents
246
- documents = load_wikipedia_docs(100)
247
- st.success(f"✅ Loaded {len(documents)} documents")
248
-
249
- # Split into chunks
250
- chunks = split_into_chunks(documents, 500)
251
- st.success(f"✅ Created {len(chunks)} chunks")
252
-
253
- # Store in ChromaDB
254
- collection = store_in_chromadb(bedrock_client, chunks)
255
- st.session_state.collection = collection
256
- st.session_state.setup_done = True
257
-
258
- st.success("🎉 Setup complete! You can now test queries below.")
259
- st.balloons()
260
-
261
  except Exception as e:
262
- st.error(f" Setup failed: {str(e)}")
263
-
264
- else:
265
- st.success("✅ Setup completed! ChromaDB is ready with Wikipedia data.")
266
-
267
- # Query testing section
268
- st.subheader("🔍 Test Queries")
269
-
270
- # Predefined queries
271
- sample_queries = [
272
- "What are the main causes of climate change?",
273
- "How does quantum computing work?",
274
- "What were the social impacts of the industrial revolution?"
275
- ]
276
-
277
- # Query selection
278
- query_option = st.radio("Choose a query:",
279
- ["Custom Query"] + sample_queries)
280
-
281
- if query_option == "Custom Query":
282
- query = st.text_input("Enter your custom query:")
283
- else:
284
- query = query_option
285
- st.write(f"Selected query: **{query}**")
286
-
287
- if query:
288
- if st.button("🔍 Compare Retrieval Methods"):
289
- try:
290
- bedrock_client = connect_to_bedrock()
291
-
292
- st.write("---")
293
-
294
- # Method 1: Simple Retrieval
295
- st.subheader("📋 Method 1: Simple Retrieval (Baseline)")
296
- with st.spinner("Performing simple retrieval..."):
297
- simple_results = simple_retrieval(st.session_state.collection, bedrock_client, query, 10)
298
- simple_top5 = simple_results[:5]
299
-
300
- st.write("**Top 5 Results:**")
301
- for i, doc in enumerate(simple_top5, 1):
302
- with st.expander(f"{i}. {doc['title']} (Distance: {doc['distance']:.3f})"):
303
- st.write(doc['text'][:300] + "...")
304
-
305
- # Generate answer with simple retrieval
306
- simple_answer = generate_answer(bedrock_client, query, simple_top5)
307
- st.write("**Answer using Simple Retrieval:**")
308
- st.info(simple_answer)
309
-
310
- st.write("---")
311
-
312
- # Method 2: Retrieval with Re-ranking
313
- st.subheader("🎯 Method 2: Retrieval with Re-ranking")
314
- with st.spinner("Performing retrieval with re-ranking..."):
315
- # First get more results
316
- initial_results = simple_retrieval(st.session_state.collection, bedrock_client, query, 10)
317
-
318
- # Then re-rank them
319
- reranked_results = rerank_with_claude(bedrock_client, query, initial_results, 5)
320
-
321
- st.write("**Top 5 Re-ranked Results:**")
322
- for i, doc in enumerate(reranked_results, 1):
323
- with st.expander(f"{i}. {doc['title']} (Re-ranked)"):
324
- st.write(doc['text'][:300] + "...")
325
-
326
- # Generate answer with re-ranked results
327
- reranked_answer = generate_answer(bedrock_client, query, reranked_results)
328
- st.write("**Answer using Re-ranked Retrieval:**")
329
- st.success(reranked_answer)
330
-
331
- st.write("---")
332
- st.subheader("📊 Comparison Summary")
333
- st.write("**Simple Retrieval:** Uses only vector similarity to find relevant documents.")
334
- st.write("**Re-ranked Retrieval:** Uses Claude 3.5 to intelligently reorder results for better relevance.")
335
-
336
- except Exception as e:
337
- st.error(f"❌ Error during retrieval: {str(e)}")
338
-
339
- # Reset button
340
- if st.button("🔄 Reset Setup"):
341
- st.session_state.collection = None
342
- st.session_state.setup_done = False
343
- st.rerun()
344
 
345
- # Installation guide
346
- def show_installation_guide():
347
- with st.expander("📖 Installation Guide"):
348
- st.markdown("""
349
- **Step 1: Install Required Libraries**
350
- ```bash
351
- pip install streamlit boto3 chromadb datasets
352
- ```
353
-
354
- **Step 2: Set up AWS**
355
- ```bash
356
- aws configure
357
- ```
358
- Enter your AWS access keys when prompted.
359
-
360
- **Step 3: Run the App**
361
- ```bash
362
- streamlit run reranking_app.py
363
- ```
364
-
365
- **What this app does:**
366
- 1. Loads 100 Wikipedia documents
367
- 2. Splits them into 500-character chunks
368
- 3. Creates embeddings using Bedrock Titan
369
- 4. Stores in local ChromaDB
370
- 5. Compares simple vs re-ranked retrieval
371
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
 
373
- # Run the app
374
  if __name__ == "__main__":
375
- show_installation_guide()
376
- main()
 
1
+
2
+ from huggingface_hub import InferenceClient
3
  import streamlit as st
4
+ import logging
5
+ import os
6
+ from dotenv import load_dotenv
7
  from datasets import load_dataset
8
+ from langchain_core.documents import Document
9
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
10
+ from langchain_community.embeddings import BedrockEmbeddings
11
+ from langchain_qdrant import Qdrant
12
+ from langchain_aws import ChatBedrock
13
+ from langchain_core.prompts import ChatPromptTemplate
14
+ from langchain_core.runnables import RunnablePassthrough
15
+ from langchain_core.output_parsers import StrOutputParser
16
+ from qdrant_client import QdrantClient
17
+ from qdrant_client.models import Distance, VectorParams
18
+ import re
19
+ import json
20
+ from urllib.error import URLError
21
 
22
+ # Set up logging
23
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
24
+ logger = logging.getLogger(__name__)
 
25
 
26
+ def load_environment():
27
+ """Load and validate environment variables."""
28
+ try:
29
+ load_dotenv()
30
+ required_vars = ['AWS_ACCESS_KEY_ID', 'AWS_SECRET_ACCESS_KEY', 'AWS_REGION', 'QDRANT_URL', 'QDRANT_API_KEY']
31
+ missing_vars = [var for var in required_vars if not os.getenv(var)]
32
+ if missing_vars:
33
+ logger.error(f"Missing environment variables: {missing_vars}")
34
+ st.error(f"Missing environment variables: {missing_vars}")
35
+ raise ValueError(f"Missing environment variables: {missing_vars}")
36
+ logger.info("Environment variables loaded successfully")
37
+ except Exception as e:
38
+ logger.error(f"Error loading environment variables: {e}")
39
+ st.error(f"Error loading environment variables: {e}")
40
+ raise
 
 
 
41
 
42
+ @st.cache_resource
43
+ def load_wikipedia_documents():
44
+ """Load 100 Wikipedia documents from Cohere's HF dataset."""
45
+ try:
46
+ dataset = load_dataset(
47
+ "Cohere/wikipedia-22-12-simple-embeddings",
48
+ split="train[:100]" # Load only 100 entries
49
+ )
50
+ documents = [Document(page_content=item["text"]) for item in dataset]
51
+ logger.info(f"Loaded {len(documents)} Wikipedia documents")
52
+ if not documents:
53
+ logger.error("No documents loaded from dataset")
54
+ st.error("No documents loaded from dataset")
55
+ return []
56
+ return documents
57
+ except Exception as e:
58
+ logger.error(f"Error loading dataset: {e}")
59
+ st.error(f"Failed to load dataset: {e}")
60
+ return []
 
 
 
 
 
61
 
62
+ @st.cache_resource
63
+ def split_documents(_documents):
64
+ """Split documents into chunks."""
65
+ try:
66
+ if not _documents:
67
+ logger.error("No documents provided for splitting")
68
+ st.error("No documents provided for splitting")
69
+ return []
70
+ splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
71
+ chunks = splitter.split_documents(_documents)
72
+ logger.info(f"Split into {len(chunks)} chunks")
73
+ if not chunks:
74
+ logger.error("No chunks created from documents")
75
+ st.error("No chunks created from documents")
76
+ return []
77
+ return chunks
78
+ except Exception as e:
79
+ logger.error(f"Error splitting documents: {e}")
80
+ st.error(f"Failed to split documents: {e}")
81
+ return []
82
 
83
+ @st.cache_resource
84
+ def initialize_embeddings():
85
+ """Initialize AWS Bedrock embeddings."""
 
 
 
 
 
86
  try:
87
+ embeddings = BedrockEmbeddings(
88
+ model_id="amazon.titan-embed-text-v1",
89
+ region_name=os.getenv("AWS_REGION")
90
+ )
91
+ logger.info("Initialized Bedrock embeddings")
92
+ return embeddings
93
+ except Exception as e:
94
+ logger.error(f"Error initializing embeddings: {e}")
95
+ st.error(f"Failed to initialize embeddings: {e}")
96
+ return None
97
+
98
+ def store_in_qdrant(_chunks, _embeddings):
99
+ """Store document chunks in a hosted Qdrant instance after deleting all collections."""
100
+ try:
101
+ # Initialize Qdrant client
102
+ client = QdrantClient(
103
+ url=os.getenv("QDRANT_URL"),
104
+ api_key=os.getenv("QDRANT_API_KEY"),
105
+ timeout=30
106
+ )
 
 
 
 
 
 
 
 
 
107
 
108
+ # Test Qdrant connection
109
+ try:
110
+ client.get_collections()
111
+ logger.info("Successfully connected to Qdrant at %s", os.getenv("QDRANT_URL"))
112
+ except Exception as e:
113
+ logger.error("Failed to connect to Qdrant: %s", e)
114
+ st.error(f"Failed to connect to Qdrant: {e}")
115
+ return None
116
+
117
+ # Delete all existing collections
118
+ try:
119
+ collections = client.get_collections().collections
120
+ for collection in collections:
121
+ client.delete_collection(collection.name)
122
+ logger.info(f"Deleted Qdrant collection: {collection.name}")
123
+ logger.info("All Qdrant collections deleted")
124
+ except Exception as e:
125
+ logger.warning(f"Error deleting collections: {e}")
126
+ st.warning(f"Error deleting collections: {e}")
127
+
128
+ # Validate input chunks
129
+ if not _chunks:
130
+ logger.error("No chunks provided for Qdrant storage")
131
+ st.error("No chunks provided for Qdrant storage")
132
+ return None
133
+
134
+ # Create and populate new collection
135
+ collection_name = "wikipedia_chunks"
136
+ try:
137
+ vector_store = Qdrant.from_documents(
138
+ documents=_chunks,
139
+ embedding=_embeddings,
140
+ url=os.getenv("QDRANT_URL"),
141
+ api_key=os.getenv("QDRANT_API_KEY"),
142
+ collection_name=collection_name,
143
+ force_recreate=True # Ensure fresh collection
144
  )
145
+ logger.info(f"Created Qdrant collection {collection_name} with {len(_chunks)} chunks")
146
+ except Exception as e:
147
+ logger.error(f"Error creating Qdrant collection: {e}")
148
+ st.error(f"Failed to create Qdrant collection: {e}")
149
+ return None
150
 
151
+ # Verify storage
152
+ try:
153
+ collection_info = client.get_collection(collection_name)
154
+ stored_points = collection_info.points_count
155
+ logger.info(f"Stored {stored_points} points in Qdrant collection {collection_name}")
156
+ if stored_points == 0:
157
+ logger.error("No documents stored in Qdrant collection")
158
+ st.error("No documents stored in Qdrant collection")
159
+ return None
160
+ if stored_points != len(_chunks):
161
+ logger.warning(f"Expected {len(_chunks)} chunks, but stored {stored_points} in Qdrant")
162
+ st.warning(f"Expected {len(_chunks)} chunks, but stored {stored_points} in Qdrant")
163
+ return vector_store
164
+ except Exception as e:
165
+ logger.error(f"Error verifying Qdrant storage: {e}")
166
+ st.error(f"Failed to verify Qdrant storage: {e}")
167
+ return None
 
 
 
 
168
 
169
+ except Exception as e:
170
+ logger.error(f"Error in Qdrant storage process: {e}")
171
+ st.error(f"Failed to store documents in Qdrant: {e}")
172
+ return None
173
+
174
+ @st.cache_resource
175
+ def initialize_llm():
176
+ """Initialize AWS Bedrock Claude 3.5 Sonnet model."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  try:
178
+ llm = ChatBedrock(
179
+ model_id="anthropic.claude-3-5-sonnet-20240620-v1:0",
180
+ region_name=os.getenv("AWS_REGION"),
181
+ model_kwargs={"max_tokens": 1000}
182
+ )
183
+ logger.info("Initialized Claude 3.5 Sonnet")
184
+ return llm
185
+ except Exception as e:
186
+ logger.error(f"Error initializing LLM: {e}")
187
+ st.error(f"Failed to initialize LLM: {e}")
188
+ return None
 
 
189
 
190
+ def extract_score_from_text(text):
191
+ """Extract the first float number between 0 and 1 from the text using regex."""
192
+ try:
193
+ matches = re.findall(r'\b0(?:\.\d+)?\b|\b1(?:\.0+)?\b', text)
194
+ if not matches:
195
+ logger.warning("No score found in text")
196
+ return None
197
+ score = float(matches[0])
198
+ if 0.0 <= score <= 1.0:
199
+ return score
200
+ logger.warning(f"Score {score} out of expected range 0-1")
201
+ return None
202
+ except ValueError as e:
203
+ logger.warning(f"Cannot convert match to float: {e}")
204
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
+ def claude_rerank(docs, query, llm, top_n=5):
207
+ """Rerank documents based on relevance using the LLM."""
208
+ try:
209
+ rerank_prompt = ChatPromptTemplate.from_template(
210
+ """
211
+ Given the query: "{query}" and the document chunk: "{chunk}", please rate
212
+ the relevance on a scale from 0 to 1 (0=not relevant, 1=highly relevant).
213
+
214
+ Respond with a number only, like: 0.8
215
+ """
216
+ )
217
+ scored_docs = []
218
+ for idx, doc in enumerate(docs):
219
+ prompt = rerank_prompt.format(query=query, chunk=doc.page_content)
220
+ response = llm.invoke(prompt)
221
+ text = response.content.strip()
222
+ logger.info(f"Doc {idx} rerank raw output: {text}")
223
+ score = extract_score_from_text(text)
224
+ if score is None:
225
+ logger.warning(f"Failed to extract valid score for doc {idx}. Assigning 0.")
226
+ score = 0.0
227
+ scored_docs.append((doc, score))
228
+ scored_docs.sort(key=lambda x: x[1], reverse=True)
229
+ logger.info(f"Reranked top {top_n} docs based on scores")
230
+ return [doc for doc, _ in scored_docs[:top_n]]
231
+ except Exception as e:
232
+ logger.error(f"Error in reranking: {e}")
233
+ st.error(f"Error in reranking: {e}")
234
+ return docs[:top_n] # Fallback to original docs
235
+
236
+ def create_rag_chain(vector_store, llm, use_rerank=False):
237
+ """Create a RAG chain with or without reranking."""
238
+ try:
239
+ prompt_template = ChatPromptTemplate.from_template(
240
+ """You are a helpful assistant. Use the following context to answer the question concisely.\n\nContext:\n{context}\n\nQuestion: {question}\n\nAnswer:"""
241
+ )
242
+ retriever = vector_store.as_retriever(search_kwargs={"k": 20 if use_rerank else 5})
243
+
244
+ def rerank_context(inputs):
245
  try:
246
+ docs = retriever.invoke(inputs["question"])
247
+ if not docs:
248
+ logger.warning("No documents retrieved for query")
249
+ return {"context": "", "question": inputs["question"]}
250
+ if use_rerank:
251
+ docs = claude_rerank(docs, inputs["question"], llm)
252
+ return {"context": "\n\n".join(doc.page_content for doc in docs), "question": inputs["question"]}
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  except Exception as e:
254
+ logger.error(f"Error in rerank_context: {e}")
255
+ return {"context": "", "question": inputs["question"]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
 
257
+ chain = rerank_context | prompt_template | llm | StrOutputParser()
258
+ logger.info(f"Initialized {'re-ranked' if use_rerank else 'baseline'} RAG chain")
259
+ return chain
260
+ except Exception as e:
261
+ logger.error(f"Error creating RAG chain: {e}")
262
+ st.error(f"Failed to create RAG chain: {e}")
263
+ return None
264
+
265
+ def main():
266
+ st.title("Wikipedia Q&A with RAG (Qdrant + AWS Bedrock)")
267
+ st.write("Enter a question to get answers using baseline and reranked retrieval methods.")
268
+
269
+ # Load environment variables
270
+ try:
271
+ load_environment()
272
+ except ValueError:
273
+ return
274
+
275
+ # Initialize components
276
+ documents = load_wikipedia_documents()
277
+ if not documents:
278
+ st.error("Cannot proceed without documents")
279
+ return
280
+ chunks = split_documents(documents)
281
+ if not chunks:
282
+ st.error("Cannot proceed without document chunks")
283
+ return
284
+ embeddings = initialize_embeddings()
285
+ if embeddings is None:
286
+ st.error("Cannot proceed without embeddings")
287
+ return
288
+ vector_store = store_in_qdrant(chunks, embeddings)
289
+ if vector_store is None:
290
+ st.error("Cannot proceed without vector store")
291
+ return
292
+ llm = initialize_llm()
293
+ if llm is None:
294
+ st.error("Cannot proceed without LLM")
295
+ return
296
+
297
+ baseline_chain = create_rag_chain(vector_store, llm, use_rerank=False)
298
+ if baseline_chain is None:
299
+ st.error("Cannot proceed without baseline chain")
300
+ return
301
+ rerank_chain = create_rag_chain(vector_store, llm, use_rerank=True)
302
+ if rerank_chain is None:
303
+ st.error("Cannot proceed without rerank chain")
304
+ return
305
+
306
+ # Streamlit input
307
+ query = st.text_input("Enter your question:", placeholder="e.g., What are the main causes of climate change?")
308
+ if query:
309
+ with st.spinner("Processing your query..."):
310
+ try:
311
+ baseline_response = baseline_chain.invoke({"question": query})
312
+ rerank_response = rerank_chain.invoke({"question": query})
313
+
314
+ st.subheader("Results")
315
+ st.write("**Query:**", query)
316
+ st.write("**Baseline Answer:**")
317
+ st.write(baseline_response)
318
+ st.write("**Reranked Answer:**")
319
+ st.write(rerank_response)
320
+ except Exception as e:
321
+ logger.error(f"Error processing query: {e}")
322
+ st.error(f"Error processing query: {e}")
323
 
 
324
  if __name__ == "__main__":
325
+ main()