File size: 7,407 Bytes
e2b8671
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78ad15e
e2b8671
 
 
 
 
 
 
 
78ad15e
 
 
 
 
 
 
 
 
 
 
 
 
 
e2b8671
78ad15e
 
 
 
 
 
e2b8671
78ad15e
 
 
e2b8671
78ad15e
 
c2063a1
 
 
78ad15e
 
 
e2b8671
78ad15e
e2b8671
78ad15e
 
 
 
e2b8671
78ad15e
 
 
c2063a1
 
78ad15e
e2b8671
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import torch
import pandas as pd
import faiss
import numpy as np
import re
import os
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer

class FinancialChatbot:
    def __init__(self, data_path, model_name="all-MiniLM-L6-v2", qwen_model_name="Qwen/Qwen2-0.5B-Instruct"):
        self.device = "cpu"
        self.data_path = data_path  # Store data path

        # Load SBERT for embeddings
        self.sbert_model = SentenceTransformer(model_name, device=self.device)
        self.sbert_model = self.sbert_model.half()

        # Load Qwen model for text generation
        self.qwen_model = AutoModelForCausalLM.from_pretrained(
            qwen_model_name, torch_dtype=torch.float16, trust_remote_code=True
        ).to(self.device)

        self.qwen_tokenizer = AutoTokenizer.from_pretrained(qwen_model_name, trust_remote_code=True)

        # Load or create FAISS index
        self.load_or_create_index()

    import os  # Import os for file checks

    def load_or_create_index(self):
        """Loads FAISS index and index_map if they exist, otherwise creates new ones."""
        if os.path.exists("financial_faiss.index") and os.path.exists("index_map.txt"):
            try:
                self.faiss_index = faiss.read_index("financial_faiss.index")
                with open("index_map.txt", "r", encoding="utf-8") as f:
                    self.index_map = {i: line.strip() for i, line in enumerate(f)}
                print("FAISS index and index_map loaded successfully.")
            except Exception as e:
                print(f"Error loading FAISS index: {e}. Recreating index...")
                self.create_faiss_index()
        else:
            print("FAISS index or index_map not found. Creating a new one...")
            self.create_faiss_index()


    def create_faiss_index(self):
        """Creates a FAISS index from the provided Excel file."""
        df = pd.read_excel(self.data_path)
        sentences = []
        self.index_map = {}  # Initialize index_map

        for row_idx, row in df.iterrows():
            for col in df.columns[1:]:  # Ignore the first column (assumed to be labels)
                sentence = f"{row[df.columns[0]]} - year {col} is: {row[col]}"
                sentences.append(sentence)
                self.index_map[len(self.index_map)] = sentence  # Store mapping

        # Encode the sentences into embeddings
        embeddings = self.sbert_model.encode(sentences, convert_to_numpy=True)

        # Create FAISS index (FlatL2 for simplicity)
        self.faiss_index = faiss.IndexFlatL2(embeddings.shape[1])
        self.faiss_index.add(embeddings)

        # Save index and index map
        faiss.write_index(self.faiss_index, "financial_faiss.index")
        with open("index_map.txt", "w", encoding="utf-8") as f:
            for sentence in self.index_map.values():
                f.write(sentence + "\n")

    def query_faiss(self, query, top_k=3):
        """Retrieves the top_k closest sentences from FAISS index."""
        query_embedding = self.sbert_model.encode([query], convert_to_numpy=True)
        distances, indices = self.faiss_index.search(query_embedding, top_k)

        results = [self.index_map[idx] for idx in indices[0] if idx in self.index_map]
        confidences = [(1 - (dist / (np.max(distances[0]) or 1))) * 10 for dist in distances[0]]

        return results, confidences

    def moderate_query(self, query):
        """Blocks inappropriate queries containing restricted words."""
        BLOCKED_WORDS = re.compile(r"\b(hack|bypass|illegal|exploit|scam|kill|laundering|murder|suicide|self-harm)\b", re.IGNORECASE)
        return not bool(BLOCKED_WORDS.search(query))

    # def generate_answer(self, context, question):
    #     messages = [
    #         {"role": "system", "content": "You are a financial assistant. Answer only finance-related questions. If the question is not related to finance, reply: 'I'm sorry, but I can only answer financial-related questions.' If the user greets you (e.g., 'Hello', 'Hi', 'Good morning'), respond politely with 'Hello! How can I assist you today?'."},
    #         {"role": "user", "content": f"{question} - related contect extracted form db {context}"}
    #     ]

    #     # Use Qwen's chat template
    #     input_text = self.qwen_tokenizer.apply_chat_template(
    #         messages, tokenize=False, add_generation_prompt=True
    #     )

    #     # Tokenize and move input to device
    #     inputs = self.qwen_tokenizer([input_text], return_tensors="pt").to(self.device)
    #     self.qwen_model.config.pad_token_id = self.qwen_tokenizer.eos_token_id

    #     # Generate response
    #     outputs = self.qwen_model.generate(
    #         inputs.input_ids,
    #         max_new_tokens=50,
    #         pad_token_id=self.qwen_tokenizer.eos_token_id,
    #     )

    #     # Extract only the newly generated part
    #     generated_ids = outputs[:, inputs.input_ids.shape[1]:]  # Remove prompt part
    #     response = self.qwen_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

    #     return response

    ValueError: Input length of input_ids is 127, but `max_length` is set to 100. This can lead to unexpected behavior. You should consider increasing `max_length` or, better yet, setting `max_new_tokens`.

    
    def generate_answer(self, context, question):
        prompt = f"""
        You are a financial assistant. If the user greets you (e.g., "Hello," "Hi," "Good morning"), respond politely without requiring context. 

        For financial-related questions, answer based on the context provided. If the context lacks information, say "I don't know."

        Context: {context}
        User Query: {question}
        Answer:
        """

        input_text = prompt
        # f"Context: {context}\nQuestion: {question}\nAnswer:"
        inputs = self.qwen_tokenizer.encode(input_text, return_tensors="pt")
        # outputs = self.qwen_model.generate(inputs, max_length=100)
        outputs = self.qwen_model.generate(inputs, max_new_tokens=50)
        return self.qwen_tokenizer.decode(outputs[0], skip_special_tokens=True)


    def get_answer(self, query):
        """Main function to process a user query and return an answer."""
        
        # Check if query is appropriate
        if not self.moderate_query(query):
            return "Inappropriate request.", 0.0

        # Retrieve relevant documents and their confidence scores
        retrieved_docs, confidences = self.query_faiss(query)
        if not retrieved_docs:
            return "No relevant information found.", 0.0

        # Combine retrieved documents as context
        context = " ".join(retrieved_docs)
        avg_confidence = round(sum(confidences) / len(confidences), 2)

        # Generate model response
        model_response = self.generate_answer(context, query)

        # Extract only the relevant part of the response
        model_response = model_response.strip()
        
        # Ensure only the actual answer is returned
        if model_response.lower() in ["i don't know", "no relevant information found"]:
            return "I don't know.", avg_confidence
        #print(avg_confidence)
        if avg_confidence == 0.0:
            return "Not relevant ", avg_confidence

        
        return model_response, avg_confidence