ankanghosh commited on
Commit
38deecc
Β·
verified Β·
1 Parent(s): 40ca1f2

Update rag_engine.py

Browse files
Files changed (1) hide show
  1. rag_engine.py +85 -52
rag_engine.py CHANGED
@@ -12,13 +12,15 @@ import unicodedata
12
  import streamlit as st
13
  from utils import setup_gcp_auth, setup_openai_auth
14
 
15
- # Initialize session state for model and tokenizer
16
  if 'model' not in st.session_state:
17
  st.session_state.model = None
 
18
  if 'tokenizer' not in st.session_state:
19
  st.session_state.tokenizer = None
 
20
  if 'device' not in st.session_state:
21
- st.session_state.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
  print(f"Using device: {st.session_state.device}")
23
 
24
  # Load GCP authentication from utility function
@@ -58,58 +60,86 @@ def load_model():
58
  # Force model to CPU - more stable than GPU for this use case
59
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
60
 
61
- print("Loading tokenizer...")
62
- tokenizer = AutoTokenizer.from_pretrained("intfloat/e5-small-v2")
63
-
64
- print("Loading model...")
65
- model = AutoModel.from_pretrained(
66
- "intfloat/e5-small-v2",
67
- torch_dtype=torch.float16, # Use half precision
68
- low_cpu_mem_usage=True,
69
- device_map="auto" # Let transformers decide
70
- )
71
-
72
- model.eval()
73
- torch.set_grad_enabled(False)
74
-
75
- st.session_state.tokenizer = tokenizer
76
- st.session_state.model = model
77
-
78
- print("βœ… Model loaded successfully")
 
79
 
80
  return st.session_state.tokenizer, st.session_state.model
81
  except Exception as e:
82
  print(f"❌ Error loading model: {str(e)}")
 
83
  raise
84
 
85
  def download_file_from_gcs(gcs_path, local_path):
86
  """Download a file from GCS to local storage."""
87
- blob = bucket.blob(gcs_path)
88
- blob.download_to_filename(local_path)
89
- print(f"βœ… Downloaded {gcs_path} β†’ {local_path}")
90
-
91
- # Download necessary files
92
- download_file_from_gcs(faiss_index_file_gcs, local_faiss_index_file)
93
- download_file_from_gcs(text_chunks_file_gcs, local_text_chunks_file)
94
- download_file_from_gcs(metadata_file_gcs, local_metadata_file)
95
-
96
- # Load FAISS index
97
- faiss_index = faiss.read_index(local_faiss_index_file)
98
-
99
- # Load text chunks
100
- text_chunks = {} # {ID -> (Title, Author, Text)}
101
- with open(local_text_chunks_file, "r", encoding="utf-8") as f:
102
- for line in f:
103
- parts = line.strip().split("\t")
104
- if len(parts) == 4:
105
- text_chunks[int(parts[0])] = (parts[1], parts[2], parts[3])
106
-
107
- # Load metadata.jsonl for publisher information
108
- metadata_dict = {}
109
- with open(local_metadata_file, "r", encoding="utf-8") as f:
110
- for line in f:
111
- item = json.loads(line)
112
- metadata_dict[item["Title"]] = item # Store for easy lookup
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  print(f"βœ… FAISS index and text chunks loaded. {len(text_chunks)} passages available.")
115
 
@@ -155,7 +185,8 @@ def get_embedding(text):
155
  return embeddings
156
  except Exception as e:
157
  print(f"❌ Embedding error: {str(e)}")
158
- return np.zeros((1, 1024), dtype=np.float32)
 
159
 
160
  def retrieve_passages(query, top_k=5, similarity_threshold=0.5):
161
  """Retrieve top-k most relevant passages using FAISS with metadata."""
@@ -198,6 +229,7 @@ def retrieve_passages(query, top_k=5, similarity_threshold=0.5):
198
  return retrieved_passages, retrieved_sources
199
  except Exception as e:
200
  print(f"❌ Error in retrieve_passages: {str(e)}")
 
201
  return [], []
202
 
203
  def answer_with_llm(query, context=None, word_limit=100):
@@ -265,8 +297,13 @@ def answer_with_llm(query, context=None, word_limit=100):
265
 
266
  except Exception as e:
267
  print(f"❌ LLM API error: {str(e)}")
 
268
  return "I apologize, but I'm unable to answer at the moment."
269
 
 
 
 
 
270
  def process_query(query, top_k=5, word_limit=100):
271
  """Process a query through the RAG pipeline with proper formatting."""
272
  print(f"\nπŸ” Processing query: {query}")
@@ -280,8 +317,4 @@ def process_query(query, top_k=5, word_limit=100):
280
  else:
281
  llm_answer_with_rag = "⚠️ No relevant context found."
282
 
283
- return {"query": query, "answer_with_rag": llm_answer_with_rag, "citations": sources}
284
-
285
- def format_citations(sources):
286
- """Format citations to display each one on a new line."""
287
- return "\n".join([f"πŸ“š {title} by {author}, Published by {publisher}" for title, author, publisher in sources])
 
12
  import streamlit as st
13
  from utils import setup_gcp_auth, setup_openai_auth
14
 
15
+ # Initialize session state for model and tokenizer FIRST - before any usage
16
  if 'model' not in st.session_state:
17
  st.session_state.model = None
18
+ print("Initialized st.session_state.model to None")
19
  if 'tokenizer' not in st.session_state:
20
  st.session_state.tokenizer = None
21
+ print("Initialized st.session_state.tokenizer to None")
22
  if 'device' not in st.session_state:
