ankanghosh commited on
Commit
6aa479a
Β·
verified Β·
1 Parent(s): 1ab74f5

Update rag_engine.py

Browse files
Files changed (1) hide show
  1. rag_engine.py +66 -63
rag_engine.py CHANGED
@@ -11,23 +11,22 @@ import textwrap
11
  import unicodedata
12
  import streamlit as st
13
  from utils import setup_gcp_auth, setup_openai_auth
14
- import gc # Added for explicit garbage collection
15
 
16
  # Force model to CPU for stability
17
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
18
 
19
- # Create a function to initialize session state
20
- def initialize_session_state():
21
- if 'model_initialized' not in st.session_state:
22
- st.session_state.model_initialized = False
23
- st.session_state.model = None
24
- st.session_state.tokenizer = None
25
- st.session_state.device = torch.device("cpu")
26
- st.session_state.data_loaded = False
27
- print("Initialized session state variables")
28
 
29
- # Call the initialization function right away
30
- initialize_session_state()
 
 
 
31
 
32
  # Load GCP authentication from utility function
33
  def setup_gcp_client():
@@ -52,59 +51,49 @@ def setup_openai_client():
52
  print(f"❌ OpenAI client initialization error: {str(e)}")
53
  return False
54
 
55
- # GCS Paths
56
- metadata_file_gcs = "metadata/metadata.jsonl"
57
- embeddings_file_gcs = "processed/embeddings/all_embeddings.npy"
58
- faiss_index_file_gcs = "processed/indices/faiss_index.faiss"
59
- text_chunks_file_gcs = "processed/chunks/text_chunks.txt"
60
-
61
- # Local Paths
62
- local_embeddings_file = "all_embeddings.npy"
63
- local_faiss_index_file = "faiss_index.faiss"
64
- local_text_chunks_file = "text_chunks.txt"
65
- local_metadata_file = "metadata.jsonl"
66
-
67
  def load_model():
 
68
  try:
69
- # Check if model is already loaded
70
- if st.session_state.model is not None and st.session_state.tokenizer is not None:
71
- print("Model already loaded, reusing existing instance")
72
  return st.session_state.tokenizer, st.session_state.model
73
-
74
- # Force model to CPU - more stable than GPU for this use case
75
- os.environ["CUDA_VISIBLE_DEVICES"] = ""
76
 
77
- print("Loading tokenizer...")
78
- tokenizer = AutoTokenizer.from_pretrained("intfloat/e5-small-v2")
79
 
80
- print("Loading model...")
 
81
  model = AutoModel.from_pretrained(
82
  "intfloat/e5-small-v2",
83
- torch_dtype=torch.float16 # Use half precision
84
  )
85
 
86
- # Move model to CPU explicitly
87
- model = model.to('cpu')
88
  model.eval()
 
 
89
  torch.set_grad_enabled(False)
90
 
91
  # Store in session state
92
  st.session_state.tokenizer = tokenizer
93
  st.session_state.model = model
94
- st.session_state.model_initialized = True
95
 
96
  print("βœ… Model loaded successfully")
97
-
98
  return tokenizer, model
 
99
  except Exception as e:
100
  print(f"❌ Error loading model: {str(e)}")
101
- # Return None values instead of raising to avoid crashing
102
  return None, None
103
 
104
  def download_file_from_gcs(bucket, gcs_path, local_path):
105
  """Download a file from GCS to local storage."""
106
  try:
107
- # Check if file already exists locally
108
  if os.path.exists(local_path):
109
  print(f"File already exists locally: {local_path}")
110
  return True
@@ -118,12 +107,13 @@ def download_file_from_gcs(bucket, gcs_path, local_path):
118
  return False
119
 
120
  def load_data_files():
 
121
  # Check if already loaded in session state
122
- if hasattr(st.session_state, 'faiss_index') and st.session_state.faiss_index is not None:
123
  print("Using cached data files from session state")
124
  return st.session_state.faiss_index, st.session_state.text_chunks, st.session_state.metadata_dict
125
 
126
- # Initialize GCP and OpenAI clients
127
  bucket = setup_gcp_client()
128
  openai_initialized = setup_openai_client()
129
 
@@ -160,24 +150,23 @@ def load_data_files():
160
  print(f"❌ Error loading text chunks: {str(e)}")
161
  return None, None, None
162
 
163
- # Load metadata.jsonl for publisher information
164
  try:
165
  metadata_dict = {}
166
  with open(local_metadata_file, "r", encoding="utf-8") as f:
167
  for line in f:
168
  item = json.loads(line)
169
- metadata_dict[item["Title"]] = item # Store for easy lookup
170
  except Exception as e:
171
  print(f"❌ Error loading metadata: {str(e)}")
172
  return None, None, None
173
 
