ankanghosh commited on
Commit
3b2ec72
·
verified ·
1 Parent(s): ac3798b

Update rag_engine.py

Browse files
Files changed (1) hide show
  1. rag_engine.py +17 -37
rag_engine.py CHANGED
@@ -66,7 +66,8 @@ local_metadata_file = "metadata.jsonl"
66
 
67
  def load_model():
68
  try:
69
- if st.session_state.model is None:
 
70
  # Force model to CPU - more stable than GPU for this use case
71
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
72
 
@@ -79,20 +80,22 @@ def load_model():
79
  torch_dtype=torch.float16 # Use half precision
80
  )
81
 
82
- # Move model to the designated device
83
- model = model.to(st.session_state.device)
84
  model.eval()
85
  torch.set_grad_enabled(False)
86
 
 
87
  st.session_state.tokenizer = tokenizer
88
  st.session_state.model = model
89
 
90
  print("✅ Model loaded successfully")
91
-
92
  return st.session_state.tokenizer, st.session_state.model
93
  except Exception as e:
94
  print(f"❌ Error loading model: {str(e)}")
95
- raise
 
96
 
97
  def download_file_from_gcs(bucket, gcs_path, local_path):
98
  """Download a file from GCS to local storage."""
@@ -172,41 +175,18 @@ query_embedding_cache = {}
172
  def get_embedding(text):
173
  if text in query_embedding_cache:
174
  return query_embedding_cache[text]
175
-
176
  try:
177
- tokenizer, model = load_model()
178
- input_text = f"query: {text}" if len(text) < 512 else f"passage: {text}"
179
-
180
- # Explicitly specify truncation parameters to avoid warnings
181
- inputs = tokenizer(
182
- input_text,
183
- padding=True,
184
- truncation=True,
185
- return_tensors="pt",
186
- max_length=512,
187
- return_attention_mask=True
188
- )
189
-
190
- # Move to CPU explicitly before processing
191
- inputs = {k: v.to('cpu') for k, v in inputs.items()}
192
 
193
- with torch.no_grad():
194
- outputs = model(**inputs)
195
- embeddings = average_pool(outputs.last_hidden_state, inputs['attention_mask'])
196
- embeddings = nn.functional.normalize(embeddings, p=2, dim=1)
197
- # Ensure we detach and move to numpy on CPU
198
- embeddings = embeddings.detach().cpu().numpy()
199
-
200
- # Explicitly clean up
201
- del outputs
202
- torch.cuda.empty_cache() if torch.cuda.is_available() else None
203
 
204
- query_embedding_cache[text] = embeddings
205
- return embeddings
206
- except Exception as e:
207
- print(f"❌ Embedding error: {str(e)}")
208
- st.error(f"Embedding error: {str(e)}")
209
- return np.zeros((1, 384), dtype=np.float32) # Changed from 1024 to 384 for e5-small-v2
210
 
211
  def retrieve_passages(query, faiss_index, text_chunks, metadata_dict, top_k=5, similarity_threshold=0.5):
212
  """Retrieve top-k most relevant passages using FAISS with metadata."""
 
66
 
67
  def load_model():
68
  try:
69
+ # Initialize model if it doesn't exist
70
+ if 'model' not in st.session_state or st.session_state.model is None:
71
  # Force model to CPU - more stable than GPU for this use case
72
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
73
 
 
80
  torch_dtype=torch.float16 # Use half precision
81
  )
82
 
83
+ # Move model to CPU explicitly
84
+ model = model.to('cpu')
85
  model.eval()
86
  torch.set_grad_enabled(False)
87
 
88
+ # Store in session state
89
  st.session_state.tokenizer = tokenizer
90
  st.session_state.model = model
91
 
92
  print("✅ Model loaded successfully")
93
+
94
  return st.session_state.tokenizer, st.session_state.model
95
  except Exception as e:
96
  print(f"❌ Error loading model: {str(e)}")
97
+ # Return None values instead of raising to avoid crashing
98
+ return None, None
99
 
100
  def download_file_from_gcs(bucket, gcs_path, local_path):
101
  """Download a file from GCS to local storage."""
 
175
  def get_embedding(text):
176
  if text in query_embedding_cache:
177
  return query_embedding_cache[text]
 
178
  try:
179
+ # Ensure model initialization
180
+ if 'model' not in st.session_state or st.session_state.model is None:
181
+ tokenizer, model = load_model()
182
+ if model is None:
183
+ return np.zeros((1, 384), dtype=np.float32) # Fallback
184
+ else:
185
+ tokenizer, model = st.session_state.tokenizer, st.session_state.model
 
 
 
 
 
 
 
 
186
 
187
+ input_text = f"query: {text}" if len(text) < 512 else f"passage: {text}"
 
 
 
 
 
 
 
 
 
188
 
189
+ # Rest of your code...
 
 
 
 
 
190
 
191
  def retrieve_passages(query, faiss_index, text_chunks, metadata_dict, top_k=5, similarity_threshold=0.5):
192
  """Retrieve top-k most relevant passages using FAISS with metadata."""