Spaces:
Sleeping
Sleeping
File size: 7,407 Bytes
e2b8671 78ad15e e2b8671 78ad15e e2b8671 78ad15e e2b8671 78ad15e e2b8671 78ad15e c2063a1 78ad15e e2b8671 78ad15e e2b8671 78ad15e e2b8671 78ad15e c2063a1 78ad15e e2b8671 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import torch
import pandas as pd
import faiss
import numpy as np
import re
import os
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer
class FinancialChatbot:
def __init__(self, data_path, model_name="all-MiniLM-L6-v2", qwen_model_name="Qwen/Qwen2-0.5B-Instruct"):
self.device = "cpu"
self.data_path = data_path # Store data path
# Load SBERT for embeddings
self.sbert_model = SentenceTransformer(model_name, device=self.device)
self.sbert_model = self.sbert_model.half()
# Load Qwen model for text generation
self.qwen_model = AutoModelForCausalLM.from_pretrained(
qwen_model_name, torch_dtype=torch.float16, trust_remote_code=True
).to(self.device)
self.qwen_tokenizer = AutoTokenizer.from_pretrained(qwen_model_name, trust_remote_code=True)
# Load or create FAISS index
self.load_or_create_index()
import os # Import os for file checks
def load_or_create_index(self):
"""Loads FAISS index and index_map if they exist, otherwise creates new ones."""
if os.path.exists("financial_faiss.index") and os.path.exists("index_map.txt"):
try:
self.faiss_index = faiss.read_index("financial_faiss.index")
with open("index_map.txt", "r", encoding="utf-8") as f:
self.index_map = {i: line.strip() for i, line in enumerate(f)}
print("FAISS index and index_map loaded successfully.")
except Exception as e:
print(f"Error loading FAISS index: {e}. Recreating index...")
self.create_faiss_index()
else:
print("FAISS index or index_map not found. Creating a new one...")
self.create_faiss_index()
def create_faiss_index(self):
"""Creates a FAISS index from the provided Excel file."""
df = pd.read_excel(self.data_path)
sentences = []
self.index_map = {} # Initialize index_map
for row_idx, row in df.iterrows():
for col in df.columns[1:]: # Ignore the first column (assumed to be labels)
sentence = f"{row[df.columns[0]]} - year {col} is: {row[col]}"
sentences.append(sentence)
self.index_map[len(self.index_map)] = sentence # Store mapping
# Encode the sentences into embeddings
embeddings = self.sbert_model.encode(sentences, convert_to_numpy=True)
# Create FAISS index (FlatL2 for simplicity)
self.faiss_index = faiss.IndexFlatL2(embeddings.shape[1])
self.faiss_index.add(embeddings)
# Save index and index map
faiss.write_index(self.faiss_index, "financial_faiss.index")
with open("index_map.txt", "w", encoding="utf-8") as f:
for sentence in self.index_map.values():
f.write(sentence + "\n")
def query_faiss(self, query, top_k=3):
"""Retrieves the top_k closest sentences from FAISS index."""
query_embedding = self.sbert_model.encode([query], convert_to_numpy=True)
distances, indices = self.faiss_index.search(query_embedding, top_k)
results = [self.index_map[idx] for idx in indices[0] if idx in self.index_map]
confidences = [(1 - (dist / (np.max(distances[0]) or 1))) * 10 for dist in distances[0]]
return results, confidences
def moderate_query(self, query):
"""Blocks inappropriate queries containing restricted words."""
BLOCKED_WORDS = re.compile(r"\b(hack|bypass|illegal|exploit|scam|kill|laundering|murder|suicide|self-harm)\b", re.IGNORECASE)
return not bool(BLOCKED_WORDS.search(query))
# def generate_answer(self, context, question):
# messages = [
# {"role": "system", "content": "You are a financial assistant. Answer only finance-related questions. If the question is not related to finance, reply: 'I'm sorry, but I can only answer financial-related questions.' If the user greets you (e.g., 'Hello', 'Hi', 'Good morning'), respond politely with 'Hello! How can I assist you today?'."},
# {"role": "user", "content": f"{question} - related contect extracted form db {context}"}
# ]
# # Use Qwen's chat template
# input_text = self.qwen_tokenizer.apply_chat_template(
# messages, tokenize=False, add_generation_prompt=True
# )
# # Tokenize and move input to device
# inputs = self.qwen_tokenizer([input_text], return_tensors="pt").to(self.device)
# self.qwen_model.config.pad_token_id = self.qwen_tokenizer.eos_token_id
# # Generate response
# outputs = self.qwen_model.generate(
# inputs.input_ids,
# max_new_tokens=50,
# pad_token_id=self.qwen_tokenizer.eos_token_id,
# )
# # Extract only the newly generated part
# generated_ids = outputs[:, inputs.input_ids.shape[1]:] # Remove prompt part
# response = self.qwen_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
# return response
ValueError: Input length of input_ids is 127, but `max_length` is set to 100. This can lead to unexpected behavior. You should consider increasing `max_length` or, better yet, setting `max_new_tokens`.
def generate_answer(self, context, question):
prompt = f"""
You are a financial assistant. If the user greets you (e.g., "Hello," "Hi," "Good morning"), respond politely without requiring context.
For financial-related questions, answer based on the context provided. If the context lacks information, say "I don't know."
Context: {context}
User Query: {question}
Answer:
"""
input_text = prompt
# f"Context: {context}\nQuestion: {question}\nAnswer:"
inputs = self.qwen_tokenizer.encode(input_text, return_tensors="pt")
# outputs = self.qwen_model.generate(inputs, max_length=100)
outputs = self.qwen_model.generate(inputs, max_new_tokens=50)
return self.qwen_tokenizer.decode(outputs[0], skip_special_tokens=True)
def get_answer(self, query):
"""Main function to process a user query and return an answer."""
# Check if query is appropriate
if not self.moderate_query(query):
return "Inappropriate request.", 0.0
# Retrieve relevant documents and their confidence scores
retrieved_docs, confidences = self.query_faiss(query)
if not retrieved_docs:
return "No relevant information found.", 0.0
# Combine retrieved documents as context
context = " ".join(retrieved_docs)
avg_confidence = round(sum(confidences) / len(confidences), 2)
# Generate model response
model_response = self.generate_answer(context, query)
# Extract only the relevant part of the response
model_response = model_response.strip()
# Ensure only the actual answer is returned
if model_response.lower() in ["i don't know", "no relevant information found"]:
return "I don't know.", avg_confidence
#print(avg_confidence)
if avg_confidence == 0.0:
return "Not relevant ", avg_confidence
return model_response, avg_confidence
|