financial-qa-agent / Embeddings.py
codewithpurav's picture
Add Dockerfile for Streamlit deployment
3efe7a4
raw
history blame
11.8 kB
import os
import glob
import pickle, json
from tqdm import tqdm
import numpy as np
# Try imports with friendly errors
try:
import faiss
except Exception as e:
raise ImportError("faiss is required. Install cpu version: `pip install faiss-cpu` or install via conda for GPU (faiss-gpu).") from e
try:
from sentence_transformers import SentenceTransformer
except Exception as e:
raise ImportError("sentence-transformers is required. `pip install sentence-transformers`") from e
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
from dotenv import load_dotenv
from Data_Cleaning import GetDataCleaning
from Logger import GetLogger
class GetEmbeddings:
"""
Embedding pipeline for cleaned text files.
Generates embeddings using SentenceTransformers, builds a FAISS index,
and allows searching queries against the vector database.
"""
def __init__(self, config_path="config.json", logger=None):
with open(config_path, "r") as f:
self.config = json.load(f)
cfg_paths = self.config["paths"]
cfg_emb = self.config["embedding"]
self.root = cfg_paths["root"]
self.cleaned_suffix = "_cleaned_txt"
self.chunk_words = cfg_emb["chunk_words"]
self.batch_size = cfg_emb["batch_size"]
self.faiss_index_path = cfg_paths["faiss_index"]
self.metadata_path = cfg_paths["metadata"]
self.embedding_model = cfg_emb["model"]
if not logger:
obj = GetLogger()
logger = obj.get_logger()
self.logger = logger
self.logger.info("Initializing Embedding Pipeline...")
# Device
self.device = "cuda" if self.check_cuda() and cfg_emb["use_gpu"] else "cpu"
load_dotenv()
self.hf_token = os.getenv("HF_TOKEN")
def check_cuda(self):
"""Return True if CUDA is available and usable."""
try:
if torch.cuda.is_available():
_ = torch.cuda.current_device()
self.logger.info(f"βœ… CUDA available. Device: {torch.cuda.get_device_name(0)}")
return True
self.logger.info("⚠️ CUDA not available. Using CPU.")
return False
except Exception as e:
self.logger.error(f"Error checking CUDA, defaulting to CPU. Error: {e}")
return False
def list_cleaned_files(self):
"""Return sorted list of cleaned text files under root/*{cleaned_suffix}/*.txt"""
pattern = os.path.join(self.root, f"*{self.cleaned_suffix}", "*.txt")
files = glob.glob(pattern)
files.sort()
return files
def read_text_file(self, path):
"""Read a text file and return string content."""
with open(path, "r", encoding="utf-8") as f:
return f.read()
def chunk_text_words(self, text):
"""
Simple word-based chunking.
Returns list of text chunks.
"""
words = text.split()
if not words:
return []
return [" ".join(words[i:i + self.chunk_words]) for i in range(0, len(words), self.chunk_words)]
def save_index_and_metadata(self):
"""Save FAISS index and metadata to disk."""
os.makedirs(os.path.dirname(self.faiss_index_path), exist_ok=True)
faiss.write_index(self.index, self.faiss_index_path)
with open(self.metadata_path, "wb") as f:
pickle.dump(self.metadata, f)
self.logger.info(f"πŸ’Ύ Saved FAISS index to {self.faiss_index_path}")
self.logger.info(f"πŸ’Ύ Saved metadata to {self.metadata_path}")
def load_index_and_metadata(self):
"""Load FAISS index and metadata if they exist."""
if os.path.exists(self.faiss_index_path) and os.path.exists(self.metadata_path):
try:
self.index = faiss.read_index(self.faiss_index_path)
with open(self.metadata_path, "rb") as f:
self.metadata = pickle.load(f)
self.logger.info(f"βœ… Loaded existing FAISS index + metadata from disk.")
return True
except Exception as e:
self.logger.warning(f"⚠️ Failed to load FAISS index/metadata, will rebuild. Error: {e}")
return False
return False
def load_encoder(self):
"""Loading Encoder"""
self.encoder = SentenceTransformer(self.embedding_model, device=self.device)
self.logger.info(f"Loaded embedding model '{self.embedding_model}' on {self.device}")
return self.encoder
def building_embeddings_index(self, files):
"""Build embeddings for all text chunks and return FAISS index + metadata."""
all_embeddings, metadata = [], []
next_id = 0
# Iterate files and chunks
for fp in tqdm(files, desc="Files", unit="file"):
text = self.read_text_file(fp)
if not text.strip():
continue
# metadata: infer company and file from path
# e.g., financial_reports/Infosys_cleaned_txt/Infosys_2023_AR.txt
rel = os.path.relpath(fp, self.root)
folder = rel.split(os.sep)[0]
filename = os.path.basename(fp)
chunks = self.chunk_text_words(text)
if not chunks:
continue
for i in range(0, len(chunks), self.batch_size):
batch = chunks[i:i + self.batch_size]
embs = self.encoder.encode(batch, show_progress_bar=False, convert_to_numpy=True)
embs = embs.astype(np.float32)
for j, vec in enumerate(embs):
all_embeddings.append(vec)
metadata.append({
"id": next_id,
"source_folder": folder,
"file": filename,
"chunk_id": i + j,
"text": batch[j] # store chunk text for retrieval
})
next_id += 1
if not all_embeddings:
raise RuntimeError("No embeddings were produced. Check cleaned files and chunking.")
emb_matrix = np.vstack(all_embeddings).astype(np.float32)
faiss.normalize_L2(emb_matrix)
# Build FAISS index (IndexFlatIP over normalized vectors = cosine similarity)
dim = emb_matrix.shape[1]
self.index = faiss.IndexFlatIP(dim)
self.index.add(emb_matrix)
self.metadata = metadata
self.logger.info(f"βœ… Built FAISS index with {self.index.ntotal} vectors, dim={dim}")
return self.index, self.metadata
def run(self):
"""Main entry: load or build embeddings + FAISS index."""
if self.load_index_and_metadata():
return
files = self.list_cleaned_files()
if not files:
self.logger.error("❌ No cleaned text files found.")
raise SystemExit(1)
self.load_encoder()
self.building_embeddings_index(files)
self.save_index_and_metadata()
def load_summarizer(self, model_name="google/gemma-2b"):
"""
Load summarizer LLM once.
If already loaded, skip.
"""
if hasattr(self, "summarizer_pipeline"):
self.logger.info("ℹ️ Summarizer already loaded, skipping reload.")
return
try:
self.logger.info(f"⏳ Loading summarizer model '{model_name}'...")
self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=self.hf_token)
self.summarizer_model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
device_map=self.device,
token=self.hf_token
)
self.summarizer_pipeline = pipeline(
"text-generation",
model=self.summarizer_model,
tokenizer=self.tokenizer
)
self.logger.info(f"βœ… Summarizer model '{model_name}' loaded successfully.")
except RuntimeError as e:
if "CUDA out of memory" in str(e):
self.logger.warning("⚠️ CUDA OOM while loading summarizer. Retrying on CPU...")
self.device = "cpu"
torch.cuda.empty_cache()
return self.load_summarizer(model_name=model_name)
else:
self.logger.error(f"❌ Failed to load summarizer: {e}")
raise
def summarize_chunks(self, chunks, max_content_tokens=2048, max_output_tokens=256):
"""
Summarize list of text chunks using LLM.
- Chunks are joined until they fit into max_context_tokens
- Generates a concise summary.
"""
if not hasattr(self, "summarizer_pipeline"):
self.load_summarizer()
self.logger.info("Summarizer not initialized. Called load_summarizer(). pipeline will work with default parameters.")
# Join chunks into one context, respecting token budget
context = " ".join(chunks)
input_tokens = len(self.tokenizer.encode(context))
if input_tokens > max_content_tokens:
# Trim to fit context window
context = " ".join(context.split()[:max_content_tokens])
self.logger.warning("⚠️ Context truncated to fit within model token limit.")
# Build summarization prompt
prompt = f"""
Summarize the following financial report excerpts into a concise answer.
Keep it factual, short, and grounded in the text.
Excerpts:
{context}
Summary:
"""
try:
output = self.summarizer_pipeline(
prompt,
max_new_tokens=max_output_tokens,
do_sample=False
)[0]["generated_text"]
if "Summary:" in output:
summary = output.split("Summary:")[-1].strip()
else:
summary = output.strip()
return summary
except RuntimeError as e:
if "CUDA out of memory" in str(e):
self.logger.warning("⚠️ CUDA OOM during summarization. Retrying on CPU...")
self.device = "cpu"
torch.cuda.empty_cache()
return self.summarize_chunks(chunks, max_content_tokens, max_output_tokens)
else:
self.logger.error(f"❌ Summarizer failed: {e}. Falling back to raw chunks.")
return " ".join(chunks[:2]) # fallback: return first 2 chunks
def answer_query(self, query, top_k=3):
"""
End-to-end QA:
- Retrieve relevant chunks from FAISS
- Summarize into a final answer.
"""
try:
#step 1: Retrieve
self.logger.info(f"πŸ” searching vector DB for query: {query}")
q_emb = self.encoder.encode(query, show_progress_bar=False, convert_to_numpy=True).reshape(1, -1)
faiss.normalize_L2(q_emb)
scores, idxs = self.index.search(q_emb, k=top_k)
chunks = [self.metadata[idx]["text"] for idx in idxs[0]]
# Step 2: Summarize
summary = self.summarize_chunks(chunks)
# Log results
self.logger.info(f"βœ… Final Answer: {summary}")
return summary
except Exception as e:
self.logger.error(f"Error in answer_query: {e}")
return None
# Example
ge = GetEmbeddings()
# ge.run()
# # NEW STEP
# ge.load_summarizer("google/gemma-2b")
# answer = ge.answer_query("What are the key highlights from Q2 financial report?")
# print(answer)