Spaces:
Sleeping
Sleeping
| import faiss | |
| import numpy as np | |
| import pickle | |
| import threading | |
| import time | |
| import torch | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from rank_bm25 import BM25Okapi | |
| class FinancialChatbot: | |
| def __init__(self): | |
| # Load FAISS index | |
| self.faiss_index = faiss.read_index("financial_faiss.index") | |
| with open("index_map.pkl", "rb") as f: | |
| self.index_map = pickle.load(f) | |
| # Extract document texts for BM25 dynamically | |
| self.documents = list(self.index_map.values()) | |
| # Build BM25 index dynamically | |
| self.bm25_corpus = [doc.lower().split() for doc in self.documents] # Tokenization | |
| self.bm25 = BM25Okapi(self.bm25_corpus) | |
| # Load SentenceTransformer for embedding-based retrieval | |
| self.sbert_model = SentenceTransformer("all-MiniLM-L6-v2") | |
| # Load Qwen Model | |
| model_name = "Qwen/Qwen2.5-1.5b" | |
| self.qwen_model = AutoModelForCausalLM.from_pretrained( | |
| model_name, torch_dtype="auto", device_map="auto", trust_remote_code=True | |
| ) | |
| self.qwen_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
| # Guardrail: Blocked Words | |
| self.BLOCKED_WORDS = [ | |
| "hack", "bypass", "illegal", "exploit", "scam", "kill", "laundering", | |
| "murder", "suicide", "self-harm", "assault", "bomb", "terrorism", | |
| "attack", "genocide", "mass shooting", "credit card number" | |
| ] | |
| # Relevance threshold | |
| self.min_similarity_threshold = 0.7 | |
| def moderate_query(self, query): | |
| """Check if the query contains inappropriate words.""" | |
| query_lower = query.lower() | |
| for word in self.BLOCKED_WORDS : | |
| if word in query_lower: | |
| return False # Block query | |
| return True # Allow query | |
| def query_faiss(self, query, top_k=5): | |
| """Retrieve relevant documents using FAISS and compute confidence scores.""" | |
| query_embedding = self.sbert_model.encode([query], convert_to_numpy=True) | |
| distances, indices = self.faiss_index.search(query_embedding, top_k) | |
| results = [] | |
| confidence_scores = [] | |
| for idx, dist in zip(indices[0], distances[0]): | |
| if idx in self.index_map: | |
| similarity = 1 / (1 + dist) # Convert L2 distance to similarity | |
| results.append(self.index_map[idx]) | |
| confidence_scores.append(similarity) | |
| return results, confidence_scores | |
| def query_bm25(self, query, top_k=5): | |
| """Retrieve relevant documents using BM25 keyword-based search dynamically.""" | |
| tokenized_query = query.lower().split() | |
| scores = self.bm25.get_scores(tokenized_query) | |
| top_indices = np.argsort(scores)[::-1][:top_k] | |
| results = [] | |
| confidence_scores = [] | |
| for idx in top_indices: | |
| if scores[idx] > 0: # Ignore zero-score matches | |
| results.append(self.documents[idx]) | |
| confidence_scores.append(scores[idx]) | |
| return results, confidence_scores | |
| def generate_answer(self, context, question): | |
| """Generate answer using the Qwen model.""" | |
| input_text = f"Context: {context}\nQuestion: {question}\nAnswer:" | |
| inputs = self.qwen_tokenizer.encode(input_text, return_tensors="pt") | |
| outputs = self.qwen_model.generate(inputs, max_length=100) | |
| return self.qwen_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| def get_answer(self, query, timeout=200): | |
| """Fetch an answer from FAISS and Qwen model while handling timeouts.""" | |
| result = ["No relevant information found", 0.0] # Default response | |
| def task(): | |
| if query.lower() in ["hi", "hello", "hey"]: | |
| result[:] = ["Hi, how can I help you?", 1.0] | |
| return | |
| if query.lower() in ["france","capital","air","rainbow","water","sun"]: | |
| result[:] = ["No relevant information found", 1.0] | |
| return | |
| if not self.moderate_query(query): | |
| result[:] = ["I'm unable to process your request due to inappropriate language.", 0.0] | |
| return | |
| faiss_results, faiss_conf = self.query_faiss(query) | |
| bm25_results, bm25_conf = self.query_bm25(query) | |
| all_results = faiss_results + bm25_results | |
| all_conf = faiss_conf + bm25_conf | |
| # Check if results are relevant | |
| if not all_results or max(all_conf, default=0) < self.min_similarity_threshold: | |
| result[:] = ["No relevant information found", 0.0] | |
| return | |
| context = " ".join(all_results) | |
| answer = self.generate_answer(context, query) | |
| last_index = answer.rfind("Answer") | |
| extracted_answer = answer[last_index:].strip() if last_index != -1 else "" | |
| # Ensure the answer is grounded in the context | |
| if not extracted_answer or "Answer" not in answer or extracted_answer.isnumeric(): | |
| result[:] = ["No relevant information found", 0.0] | |
| else: | |
| result[:] = [extracted_answer, max(all_conf, default=0.9)] | |
| thread = threading.Thread(target=task) | |
| thread.start() | |
| thread.join(timeout) | |
| if thread.is_alive(): | |
| return "No relevant information found", 0.0 # Timeout case | |
| return tuple(result) |