23
+ st.session_state.device = torch.device("cpu") # Force CPU for stability
24
  print(f"Using device: {st.session_state.device}")
25
 
26
  # Load GCP authentication from utility function
 
60
  # Force model to CPU - more stable than GPU for this use case
61
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
62
 
63
+ with st.spinner("Loading tokenizer and model... This may take a minute."):
64
+ print("Loading tokenizer...")
65
+ tokenizer = AutoTokenizer.from_pretrained("intfloat/e5-small-v2")
66
+
67
+ print("Loading model...")
68
+ model = AutoModel.from_pretrained(
69
+ "intfloat/e5-small-v2",
70
+ torch_dtype=torch.float16, # Use half precision
71
+ low_cpu_mem_usage=True,
72
+ # Remove device_map - it requires accelerate and causes issues
73
+ )
74
+
75
+ model.eval()
76
+ torch.set_grad_enabled(False)
77
+
78
+ st.session_state.tokenizer = tokenizer
79
+ st.session_state.model = model
80
+
81
+ print("βœ… Model loaded successfully")
82
 
83
  return st.session_state.tokenizer, st.session_state.model
84
  except Exception as e:
85
  print(f"❌ Error loading model: {str(e)}")
86
+ st.error(f"Error loading model: {str(e)}")
87
  raise
88
 
89
  def download_file_from_gcs(gcs_path, local_path):
90
  """Download a file from GCS to local storage."""
91
+ try:
92
+ blob = bucket.blob(gcs_path)
93
+ blob.download_to_filename(local_path)
94
+ print(f"βœ… Downloaded {gcs_path} β†’ {local_path}")
95
+ except Exception as e:
96
+ print(f"❌ Error downloading {gcs_path}: {str(e)}")
97
+ st.error(f"Error downloading {gcs_path}: {str(e)}")
98
+ raise
99
+
100
+ # Add error handling around file downloads
101
+ try:
102
+ # Download necessary files with a spinner to show progress
103
+ with st.spinner("Downloading necessary files..."):
104
+ download_file_from_gcs(faiss_index_file_gcs, local_faiss_index_file)
105
+ download_file_from_gcs(text_chunks_file_gcs, local_text_chunks_file)
106
+ download_file_from_gcs(metadata_file_gcs, local_metadata_file)
107
+ except Exception as e:
108
+ st.error(f"Error setting up data files: {str(e)}")
109
+ raise
110
+
111
+ # Load FAISS index with error handling
112
+ try:
113
+ faiss_index = faiss.read_index(local_faiss_index_file)
114
+ except Exception as e:
115
+ print(f"❌ Error loading FAISS index: {str(e)}")
116
+ st.error(f"Error loading FAISS index: {str(e)}")
117
+ raise
118
+
119
+ # Load text chunks with error handling
120
+ try:
121
+ text_chunks = {} # {ID -> (Title, Author, Text)}
122
+ with open(local_text_chunks_file, "r", encoding="utf-8") as f:
123
+ for line in f:
124
+ parts = line.strip().split("\t")
125
+ if len(parts) == 4:
126
+ text_chunks[int(parts[0])] = (parts[1], parts[2], parts[3])
127
+ except Exception as e:
128
+ print(f"❌ Error loading text chunks: {str(e)}")
129
+ st.error(f"Error loading text chunks: {str(e)}")
130
+ raise
131
+
132
+ # Load metadata.jsonl for publisher information with error handling
133
+ try:
134
+ metadata_dict = {}
135
+ with open(local_metadata_file, "r", encoding="utf-8") as f:
136
+ for line in f:
137
+ item = json.loads(line)
138
+ metadata_dict[item["Title"]] = item # Store for easy lookup
139
+ except Exception as e:
140
+ print(f"❌ Error loading metadata: {str(e)}")
141
+ st.error(f"Error loading metadata: {str(e)}")
142
+ raise
143
 
144
  print(f"βœ… FAISS index and text chunks loaded. {len(text_chunks)} passages available.")
145
 
 
185
  return embeddings
186
  except Exception as e:
187
  print(f"❌ Embedding error: {str(e)}")
188
+ st.error(f"Embedding error: {str(e)}")
189
+ return np.zeros((1, 384), dtype=np.float32) # Changed from 1024 to 384 for e5-small-v2
190
 
191
  def retrieve_passages(query, top_k=5, similarity_threshold=0.5):
192
  """Retrieve top-k most relevant passages using FAISS with metadata."""
 
229
  return retrieved_passages, retrieved_sources
230
  except Exception as e:
231
  print(f"❌ Error in retrieve_passages: {str(e)}")
232
+ st.error(f"Error in retrieve_passages: {str(e)}")
233
  return [], []
234
 
235
  def answer_with_llm(query, context=None, word_limit=100):
 
297
 
298
  except Exception as e:
299
  print(f"❌ LLM API error: {str(e)}")
300
+ st.error(f"LLM API error: {str(e)}")
301
  return "I apologize, but I'm unable to answer at the moment."
302
 
303
+ def format_citations(sources):
304
+ """Format citations to display each one on a new line."""
305
+ return "\n".join([f"πŸ“š {title} by {author}, Published by {publisher}" for title, author, publisher in sources])
306
+
307
  def process_query(query, top_k=5, word_limit=100):
308
  """Process a query through the RAG pipeline with proper formatting."""
309
  print(f"\nπŸ” Processing query: {query}")
 
317
  else:
318
  llm_answer_with_rag = "⚠️ No relevant context found."
319
 
320
+ return {"query": query, "answer_with_rag": llm_answer_with_rag, "citations": sources}