174
- print(f"βœ… FAISS index and text chunks loaded. {len(text_chunks)} passages available.")
175
 
176
  # Store in session state
177
  st.session_state.faiss_index = faiss_index
178
  st.session_state.text_chunks = text_chunks
179
  st.session_state.metadata_dict = metadata_dict
180
- st.session_state.data_loaded = True
181
 
182
  return faiss_index, text_chunks, metadata_dict
183
 
@@ -186,25 +175,31 @@ def average_pool(last_hidden_states, attention_mask):
186
  last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
187
  return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
188
 
 
189
  query_embedding_cache = {}
190
 
191
  def get_embedding(text):
 
 
192
  if text in query_embedding_cache:
193
  return query_embedding_cache[text]
194
 
195
  try:
196
- # Ensure model initialization
197
- if not hasattr(st.session_state, 'model') or st.session_state.model is None:
198
  tokenizer, model = load_model()
199
- if model is None:
200
- return np.zeros((1, 384), dtype=np.float32)
201
  else:
202
  tokenizer, model = st.session_state.tokenizer, st.session_state.model
203
 
 
 
 
 
 
204
  # Prepare text
205
  input_text = f"query: {text}" if len(text) < 512 else f"passage: {text}"
206
 
207
- # Explicitly specify truncation parameters
208
  inputs = tokenizer(
209
  input_text,
210
  padding=True,
@@ -214,20 +209,18 @@ def get_embedding(text):
214
  return_attention_mask=True
215
  )
216
 
217
- # Move to CPU explicitly
218
- inputs = {k: v.to('cpu') for k, v in inputs.items()}
219
-
220
  with torch.no_grad():
221
  outputs = model(**inputs)
222
  embeddings = average_pool(outputs.last_hidden_state, inputs['attention_mask'])
223
  embeddings = nn.functional.normalize(embeddings, p=2, dim=1)
224
  embeddings = embeddings.detach().cpu().numpy()
225
 
226
- # Explicitly clean up
227
  del outputs, inputs
228
  gc.collect()
229
- torch.cuda.empty_cache() if torch.cuda.is_available() else None
230
 
 
231
  query_embedding_cache[text] = embeddings
232
  return embeddings
233
  except Exception as e:
