Spaces:
Sleeping
Sleeping
import os | |
import json | |
import numpy as np | |
import faiss | |
import torch | |
import torch.nn as nn | |
from google.cloud import storage | |
from transformers import AutoTokenizer, AutoModel | |
import openai | |
import textwrap | |
import unicodedata | |
import streamlit as st | |
from utils import setup_gcp_auth, setup_openai_auth | |
# Force model to CPU for stability | |
os.environ["CUDA_VISIBLE_DEVICES"] = "" | |
# Create a function to initialize session state | |
def initialize_session_state(): | |
if 'model_initialized' not in st.session_state: | |
st.session_state.model_initialized = False | |
st.session_state.model = None | |
st.session_state.tokenizer = None | |
st.session_state.device = torch.device("cpu") | |
print("Initialized session state variables") | |
# Call the initialization function right away | |
initialize_session_state() | |
# Load GCP authentication from utility function | |
def setup_gcp_client(): | |
try: | |
credentials = setup_gcp_auth() | |
storage_client = storage.Client(credentials=credentials) | |
bucket_name = "indian_spiritual-1" | |
bucket = storage_client.bucket(bucket_name) | |
print("β GCP client initialized successfully") | |
return bucket | |
except Exception as e: | |
print(f"β GCP client initialization error: {str(e)}") | |
st.error(f"GCP client initialization error: {str(e)}") | |
return None | |
# Setup OpenAI authentication | |
def setup_openai_client(): | |
try: | |
setup_openai_auth() | |
print("β OpenAI client initialized successfully") | |
return True | |
except Exception as e: | |
print(f"β OpenAI client initialization error: {str(e)}") | |
st.error(f"OpenAI client initialization error: {str(e)}") | |
return False | |
# GCS Paths | |
metadata_file_gcs = "metadata/metadata.jsonl" | |
embeddings_file_gcs = "processed/embeddings/all_embeddings.npy" | |
faiss_index_file_gcs = "processed/indices/faiss_index.faiss" | |
text_chunks_file_gcs = "processed/chunks/text_chunks.txt" | |
# Local Paths | |
local_embeddings_file = "all_embeddings.npy" | |
local_faiss_index_file = "faiss_index.faiss" | |
local_text_chunks_file = "text_chunks.txt" | |
local_metadata_file = "metadata.jsonl" | |
def load_model(): | |
try: | |
# Initialize model if it doesn't exist | |
if 'model' not in st.session_state or st.session_state.model is None: | |
# Force model to CPU - more stable than GPU for this use case | |
os.environ["CUDA_VISIBLE_DEVICES"] = "" | |
print("Loading tokenizer...") | |
tokenizer = AutoTokenizer.from_pretrained("intfloat/e5-small-v2") | |
print("Loading model...") | |
model = AutoModel.from_pretrained( | |
"intfloat/e5-small-v2", | |
torch_dtype=torch.float16 # Use half precision | |
) | |
# Move model to CPU explicitly | |
model = model.to('cpu') | |
model.eval() | |
torch.set_grad_enabled(False) | |
# Store in session state | |
st.session_state.tokenizer = tokenizer | |
st.session_state.model = model | |
print("β Model loaded successfully") | |
return st.session_state.tokenizer, st.session_state.model | |
except Exception as e: | |
print(f"β Error loading model: {str(e)}") | |
# Return None values instead of raising to avoid crashing | |
return None, None | |
def download_file_from_gcs(bucket, gcs_path, local_path): | |
"""Download a file from GCS to local storage.""" | |
try: | |
blob = bucket.blob(gcs_path) | |
blob.download_to_filename(local_path) | |
print(f"β Downloaded {gcs_path} β {local_path}") | |
return True | |
except Exception as e: | |
print(f"β Error downloading {gcs_path}: {str(e)}") | |
st.error(f"Error downloading {gcs_path}: {str(e)}") | |
return False | |
def load_data_files(): | |
# Initialize GCP and OpenAI clients | |
bucket = setup_gcp_client() | |
openai_initialized = setup_openai_client() | |
if not bucket or not openai_initialized: | |
st.error("Failed to initialize required services") | |
return None, None, None | |
# Download necessary files - remove the spinner from here | |
success = True | |
success &= download_file_from_gcs(bucket, faiss_index_file_gcs, local_faiss_index_file) | |
success &= download_file_from_gcs(bucket, text_chunks_file_gcs, local_text_chunks_file) | |
success &= download_file_from_gcs(bucket, metadata_file_gcs, local_metadata_file) | |
if not success: | |
st.error("Failed to download required files") | |
return None, None, None | |
# Load FAISS index | |
try: | |
faiss_index = faiss.read_index(local_faiss_index_file) | |
except Exception as e: | |
print(f"β Error loading FAISS index: {str(e)}") | |
st.error(f"Error loading FAISS index: {str(e)}") | |
return None, None, None | |
# Load text chunks | |
try: | |
text_chunks = {} # {ID -> (Title, Author, Text)} | |
with open(local_text_chunks_file, "r", encoding="utf-8") as f: | |
for line in f: | |
parts = line.strip().split("\t") | |
if len(parts) == 4: | |
text_chunks[int(parts[0])] = (parts[1], parts[2], parts[3]) | |
except Exception as e: | |
print(f"β Error loading text chunks: {str(e)}") | |
st.error(f"Error loading text chunks: {str(e)}") | |
return None, None, None | |
# Load metadata.jsonl for publisher information | |
try: | |
metadata_dict = {} | |
with open(local_metadata_file, "r", encoding="utf-8") as f: | |
for line in f: | |
item = json.loads(line) | |
metadata_dict[item["Title"]] = item # Store for easy lookup | |
except Exception as e: | |
print(f"β Error loading metadata: {str(e)}") | |
st.error(f"Error loading metadata: {str(e)}") | |
return None, None, None | |
print(f"β FAISS index and text chunks loaded. {len(text_chunks)} passages available.") | |
return faiss_index, text_chunks, metadata_dict | |
def average_pool(last_hidden_states, attention_mask): | |
"""Average pooling for sentence embeddings.""" | |
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) | |
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] | |
query_embedding_cache = {} | |
def get_embedding(text): | |
if text in query_embedding_cache: | |
return query_embedding_cache[text] | |
try: | |
# Ensure model initialization | |
if 'model' not in st.session_state or st.session_state.model is None: | |
tokenizer, model = load_model() | |
if model is None: | |
return np.zeros((1, 384), dtype=np.float32) # Fallback for e5-small-v2 | |
else: | |
tokenizer, model = st.session_state.tokenizer, st.session_state.model | |
input_text = f"query: {text}" if len(text) < 512 else f"passage: {text}" | |
# Explicitly specify truncation parameters to avoid warnings | |
inputs = tokenizer( | |
input_text, | |
padding=True, | |
truncation=True, | |
return_tensors="pt", | |
max_length=512, | |
return_attention_mask=True | |
) | |
# Move to CPU explicitly before processing | |
inputs = {k: v.to('cpu') for k, v in inputs.items()} | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
embeddings = average_pool(outputs.last_hidden_state, inputs['attention_mask']) | |
embeddings = nn.functional.normalize(embeddings, p=2, dim=1) | |
# Ensure we detach and move to numpy on CPU | |
embeddings = embeddings.detach().cpu().numpy() | |
# Explicitly clean up | |
del outputs | |
torch.cuda.empty_cache() if torch.cuda.is_available() else None | |
query_embedding_cache[text] = embeddings | |
return embeddings | |
except Exception as e: | |
print(f"β Embedding error: {str(e)}") | |
st.error(f"Embedding error: {str(e)}") | |
return np.zeros((1, 384), dtype=np.float32) # Changed from 1024 to 384 for e5-small-v2 | |
def retrieve_passages(query, faiss_index, text_chunks, metadata_dict, top_k=5, similarity_threshold=0.5): | |
"""Retrieve top-k most relevant passages using FAISS with metadata.""" | |
try: | |
print(f"\nπ Retrieving passages for query: {query}") | |
query_embedding = get_embedding(query) | |
distances, indices = faiss_index.search(query_embedding, top_k * 2) | |
print(f"Found {len(distances[0])} potential matches") | |
retrieved_passages = [] | |
retrieved_sources = [] | |
cited_titles = set() | |
for dist, idx in zip(distances[0], indices[0]): | |
print(f"Distance: {dist:.4f}, Index: {idx}") | |
if idx in text_chunks and dist >= similarity_threshold: | |
title_with_txt, author, text = text_chunks[idx] | |
# Normalize title and remove .txt | |
clean_title = title_with_txt.replace(".txt", "") if title_with_txt.endswith(".txt") else title_with_txt | |
clean_title = unicodedata.normalize("NFC", clean_title) | |
# Ensure unique citations | |
if clean_title in cited_titles: | |
continue | |
metadata_entry = metadata_dict.get(clean_title, {}) | |
author = metadata_entry.get("Author", "Unknown") | |
publisher = metadata_entry.get("Publisher", "Unknown") | |
cited_titles.add(clean_title) | |
retrieved_passages.append(text) | |
retrieved_sources.append((clean_title, author, publisher)) | |
if len(retrieved_passages) == top_k: | |
break | |
print(f"Retrieved {len(retrieved_passages)} passages") | |
return retrieved_passages, retrieved_sources | |
except Exception as e: | |
print(f"β Error in retrieve_passages: {str(e)}") | |
st.error(f"Error in retrieve_passages: {str(e)}") | |
return [], [] | |
def answer_with_llm(query, context=None, word_limit=100): | |
""" | |
Generate an answer using OpenAI GPT model with formatted citations. | |
""" | |
try: | |
if context: | |
formatted_contexts = [] | |
total_chars = 0 | |
max_context_chars = 4000 | |
for (title, author, publisher), text in context: | |
remaining_space = max(0, max_context_chars - total_chars) | |
excerpt_len = min(150, remaining_space) | |
if excerpt_len > 50: | |
excerpt = text[:excerpt_len].strip() + "..." if len(text) > excerpt_len else text | |
formatted_context = f"[{title} by {author}, Published by {publisher}] {excerpt}" | |
formatted_contexts.append(formatted_context) | |
total_chars += len(formatted_context) | |
if total_chars >= max_context_chars: | |
break | |
formatted_context = "\n".join(formatted_contexts) | |
else: | |
formatted_context = "No relevant information available." | |
# System message | |
system_message = ( | |
"You are an AI specialized in Indian spiritual texts. " | |
"Answer based on context, summarizing ideas rather than quoting verbatim. " | |
"Ensure proper citation and do not include direct excerpts." | |
) | |
user_message = f""" | |
Context: | |
{formatted_context} | |
Question: | |
{query} | |
""" | |
response = openai.chat.completions.create( | |
model="gpt-3.5-turbo", | |
messages=[ | |
{"role": "system", "content": system_message}, | |
{"role": "user", "content": user_message} | |
], | |
max_tokens=200, | |
temperature=0.7 | |
) | |
answer = response.choices[0].message.content.strip() | |
# Enforce word limit | |
words = answer.split() | |
if len(words) > word_limit: | |
answer = " ".join(words[:word_limit]) | |
if not answer.endswith((".", "!", "?")): | |
answer += "." | |
return answer | |
except Exception as e: | |
print(f"β LLM API error: {str(e)}") | |
st.error(f"LLM API error: {str(e)}") | |
return "I apologize, but I'm unable to answer at the moment." | |
def format_citations(sources): | |
"""Format citations to display each one on a new line with a full stop if needed.""" | |
formatted_citations = [] | |
for title, author, publisher in sources: | |
# Check if the publisher already ends with a period, question mark, or exclamation mark | |
if publisher.endswith(('.', '!', '?')): | |
formatted_citations.append(f"π {title} by {author}, Published by {publisher}") | |
else: | |
formatted_citations.append(f"π {title} by {author}, Published by {publisher}.") | |
return "\n".join(formatted_citations) | |
def process_query(query, top_k=5, word_limit=100): | |
"""Process a query through the RAG pipeline with proper formatting.""" | |
print(f"\nπ Processing query: {query}") | |
# Load data files if not already loaded | |
if not hasattr(st.session_state, 'data_loaded') or not st.session_state.data_loaded: | |
st.session_state.faiss_index, st.session_state.text_chunks, st.session_state.metadata_dict = load_data_files() | |
st.session_state.data_loaded = True | |
# Check if data loaded successfully | |
if not st.session_state.faiss_index or not st.session_state.text_chunks or not st.session_state.metadata_dict: | |
return {"query": query, "answer_with_rag": "β οΈ System error: Data files not loaded properly.", "citations": "No citations available."} | |
retrieved_context, retrieved_sources = retrieve_passages( | |
query, | |
st.session_state.faiss_index, | |
st.session_state.text_chunks, | |
st.session_state.metadata_dict, | |
top_k=top_k | |
) | |
sources = format_citations(retrieved_sources) if retrieved_sources else "No citation available." | |
if retrieved_context: | |
context_with_sources = list(zip(retrieved_sources, retrieved_context)) | |
llm_answer_with_rag = answer_with_llm(query, context_with_sources, word_limit=word_limit) | |
else: | |
llm_answer_with_rag = "β οΈ No relevant context found." | |
return {"query": query, "answer_with_rag": llm_answer_with_rag, "citations": sources} |