anejaprerna's picture
Update rag.py
43697b4 verified
raw
history blame
7.2 kB
# import time
import threading
import pandas as pd
import faiss
import numpy as np
# import numpy as np
import pickle
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer
# import torch
class FinancialChatbot:
def __init__(self, data_path, model_name="all-MiniLM-L6-v2", qwen_model_name="Qwen/Qwen2.5-1.5b"):
self.data_path = data_path
self.sbert_model = SentenceTransformer(model_name)
self.index_map = {}
self.faiss_index = None
# def get_device_map() -> str:
# return 'cuda' if torch.cuda.is_available() else ''
# device = get_device_map()
self.qwen_model = AutoModelForCausalLM.from_pretrained(qwen_model_name, torch_dtype="auto", device_map="cpu", trust_remote_code=True)
self.qwen_tokenizer = AutoTokenizer.from_pretrained(qwen_model_name, trust_remote_code=True)
self.load_or_create_index()
def load_or_create_index(self):
try:
self.faiss_index = faiss.read_index("financial_faiss.index")
with open("index_map.pkl", "rb") as f:
self.index_map = pickle.load(f)
print("Index loaded successfully!")
except:
print("Creating new FAISS index...")
df = pd.read_excel(self.data_path)
sentences = []
for index, row in df.iterrows():
for col in df.columns[1:]:
text = f"{row[df.columns[0]]} - year {col} is: {row[col]}"
sentences.append(text)
self.index_map[len(sentences) - 1] = text
embeddings = self.sbert_model.encode(sentences, convert_to_numpy=True)
dim = embeddings.shape[1]
self.faiss_index = faiss.IndexFlatL2(dim)
self.faiss_index.add(embeddings)
faiss.write_index(self.faiss_index, "financial_faiss.index")
with open("index_map.pkl", "wb") as f:
pickle.dump(self.index_map, f)
print("Indexing completed!")
# def query_faiss(self, query, top_k=5):
# query_embedding = self.sbert_model.encode([query], convert_to_numpy=True)
# distances, indices = self.faiss_index.search(query_embedding, top_k)
# return [self.index_map[idx] for idx in indices[0] if idx in self.index_map]
def query_faiss(self, query, top_k=5):
"""Retrieve top-k documents from FAISS and return confidence scores."""
query_embedding = self.sbert_model.encode([query], convert_to_numpy=True)
distances, indices = self.faiss_index.search(query_embedding, top_k)
results = []
confidences = []
if len(distances[0]) > 0:
max_dist = np.max(distances[0]) if np.max(distances[0]) != 0 else 1 # Avoid division by zero
for idx, dist in zip(indices[0], distances[0]):
if idx in self.index_map:
results.append(self.index_map[idx])
confidence = 1 - (dist / max_dist) # Normalize confidence (closer to 1 is better)
confidences.append(round(confidence, 2)) # Round for clarity
return results, confidences
def moderate_query(self, query):
BLOCKED_WORDS = ["hack", "bypass", "illegal", "exploit", "scam", "kill", "laundering", "murder", "suicide", "self-harm"]
return not any(word in query.lower() for word in BLOCKED_WORDS)
def generate_answer(self, context, question):
prompt = f"""
You are a financial assistant. If the user greets you (e.g., "Hello," "Hi," "Good morning"), respond politely without requiring context.
For financial-related questions, answer based on the context provided. If the context lacks information, say "I don't know."
Context: {context}
User Query: {question}
Answer:
"""
input_text = prompt
# f"Context: {context}\nQuestion: {question}\nAnswer:"
inputs = self.qwen_tokenizer.encode(input_text, return_tensors="pt")
outputs = self.qwen_model.generate(inputs, max_length=100)
return self.qwen_tokenizer.decode(outputs[0], skip_special_tokens=True)
# def get_answer(self, query, timeout=150):
# result = ["", 0.0] # Placeholder for answer and confidence
# def task():
# if self.moderate_query(query):
# retrieved_docs = self.query_faiss(query)
# context = " ".join(retrieved_docs)
# answer = self.generate_answer(context, query)
# last_index = answer.rfind("Answer")
# if answer[last_index+9:11] == "--":
# result[:] = ["No relevant information found", 0.0]
# else:
# result[:] = [answer[last_index:], 0.9]
# else:
# result[:] = ["I'm unable to process your request due to inappropriate language.", 0.0]
# thread = threading.Thread(target=task)
# thread.start()
# thread.join(timeout)
# if thread.is_alive():
# return "Execution exceeded time limit. Stopping function.", 0.0
# return tuple(result)
def get_answer(self, query, timeout=150):
"""Retrieve the best-matched answer along with confidence score, with execution timeout."""
result = ["Execution exceeded time limit. Stopping function.", 0.0] # Default timeout response
def task():
"""Processing function to retrieve and generate answer."""
if self.moderate_query(query):
retrieved_docs, confidences = self.query_faiss(query) # Get results + confidence scores
if not retrieved_docs: # If no relevant docs found
result[:] = ["No relevant information found", 0.0]
return
# Combine retrieved docs and calculate final confidence
context = " ".join(retrieved_docs)
avg_confidence = round(sum(confidences) / len(confidences), 2) # Avg confidence
answer = self.generate_answer(context, query)
last_index = answer.rfind("Answer")
if answer[last_index + 9:11] == "--":
result[:] = ["No relevant information found", 0.0]
else:
result[:] = [answer[last_index:], avg_confidence]
else:
result[:] = ["I'm unable to process your request due to inappropriate language.", 0.0]
# Start execution in a separate thread
thread = threading.Thread(target=task)
thread.start()
thread.join(timeout) # Wait for execution up to timeout
# If thread is still running after timeout, return timeout message
if thread.is_alive():
return "Execution exceeded time limit. Stopping function.", 0.0
return tuple(result)
# if __name__ == "__main__":
# chatbot = FinancialChatbot("C:\\Users\\Dell\\Downloads\\CAI_RAG\\DATA\\Nestle_Financtial_report_till2023.xlsx")
# query = "What is the Employees Cost in Dec'20?"
# print(chatbot.get_answer(query))