File size: 7,199 Bytes
43697b4
 
e2b8671
 
 
43697b4
 
e2b8671
 
43697b4
e2b8671
 
43697b4
 
 
 
 
 
 
 
 
 
e2b8671
 
43697b4
e2b8671
43697b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e2b8671
 
 
43697b4
 
e2b8671
43697b4
 
e2b8671
43697b4
 
 
 
 
e2b8671
43697b4
 
 
 
 
 
78ad15e
43697b4
 
e2b8671
78ad15e
e2b8671
78ad15e
 
43697b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8d190b
43697b4
 
f8d190b
43697b4
e2b8671
43697b4
 
 
 
f8d190b
43697b4
 
 
f8d190b
43697b4
 
 
f8d190b
43697b4
 
f8d190b
43697b4
 
 
 
e2b8671
43697b4
 
e2b8671
43697b4
 
 
 
e2b8671
43697b4
 
 
e2b8671
43697b4
e2b8671
 
43697b4
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
# import time
import threading
import pandas as pd
import faiss
import numpy as np
# import numpy as np
import pickle
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer
# import torch

class FinancialChatbot:
    def __init__(self, data_path, model_name="all-MiniLM-L6-v2", qwen_model_name="Qwen/Qwen2.5-1.5b"):
        self.data_path = data_path
        self.sbert_model = SentenceTransformer(model_name)
        self.index_map = {}
        self.faiss_index = None
#         def get_device_map() -> str:
#     return 'cuda' if torch.cuda.is_available() else ''

# device = get_device_map()
        self.qwen_model = AutoModelForCausalLM.from_pretrained(qwen_model_name, torch_dtype="auto", device_map="cpu", trust_remote_code=True)
        self.qwen_tokenizer = AutoTokenizer.from_pretrained(qwen_model_name, trust_remote_code=True)
        self.load_or_create_index()
    
    def load_or_create_index(self):
        try:
            self.faiss_index = faiss.read_index("financial_faiss.index")
            with open("index_map.pkl", "rb") as f:
                self.index_map = pickle.load(f)
            print("Index loaded successfully!")
        except:
            print("Creating new FAISS index...")
            df = pd.read_excel(self.data_path)
            sentences = []
            for index, row in df.iterrows():
                for col in df.columns[1:]:
                    text = f"{row[df.columns[0]]} - year {col} is: {row[col]}"
                    sentences.append(text)
                    self.index_map[len(sentences) - 1] = text
            embeddings = self.sbert_model.encode(sentences, convert_to_numpy=True)
            dim = embeddings.shape[1]
            self.faiss_index = faiss.IndexFlatL2(dim)
            self.faiss_index.add(embeddings)
            faiss.write_index(self.faiss_index, "financial_faiss.index")
            with open("index_map.pkl", "wb") as f:
                pickle.dump(self.index_map, f)
            print("Indexing completed!")
    
    # def query_faiss(self, query, top_k=5):
    #     query_embedding = self.sbert_model.encode([query], convert_to_numpy=True)
    #     distances, indices = self.faiss_index.search(query_embedding, top_k)
    #     return [self.index_map[idx] for idx in indices[0] if idx in self.index_map]

    def query_faiss(self, query, top_k=5):
        """Retrieve top-k documents from FAISS and return confidence scores."""
        
        query_embedding = self.sbert_model.encode([query], convert_to_numpy=True)
        distances, indices = self.faiss_index.search(query_embedding, top_k)

        results = []
        confidences = []

        if len(distances[0]) > 0:
            max_dist = np.max(distances[0]) if np.max(distances[0]) != 0 else 1  # Avoid division by zero

            for idx, dist in zip(indices[0], distances[0]):
                if idx in self.index_map:
                    results.append(self.index_map[idx])
                    confidence = 1 - (dist / max_dist)  # Normalize confidence (closer to 1 is better)
                    confidences.append(round(confidence, 2))  # Round for clarity

        return results, confidences
    
    def moderate_query(self, query):
        BLOCKED_WORDS = ["hack", "bypass", "illegal", "exploit", "scam", "kill", "laundering", "murder", "suicide", "self-harm"]
        return not any(word in query.lower() for word in BLOCKED_WORDS)
    
    def generate_answer(self, context, question):
        prompt = f"""
        You are a financial assistant. If the user greets you (e.g., "Hello," "Hi," "Good morning"), respond politely without requiring context. 

        For financial-related questions, answer based on the context provided. If the context lacks information, say "I don't know."

        Context: {context}
        User Query: {question}
        Answer:
        """

        input_text = prompt
        # f"Context: {context}\nQuestion: {question}\nAnswer:"
        inputs = self.qwen_tokenizer.encode(input_text, return_tensors="pt")
        outputs = self.qwen_model.generate(inputs, max_length=100)
        return self.qwen_tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # def get_answer(self, query, timeout=150):
    #     result = ["", 0.0]  # Placeholder for answer and confidence

    #     def task():
    #         if self.moderate_query(query):
    #             retrieved_docs = self.query_faiss(query)
    #             context = " ".join(retrieved_docs)
    #             answer = self.generate_answer(context, query)
    #             last_index = answer.rfind("Answer")
    #             if answer[last_index+9:11] == "--":
    #                 result[:] = ["No relevant information found", 0.0]
    #             else:
    #                 result[:] = [answer[last_index:], 0.9]
    #         else:
    #             result[:] = ["I'm unable to process your request due to inappropriate language.", 0.0]
        
    #     thread = threading.Thread(target=task)
    #     thread.start()
    #     thread.join(timeout)
    #     if thread.is_alive():
    #         return "Execution exceeded time limit. Stopping function.", 0.0
    #     return tuple(result)

    def get_answer(self, query, timeout=150):
        """Retrieve the best-matched answer along with confidence score, with execution timeout."""

        result = ["Execution exceeded time limit. Stopping function.", 0.0]  # Default timeout response

        def task():
            """Processing function to retrieve and generate answer."""
            if self.moderate_query(query):
                retrieved_docs, confidences = self.query_faiss(query)  # Get results + confidence scores

                if not retrieved_docs:  # If no relevant docs found
                    result[:] = ["No relevant information found", 0.0]
                    return

                # Combine retrieved docs and calculate final confidence
                context = " ".join(retrieved_docs)
                avg_confidence = round(sum(confidences) / len(confidences), 2)  # Avg confidence

                answer = self.generate_answer(context, query)
                last_index = answer.rfind("Answer")

                if answer[last_index + 9:11] == "--":
                    result[:] = ["No relevant information found", 0.0]
                else:
                    result[:] = [answer[last_index:], avg_confidence]

            else:
                result[:] = ["I'm unable to process your request due to inappropriate language.", 0.0]

        # Start execution in a separate thread
        thread = threading.Thread(target=task)
        thread.start()
        thread.join(timeout)  # Wait for execution up to timeout

        # If thread is still running after timeout, return timeout message
        if thread.is_alive():
            return "Execution exceeded time limit. Stopping function.", 0.0

        return tuple(result)


# if __name__ == "__main__":
#     chatbot = FinancialChatbot("C:\\Users\\Dell\\Downloads\\CAI_RAG\\DATA\\Nestle_Financtial_report_till2023.xlsx")
#     query = "What is the Employees Cost in Dec'20?"
#     print(chatbot.get_answer(query))