ankanghosh commited on
Commit
62d1e75
Β·
verified Β·
1 Parent(s): b7d14dc

Upload application files.

Browse files
Files changed (3) hide show
  1. rag_engine.py +287 -0
  2. requirements.txt +6 -0
  3. utils.py +95 -0
rag_engine.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import numpy as np
4
+ import faiss
5
+ import torch
6
+ import torch.nn as nn
7
+ from google.cloud import storage
8
+ from transformers import AutoTokenizer, AutoModel
9
+ import openai
10
+ import textwrap
11
+ import unicodedata
12
+ import streamlit as st
13
+ from utils import setup_gcp_auth, setup_openai_auth
14
+
15
+ # Initialize session state for model and tokenizer
16
+ if 'model' not in st.session_state:
17
+ st.session_state.model = None
18
+ if 'tokenizer' not in st.session_state:
19
+ st.session_state.tokenizer = None
20
+ if 'device' not in st.session_state:
21
+ st.session_state.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+ print(f"Using device: {st.session_state.device}")
23
+
24
+ # Load GCP authentication from utility function
25
+ try:
26
+ credentials = setup_gcp_auth()
27
+ storage_client = storage.Client(credentials=credentials)
28
+ bucket_name = "indian_spiritual-1"
29
+ bucket = storage_client.bucket(bucket_name)
30
+ print("βœ… GCP client initialized successfully")
31
+ except Exception as e:
32
+ print(f"❌ GCP client initialization error: {str(e)}")
33
+ raise
34
+
35
+ # Setup OpenAI authentication
36
+ try:
37
+ setup_openai_auth()
38
+ print("βœ… OpenAI client initialized successfully")
39
+ except Exception as e:
40
+ print(f"❌ OpenAI client initialization error: {str(e)}")
41
+ raise
42
+
43
+ # GCS Paths
44
+ metadata_file_gcs = "metadata/metadata.jsonl"
45
+ embeddings_file_gcs = "processed/embeddings/all_embeddings.npy"
46
+ faiss_index_file_gcs = "processed/indices/faiss_index.faiss"
47
+ text_chunks_file_gcs = "processed/chunks/text_chunks.txt"
48
+
49
+ # Local Paths
50
+ local_embeddings_file = "all_embeddings.npy"
51
+ local_faiss_index_file = "faiss_index.faiss"
52
+ local_text_chunks_file = "text_chunks.txt"
53
+ local_metadata_file = "metadata.jsonl"
54
+
55
+ def load_model():
56
+ try:
57
+ if st.session_state.model is None:
58
+ # Force model to CPU - more stable than GPU for this use case
59
+ os.environ["CUDA_VISIBLE_DEVICES"] = ""
60
+
61
+ print("Loading tokenizer...")
62
+ tokenizer = AutoTokenizer.from_pretrained("intfloat/e5-small-v2")
63
+
64
+ print("Loading model...")
65
+ model = AutoModel.from_pretrained(
66
+ "intfloat/e5-small-v2",
67
+ torch_dtype=torch.float16, # Use half precision
68
+ low_cpu_mem_usage=True,
69
+ device_map="auto" # Let transformers decide
70
+ )
71
+
72
+ model.eval()
73
+ torch.set_grad_enabled(False)
74
+
75
+ st.session_state.tokenizer = tokenizer
76
+ st.session_state.model = model
77
+
78
+ print("βœ… Model loaded successfully")
79
+
80
+ return st.session_state.tokenizer, st.session_state.model
81
+ except Exception as e:
82
+ print(f"❌ Error loading model: {str(e)}")
83
+ raise
84
+
85
+ def download_file_from_gcs(gcs_path, local_path):
86
+ """Download a file from GCS to local storage."""
87
+ blob = bucket.blob(gcs_path)
88
+ blob.download_to_filename(local_path)
89
+ print(f"βœ… Downloaded {gcs_path} β†’ {local_path}")
90
+
91
+ # Download necessary files
92
+ download_file_from_gcs(faiss_index_file_gcs, local_faiss_index_file)
93
+ download_file_from_gcs(text_chunks_file_gcs, local_text_chunks_file)
94
+ download_file_from_gcs(metadata_file_gcs, local_metadata_file)
95
+
96
+ # Load FAISS index
97
+ faiss_index = faiss.read_index(local_faiss_index_file)
98
+
99
+ # Load text chunks
100
+ text_chunks = {} # {ID -> (Title, Author, Text)}
101
+ with open(local_text_chunks_file, "r", encoding="utf-8") as f:
102
+ for line in f:
103
+ parts = line.strip().split("\t")
104
+ if len(parts) == 4:
105
+ text_chunks[int(parts[0])] = (parts[1], parts[2], parts[3])
106
+
107
+ # Load metadata.jsonl for publisher information
108
+ metadata_dict = {}
109
+ with open(local_metadata_file, "r", encoding="utf-8") as f:
110
+ for line in f:
111
+ item = json.loads(line)
112
+ metadata_dict[item["Title"]] = item # Store for easy lookup
113
+
114
+ print(f"βœ… FAISS index and text chunks loaded. {len(text_chunks)} passages available.")
115
+
116
+ def average_pool(last_hidden_states, attention_mask):
117
+ """Average pooling for sentence embeddings."""
118
+ last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
119
+ return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
120
+
121
+ query_embedding_cache = {}
122
+
123
+ def get_embedding(text):
124
+ if text in query_embedding_cache:
125
+ return query_embedding_cache[text]
126
+
127
+ try:
128
+ tokenizer, model = load_model()
129
+ input_text = f"query: {text}" if len(text) < 512 else f"passage: {text}"
130
+
131
+ inputs = tokenizer(
132
+ input_text,
133
+ padding=True,
134
+ truncation=True,
135
+ return_tensors="pt",
136
+ max_length=512,
137
+ return_attention_mask=True
138
+ )
139
+
140
+ # Move to CPU explicitly before processing
141
+ inputs = {k: v.to('cpu') for k, v in inputs.items()}
142
+
143
+ with torch.no_grad():
144
+ outputs = model(**inputs)
145
+ embeddings = average_pool(outputs.last_hidden_state, inputs['attention_mask'])
146
+ embeddings = nn.functional.normalize(embeddings, p=2, dim=1)
147
+ # Ensure we detach and move to numpy on CPU
148
+ embeddings = embeddings.detach().cpu().numpy()
149
+
150
+ # Explicitly clean up
151
+ del outputs
152
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
153
+
154
+ query_embedding_cache[text] = embeddings
155
+ return embeddings
156
+ except Exception as e:
157
+ print(f"❌ Embedding error: {str(e)}")
158
+ return np.zeros((1, 1024), dtype=np.float32)
159
+
160
+ def retrieve_passages(query, top_k=5, similarity_threshold=0.5):
161
+ """Retrieve top-k most relevant passages using FAISS with metadata."""
162
+ try:
163
+ print(f"\nπŸ” Retrieving passages for query: {query}")
164
+ query_embedding = get_embedding(query)
165
+ distances, indices = faiss_index.search(query_embedding, top_k * 2)
166
+
167
+ print(f"Found {len(distances[0])} potential matches")
168
+ retrieved_passages = []
169
+ retrieved_sources = []
170
+ cited_titles = set()
171
+
172
+ for dist, idx in zip(distances[0], indices[0]):
173
+ print(f"Distance: {dist:.4f}, Index: {idx}")
174
+ if idx in text_chunks and dist >= similarity_threshold:
175
+ title_with_txt, author, text = text_chunks[idx]
176
+
177
+ # Normalize title and remove .txt
178
+ clean_title = title_with_txt.replace(".txt", "") if title_with_txt.endswith(".txt") else title_with_txt
179
+ clean_title = unicodedata.normalize("NFC", clean_title)
180
+
181
+ # Ensure unique citations
182
+ if clean_title in cited_titles:
183
+ continue
184
+
185
+ metadata_entry = metadata_dict.get(clean_title, {})
186
+ author = metadata_entry.get("Author", "Unknown")
187
+ publisher = metadata_entry.get("Publisher", "Unknown")
188
+
189
+ cited_titles.add(clean_title)
190
+
191
+ retrieved_passages.append(text)
192
+ retrieved_sources.append((clean_title, author, publisher))
193
+
194
+ if len(retrieved_passages) == top_k:
195
+ break
196
+
197
+ print(f"Retrieved {len(retrieved_passages)} passages")
198
+ return retrieved_passages, retrieved_sources
199
+ except Exception as e:
200
+ print(f"❌ Error in retrieve_passages: {str(e)}")
201
+ return [], []
202
+
203
+ def answer_with_llm(query, context=None, word_limit=100):
204
+ """
205
+ Generate an answer using OpenAI GPT model with formatted citations.
206
+ """
207
+ try:
208
+ if context:
209
+ formatted_contexts = []
210
+ total_chars = 0
211
+ max_context_chars = 4000
212
+
213
+ for (title, author, publisher), text in context:
214
+ remaining_space = max(0, max_context_chars - total_chars)
215
+ excerpt_len = min(150, remaining_space)
216
+
217
+ if excerpt_len > 50:
218
+ excerpt = text[:excerpt_len].strip() + "..." if len(text) > excerpt_len else text
219
+ formatted_context = f"[{title} by {author}, Published by {publisher}] {excerpt}"
220
+ formatted_contexts.append(formatted_context)
221
+ total_chars += len(formatted_context)
222
+
223
+ if total_chars >= max_context_chars:
224
+ break
225
+
226
+ formatted_context = "\n".join(formatted_contexts)
227
+ else:
228
+ formatted_context = "No relevant information available."
229
+
230
+ # System message
231
+ system_message = (
232
+ "You are an AI specialized in Indian spiritual texts. "
233
+ "Answer based on context, summarizing ideas rather than quoting verbatim. "
234
+ "Ensure proper citation and do not include direct excerpts."
235
+ )
236
+
237
+ user_message = f"""
238
+ Context:
239
+ {formatted_context}
240
+
241
+ Question:
242
+ {query}
243
+ """
244
+
245
+ response = openai.chat.completions.create(
246
+ model="gpt-3.5-turbo",
247
+ messages=[
248
+ {"role": "system", "content": system_message},
249
+ {"role": "user", "content": user_message}
250
+ ],
251
+ max_tokens=200,
252
+ temperature=0.7
253
+ )
254
+
255
+ answer = response.choices[0].message.content.strip()
256
+
257
+ # Enforce word limit
258
+ words = answer.split()
259
+ if len(words) > word_limit:
260
+ answer = " ".join(words[:word_limit])
261
+ if not answer.endswith((".", "!", "?")):
262
+ answer += "."
263
+
264
+ return answer
265
+
266
+ except Exception as e:
267
+ print(f"❌ LLM API error: {str(e)}")
268
+ return "I apologize, but I'm unable to answer at the moment."
269
+
270
+ def process_query(query, top_k=5, word_limit=100):
271
+ """Process a query through the RAG pipeline with proper formatting."""
272
+ print(f"\nπŸ” Processing query: {query}")
273
+
274
+ retrieved_context, retrieved_sources = retrieve_passages(query, top_k=top_k)
275
+ sources = format_citations(retrieved_sources) if retrieved_sources else "No citation available."
276
+
277
+ if retrieved_context:
278
+ context_with_sources = list(zip(retrieved_sources, retrieved_context))
279
+ llm_answer_with_rag = answer_with_llm(query, context_with_sources, word_limit=word_limit)
280
+ else:
281
+ llm_answer_with_rag = "⚠️ No relevant context found."
282
+
283
+ return {"query": query, "answer_with_rag": llm_answer_with_rag, "citations": sources}
284
+
285
+ def format_citations(sources):
286
+ """Format citations to display each one on a new line."""
287
+ return "\n".join([f"πŸ“š {title} by {author}, Published by {publisher}" for title, author, publisher in sources])
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ faiss-cpu==1.10.0
2
+ transformers==4.38.2
3
+ openai==1.14.1
4
+ google-cloud-storage==2.14.0
5
+ google-auth>=2.28.1
6
+ streamlit>=1.32.0
utils.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from google.oauth2 import service_account
4
+ import streamlit as st
5
+ import openai
6
+
7
+ def setup_gcp_auth():
8
+ """Setup GCP authentication from HF Spaces, environment variables, or Streamlit secrets."""
9
+ try:
10
+ # Option 1: HF Spaces environment variable
11
+ if "GCP_CREDENTIALS" in os.environ:
12
+ gcp_credentials = json.loads(os.getenv("GCP_CREDENTIALS"))
13
+ print("βœ… Using GCP credentials from HF Spaces environment variable")
14
+ credentials = service_account.Credentials.from_service_account_info(gcp_credentials)
15
+ return credentials
16
+
17
+ # Option 2: Local environment variable pointing to file
18
+ elif "GOOGLE_APPLICATION_CREDENTIALS" in os.environ:
19
+ credentials_path = os.environ["GOOGLE_APPLICATION_CREDENTIALS"]
20
+ print(f"βœ… Using GCP credentials from file at {credentials_path}")
21
+ credentials = service_account.Credentials.from_service_account_file(credentials_path)
22
+ return credentials
23
+
24
+ # Option 3: Streamlit secrets
25
+ elif "gcp_credentials" in st.secrets:
26
+ gcp_credentials = st.secrets["gcp_credentials"]
27
+
28
+ # Handle different secret formats
29
+ if isinstance(gcp_credentials, dict) or hasattr(gcp_credentials, 'to_dict'):
30
+ # Convert AttrDict to dict if needed
31
+ if hasattr(gcp_credentials, 'to_dict'):
32
+ gcp_credentials = gcp_credentials.to_dict()
33
+
34
+ print("βœ… Using GCP credentials from Streamlit secrets (dict format)")
35
+ credentials = service_account.Credentials.from_service_account_info(gcp_credentials)
36
+ return credentials
37
+ else:
38
+ # Assume it's a JSON string
39
+ try:
40
+ gcp_credentials_dict = json.loads(gcp_credentials)
41
+ print("βœ… Using GCP credentials from Streamlit secrets (JSON string)")
42
+ credentials = service_account.Credentials.from_service_account_info(gcp_credentials_dict)
43
+ return credentials
44
+ except json.JSONDecodeError:
45
+ print("⚠️ GCP credentials in Streamlit secrets is not valid JSON, trying as file path")
46
+ if os.path.exists(gcp_credentials):
47
+ credentials = service_account.Credentials.from_service_account_file(gcp_credentials)
48
+ return credentials
49
+ else:
50
+ raise ValueError("GCP credentials format not recognized")
51
+
52
+ else:
53
+ raise ValueError("No GCP credentials found in environment or Streamlit secrets")
54
+
55
+ except Exception as e:
56
+ error_msg = f"❌ Authentication error: {str(e)}"
57
+ print(error_msg)
58
+ st.error(error_msg)
59
+ raise
60
+
61
+ def setup_openai_auth():
62
+ """Setup OpenAI API authentication from environment variables or Streamlit secrets."""
63
+ try:
64
+ # Option 1: Standard environment variable
65
+ if "OPENAI_API_KEY" in os.environ:
66
+ openai.api_key = os.getenv("OPENAI_API_KEY")
67
+ print("βœ… Using OpenAI API key from environment variable")
68
+ return
69
+
70
+ # Option 2: HF Spaces environment variable with different name
71
+ elif "OPENAI_KEY" in os.environ:
72
+ openai.api_key = os.getenv("OPENAI_KEY")
73
+ print("βœ… Using OpenAI API key from HF Spaces environment variable")
74
+ return
75
+
76
+ # Option 3: Streamlit secrets
77
+ elif "openai_api_key" in st.secrets:
78
+ openai.api_key = st.secrets["openai_api_key"]
79
+ print("βœ… Using OpenAI API key from Streamlit secrets")
80
+ return
81
+
82
+ else:
83
+ raise ValueError("No OpenAI API key found in environment or Streamlit secrets")
84
+
85
+ except Exception as e:
86
+ error_msg = f"❌ OpenAI authentication error: {str(e)}"
87
+ print(error_msg)
88
+ st.error(error_msg)
89
+ raise
90
+
91
+ def setup_all_auth():
92
+ """Setup all authentication in one call"""
93
+ gcp_creds = setup_gcp_auth()
94
+ setup_openai_auth()
95
+ return gcp_creds