File size: 7,641 Bytes
a2c10b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import pandas as pd
import faiss
import numpy as np
from .embedder import Embedder
from fuzzywuzzy import fuzz
from langchain_community.llms import Ollama
from langchain.prompts import PromptTemplate

class Retriever:
    def __init__(self, file_path):
        self.embedder = Embedder(model_name="all-MiniLM-L6-v2")
        self.index = None
        self.documents = []
        self.data = None
        self.embeddings = None
        self.load_file(file_path)
        self.build_index()

    def load_file(self, file_path):
        try:
            if file_path.endswith('.csv'):
                self.data = pd.read_csv(file_path)
            elif file_path.endswith('.xlsx') or file_path.endswith('.xls'):
                self.data = pd.read_excel(file_path)
            else:
                raise ValueError("Unsupported file format. Use .csv, .xlsx, or .xls")
            self.documents = self.data["Ticker"].astype(str).tolist()
        except Exception as e:
            print(f"Error loading file: {e}")
            self.documents = []
            self.data = pd.DataFrame()

    def build_index(self):
        if not self.documents:
            return
        self.embeddings = self.embedder.embed(self.documents)
        dim = self.embeddings.shape[1]
        self.index = faiss.IndexFlatL2(dim)
        self.index.add(self.embeddings)

    def retrieve(self, query, entities, k=3, threshold=0.7):
        
        query_prompt = f"{entities['ticker']} {entities['metric']} {entities['year']}"
        # print(query_prompt)
        if not self.index or not self.documents or self.data.empty:
            return []

        query_parts = query_prompt.split()
        if len(query_parts) != 3:
            print("Query must follow 'ticker metric year' pattern")
            return []

        query_ticker, query_metric, query_year = query_parts

        # Ticker similarity
        query_ticker_embedding = self.embedder.embed([query_ticker])
        distances, indices = self.index.search(query_ticker_embedding, k)
        ticker_matches = []
        for i, idx in enumerate(indices[0]):
            if idx < len(self.documents):
                ticker = self.data.iloc[idx]["Ticker"]
                similarity_score = 1 - distances[0][i] / 2
                ticker_matches.append((ticker, similarity_score, idx))

        # Metric similarity
        metric_embeddings = self.embedder.embed(self.data.columns.tolist())
        query_metric_embedding = self.embedder.embed([query_metric])[0]
        metric_scores = []
        for col, col_embedding in zip(self.data.columns, metric_embeddings):
            if col.lower() in ["ticker", "year"]:
                continue
            cos_sim = np.dot(query_metric_embedding, col_embedding) / (
                np.linalg.norm(query_metric_embedding) * np.linalg.norm(col_embedding)
            )
            metric_scores.append((col, cos_sim))

        # Year similarity
        if "Year" not in self.data.columns:
            print("No 'Year' column found in data")
            return []
        year_scores = []
        for year in self.data["Year"].astype(str).unique():
            similarity = fuzz.ratio(query_year, year) / 100.0
            year_scores.append((year, similarity))

        # Combine matches
        retrieved_data = []
        seen = set()
        for ticker, ticker_score, idx in ticker_matches:
            if ticker_score < threshold:
                continue
            for metric, metric_score in metric_scores:
                if metric_score < threshold:
                    continue
                for year, year_score in year_scores:
                    if year_score < 0.5:
                        continue
                    combined_score = (ticker_score + metric_score + year_score) / 3
                    match = self.data[
                        (self.data["Ticker"].str.lower() == ticker.lower()) &
                        (self.data["Year"].astype(str) == year) &
                        (self.data[metric].notnull())
                    ]
                    if not match.empty:
                        value = match[metric].iloc[0]
                        key = (ticker, metric, year)
                        if key not in seen:
                            seen.add(key)
                            retrieved_data.append({
                                "ticker": ticker,
                                "metric": metric,
                                "value": value,
                                "year": year,
                                "combined_score": combined_score
                            })

        if retrieved_data:
            # print(retrieved_data)
            retrieved_data.sort(key=lambda x: x["combined_score"], reverse=True)
            best_match = retrieved_data[0]
            answer = answer_question(query, best_match)
            return answer

        return "No relevant data found."

def answer_question(question, retrieved_data):
    """
    Use a lightweight LLM to generate a natural-language answer on CPU.
    
    Args:
        question (str): The question to answer
        retrieved_data (list): List of dictionaries with ticker, metric, value, year
    
    Returns:
        str: Natural-language answer
    """
    # print(question)
    # print(retrieved_data)
    try:
        # Initialize lightweight LLM (llama3.2:3b, CPU-friendly)
        llm = Ollama(model="gemma:2b", num_gpu=0)  # Explicitly disable GPU

        # Minimal prompt for CPU efficiency
        prompt_template = PromptTemplate(
            input_variables=["question", "ticker", "metric", "value", "year"],
            template=(
                "Question: {question}\n"
                "Data: Ticker={ticker}, Metric={metric}, Value={value}, Year={year}\n"
                "Answer concisely, formatting the value with commas."
            )
        )
        # print(prompt_template)

        # Format data
        if not retrieved_data:
            return "No relevant data found."
        
        prompt = prompt_template.format(
            question=question,
            ticker=retrieved_data['ticker'],
            metric=retrieved_data['metric'],
            value=retrieved_data, # formatted_value,
            year=retrieved_data['year']
        )

        # Generate response
        response = llm.invoke(prompt)
        return response.strip()

    except Exception as e:
        print(f"Error generating answer: {e}")
        return "Unable to generate answer."

# def main(file_path, query, question):
#     """
#     Main function to process a query, retrieve results, and answer a question.
    
#     Args:
#         file_path (str): Path to the CSV or Excel file
#         query (str): Query string in 'ticker metric year' format
#         question (str): Natural-language question to answer
    
#     Returns:
#         tuple: (retrieved data, answer)
#     """
#     try:
#         retriever = Retriever(file_path)
#         results = retriever.retrieve(query)
#         answer = answer_question(question, results)
#         return results, answer
#     except Exception as e:
#         print(f"Error processing query: {e}")
#         return [], "Unable to process query."

# if __name__ == "__main__":
#     file_path = "./financial_data.csv"
#     query = "AAPL InterestExpense 2024"
#     question = "What is the InterestExpense of AAPL 2024?"
#     results, answer = main(file_path, query, question)
#     for result in results:
#         print(f"Ticker: {result['ticker']}, Metric: {result['metric']}, Value: {result['value']}, Year: {result['year']}")
#     print(f"Answer: {answer}")