@@ -238,7 +231,11 @@ def retrieve_passages(query, faiss_index, text_chunks, metadata_dict, top_k=5, s
238
  """Retrieve top-k most relevant passages using FAISS with metadata."""
239
  try:
240
  print(f"\nπŸ” Retrieving passages for query: {query}")
 
 
241
  query_embedding = get_embedding(query)
 
 
242
  distances, indices = faiss_index.search(query_embedding, top_k * 2)
243
 
244
  print(f"Found {len(distances[0])} potential matches")
@@ -246,29 +243,31 @@ def retrieve_passages(query, faiss_index, text_chunks, metadata_dict, top_k=5, s
246
  retrieved_sources = []
247
  cited_titles = set()
248
 
 
249
  for dist, idx in zip(distances[0], indices[0]):
250
  print(f"Distance: {dist:.4f}, Index: {idx}")
251
  if idx in text_chunks and dist >= similarity_threshold:
252
  title_with_txt, author, text = text_chunks[idx]
253
 
254
- # Normalize title and remove .txt
255
  clean_title = title_with_txt.replace(".txt", "") if title_with_txt.endswith(".txt") else title_with_txt
256
  clean_title = unicodedata.normalize("NFC", clean_title)
257
 
258
- # Ensure unique citations
259
  if clean_title in cited_titles:
260
  continue
261
 
262
- # Get metadata safely
263
  metadata_entry = metadata_dict.get(clean_title, {})
264
  author = metadata_entry.get("Author", "Unknown")
265
  publisher = metadata_entry.get("Publisher", "Unknown")
266
 
 
267
  cited_titles.add(clean_title)
268
-
269
  retrieved_passages.append(text)
270
  retrieved_sources.append((clean_title, author, publisher))
271
 
 
272
  if len(retrieved_passages) == top_k:
273
  break
274
 
@@ -279,10 +278,9 @@ def retrieve_passages(query, faiss_index, text_chunks, metadata_dict, top_k=5, s
279
  return [], []
280
 
281
  def answer_with_llm(query, context=None, word_limit=100):
282
- """
283
- Generate an answer using OpenAI GPT model with formatted citations.
284
- """
285
  try:
 
286
  if context:
287
  formatted_contexts = []
288
  total_chars = 0
@@ -312,6 +310,7 @@ def answer_with_llm(query, context=None, word_limit=100):
312
  "Ensure proper citation and do not include direct excerpts."
313
  )
314
 
 
315
  user_message = f"""
316
  Context:
317
  {formatted_context}
@@ -319,6 +318,7 @@ def answer_with_llm(query, context=None, word_limit=100):
319
  {query}
320
  """
321
 
 
322
  response = openai.chat.completions.create(
323
  model="gpt-3.5-turbo",
324
  messages=[
@@ -371,6 +371,7 @@ def process_query(query, top_k=5, word_limit=100):
371
  "citations": "No citations available."
372
  }
373
 
 
374
  retrieved_context, retrieved_sources = retrieve_passages(
375
  query,
376
  faiss_index,
@@ -379,8 +380,10 @@ def process_query(query, top_k=5, word_limit=100):
379
  top_k=top_k
380
  )
381
 
 
382
  sources = format_citations(retrieved_sources) if retrieved_sources else "No citation available."
383
 
 
384
  if retrieved_context:
385
  context_with_sources = list(zip(retrieved_sources, retrieved_context))
386
  llm_answer_with_rag = answer_with_llm(query, context_with_sources, word_limit=word_limit)
 
11
  import unicodedata
12
  import streamlit as st
13
  from utils import setup_gcp_auth, setup_openai_auth
14
+ import gc
15
 
16
  # Force model to CPU for stability
17
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
18
 
19
+ # GCS Paths
20
+ metadata_file_gcs = "metadata/metadata.jsonl"
21
+ embeddings_file_gcs = "processed/embeddings/all_embeddings.npy"
22
+ faiss_index_file_gcs = "processed/indices/faiss_index.faiss"
23
+ text_chunks_file_gcs = "processed/chunks/text_chunks.txt"
 
 
 
 
24
 
25
+ # Local Paths
26
+ local_embeddings_file = "all_embeddings.npy"
27
+ local_faiss_index_file = "faiss_index.faiss"
28
+ local_text_chunks_file = "text_chunks.txt"
29
+ local_metadata_file = "metadata.jsonl"
30
 
31
  # Load GCP authentication from utility function
32
  def setup_gcp_client():
 
51
  print(f"❌ OpenAI client initialization error: {str(e)}")
52
  return False
53
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  def load_model():
55
+ """Load the embedding model and store in session state"""
56
  try:
57
+ # Check if model already loaded
58
+ if 'model' in st.session_state and st.session_state.model is not None:
59
+ print("Model already loaded in session state")
60
  return st.session_state.tokenizer, st.session_state.model
61
+
62
+ print("Loading new model instance...")
 
63
 
64
+ # Force model to CPU
65
+ device = torch.device("cpu")
66
 
67
+ # Load tokenizer and model
68
+ tokenizer = AutoTokenizer.from_pretrained("intfloat/e5-small-v2")
69
  model = AutoModel.from_pretrained(
70
  "intfloat/e5-small-v2",
71
+ torch_dtype=torch.float16
72
  )
73
 
74
+ # Move to CPU and set to eval mode
75
+ model = model.to(device)
76
  model.eval()
77
+
78
+ # Disable gradient computation
79
  torch.set_grad_enabled(False)
80
 
81
  # Store in session state
82
  st.session_state.tokenizer = tokenizer
83
  st.session_state.model = model
 
84
 
85
  print("βœ… Model loaded successfully")
 
86
  return tokenizer, model
87
+
88
  except Exception as e:
89
  print(f"❌ Error loading model: {str(e)}")
90
+ # Return None values - don't raise exception
91
  return None, None
92
 
93
  def download_file_from_gcs(bucket, gcs_path, local_path):
94
  """Download a file from GCS to local storage."""
95
  try:
96
+ # Check if file already exists
97
  if os.path.exists(local_path):
98
  print(f"File already exists locally: {local_path}")
99
  return True
 
107
  return False
108
 
109
  def load_data_files():
110
+ """Load FAISS index, text chunks, and metadata"""
111
  # Check if already loaded in session state
112
+ if 'faiss_index' in st.session_state and st.session_state.faiss_index is not None:
113
  print("Using cached data files from session state")
114
  return st.session_state.faiss_index, st.session_state.text_chunks, st.session_state.metadata_dict
115
 
116
+ # Initialize clients
117
  bucket = setup_gcp_client()
118
  openai_initialized = setup_openai_client()
119
 
 
150
  print(f"❌ Error loading text chunks: {str(e)}")
151
  return None, None, None
152
 
153
+ # Load metadata
154
  try:
155
  metadata_dict = {}
156
  with open(local_metadata_file, "r", encoding="utf-8") as f:
157
  for line in f:
158
  item = json.loads(line)
159
+ metadata_dict[item["Title"]] = item
160
  except Exception as e:
161
  print(f"❌ Error loading metadata: {str(e)}")
162
  return None, None, None
163
 
164
+ print(f"βœ… Data loaded successfully: {len(text_chunks)} passages available")
165
 
166
  # Store in session state
167
  st.session_state.faiss_index = faiss_index
168
  st.session_state.text_chunks = text_chunks
169
  st.session_state.metadata_dict = metadata_dict
 
170
 
171
  return faiss_index, text_chunks, metadata_dict
172
 
 
175
  last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
176
  return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
177
 
178
+ # Cache for query embeddings
179
  query_embedding_cache = {}
180
 
181
  def get_embedding(text):
182
+ """Generate embeddings for a text query"""
183
+ # Check cache first
184
  if text in query_embedding_cache:
185
  return query_embedding_cache[text]
186
 
187
  try:
188
+ # Get model
189
+ if 'model' not in st.session_state or st.session_state.model is None:
190
  tokenizer, model = load_model()
 
 
191
  else:
192
  tokenizer, model = st.session_state.tokenizer, st.session_state.model
193
 
194
+ # Handle model load failure
195
+ if model is None:
196
+ print("Model is None, returning zero embedding")
197
+ return np.zeros((1, 384), dtype=np.float32)
198
+
199
  # Prepare text
200
  input_text = f"query: {text}" if len(text) < 512 else f"passage: {text}"
201
 
202
+ # Tokenize
203
  inputs = tokenizer(
204
  input_text,
205
  padding=True,
 
209
  return_attention_mask=True
210
  )
211
 
212
+ # Generate embeddings
 
 
213
  with torch.no_grad():
214
  outputs = model(**inputs)
215
  embeddings = average_pool(outputs.last_hidden_state, inputs['attention_mask'])
216
  embeddings = nn.functional.normalize(embeddings, p=2, dim=1)
217
  embeddings = embeddings.detach().cpu().numpy()
218
 
219
+ # Clean up
220
  del outputs, inputs
221
  gc.collect()
 
222
 
223
+ # Cache and return
224
  query_embedding_cache[text] = embeddings
225
  return embeddings
226
  except Exception as e:
 
231
  """Retrieve top-k most relevant passages using FAISS with metadata."""
232
  try:
233
  print(f"\nπŸ” Retrieving passages for query: {query}")
234
+
235
+ # Get query embedding
236
  query_embedding = get_embedding(query)
237
+
238
+ # Search in FAISS index
239
  distances, indices = faiss_index.search(query_embedding, top_k * 2)
240
 
241
  print(f"Found {len(distances[0])} potential matches")
 
243
  retrieved_sources = []
244
  cited_titles = set()
245
 
246
+ # Process results
247
  for dist, idx in zip(distances[0], indices[0]):
248
  print(f"Distance: {dist:.4f}, Index: {idx}")
249
  if idx in text_chunks and dist >= similarity_threshold:
250
  title_with_txt, author, text = text_chunks[idx]
251
 
252
+ # Clean title
253
  clean_title = title_with_txt.replace(".txt", "") if title_with_txt.endswith(".txt") else title_with_txt
254
  clean_title = unicodedata.normalize("NFC", clean_title)
255
 
256
+ # Skip duplicates
257
  if clean_title in cited_titles:
258
  continue
259
 
260
+ # Get metadata
261
  metadata_entry = metadata_dict.get(clean_title, {})
262
  author = metadata_entry.get("Author", "Unknown")
263
  publisher = metadata_entry.get("Publisher", "Unknown")
264
 
265
+ # Add to results
266
  cited_titles.add(clean_title)
 
267
  retrieved_passages.append(text)
268
  retrieved_sources.append((clean_title, author, publisher))
269
 
270
+ # Stop if we have enough
271
  if len(retrieved_passages) == top_k:
272
  break
273
 
 
278
  return [], []
279
 
280
  def answer_with_llm(query, context=None, word_limit=100):
281
+ """Generate an answer using OpenAI GPT model with formatted citations."""
 
 
282
  try:
283
+ # Format context
284
  if context:
285
  formatted_contexts = []
286
  total_chars = 0
 
310
  "Ensure proper citation and do not include direct excerpts."
311
  )
312
 
313
+ # User message
314
  user_message = f"""
315
  Context:
316
  {formatted_context}
 
318
  {query}
319
  """
320
 
321
+ # Call OpenAI API
322
  response = openai.chat.completions.create(
323
  model="gpt-3.5-turbo",
324
  messages=[
 
371
  "citations": "No citations available."
372
  }
373
 
374
+ # Get relevant passages
375
  retrieved_context, retrieved_sources = retrieve_passages(
376
  query,
377
  faiss_index,
 
380
  top_k=top_k
381
  )
382
 
383
+ # Format citations
384
  sources = format_citations(retrieved_sources) if retrieved_sources else "No citation available."
385
 
386
+ # Generate answer
387
  if retrieved_context:
388
  context_with_sources = list(zip(retrieved_sources, retrieved_context))
389
  llm_answer_with_rag = answer_with_llm(query, context_with_sources, word_limit=word_limit)