Spaces:
Sleeping
Sleeping
File size: 7,641 Bytes
a2c10b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
import pandas as pd
import faiss
import numpy as np
from .embedder import Embedder
from fuzzywuzzy import fuzz
from langchain_community.llms import Ollama
from langchain.prompts import PromptTemplate
class Retriever:
def __init__(self, file_path):
self.embedder = Embedder(model_name="all-MiniLM-L6-v2")
self.index = None
self.documents = []
self.data = None
self.embeddings = None
self.load_file(file_path)
self.build_index()
def load_file(self, file_path):
try:
if file_path.endswith('.csv'):
self.data = pd.read_csv(file_path)
elif file_path.endswith('.xlsx') or file_path.endswith('.xls'):
self.data = pd.read_excel(file_path)
else:
raise ValueError("Unsupported file format. Use .csv, .xlsx, or .xls")
self.documents = self.data["Ticker"].astype(str).tolist()
except Exception as e:
print(f"Error loading file: {e}")
self.documents = []
self.data = pd.DataFrame()
def build_index(self):
if not self.documents:
return
self.embeddings = self.embedder.embed(self.documents)
dim = self.embeddings.shape[1]
self.index = faiss.IndexFlatL2(dim)
self.index.add(self.embeddings)
def retrieve(self, query, entities, k=3, threshold=0.7):
query_prompt = f"{entities['ticker']} {entities['metric']} {entities['year']}"
# print(query_prompt)
if not self.index or not self.documents or self.data.empty:
return []
query_parts = query_prompt.split()
if len(query_parts) != 3:
print("Query must follow 'ticker metric year' pattern")
return []
query_ticker, query_metric, query_year = query_parts
# Ticker similarity
query_ticker_embedding = self.embedder.embed([query_ticker])
distances, indices = self.index.search(query_ticker_embedding, k)
ticker_matches = []
for i, idx in enumerate(indices[0]):
if idx < len(self.documents):
ticker = self.data.iloc[idx]["Ticker"]
similarity_score = 1 - distances[0][i] / 2
ticker_matches.append((ticker, similarity_score, idx))
# Metric similarity
metric_embeddings = self.embedder.embed(self.data.columns.tolist())
query_metric_embedding = self.embedder.embed([query_metric])[0]
metric_scores = []
for col, col_embedding in zip(self.data.columns, metric_embeddings):
if col.lower() in ["ticker", "year"]:
continue
cos_sim = np.dot(query_metric_embedding, col_embedding) / (
np.linalg.norm(query_metric_embedding) * np.linalg.norm(col_embedding)
)
metric_scores.append((col, cos_sim))
# Year similarity
if "Year" not in self.data.columns:
print("No 'Year' column found in data")
return []
year_scores = []
for year in self.data["Year"].astype(str).unique():
similarity = fuzz.ratio(query_year, year) / 100.0
year_scores.append((year, similarity))
# Combine matches
retrieved_data = []
seen = set()
for ticker, ticker_score, idx in ticker_matches:
if ticker_score < threshold:
continue
for metric, metric_score in metric_scores:
if metric_score < threshold:
continue
for year, year_score in year_scores:
if year_score < 0.5:
continue
combined_score = (ticker_score + metric_score + year_score) / 3
match = self.data[
(self.data["Ticker"].str.lower() == ticker.lower()) &
(self.data["Year"].astype(str) == year) &
(self.data[metric].notnull())
]
if not match.empty:
value = match[metric].iloc[0]
key = (ticker, metric, year)
if key not in seen:
seen.add(key)
retrieved_data.append({
"ticker": ticker,
"metric": metric,
"value": value,
"year": year,
"combined_score": combined_score
})
if retrieved_data:
# print(retrieved_data)
retrieved_data.sort(key=lambda x: x["combined_score"], reverse=True)
best_match = retrieved_data[0]
answer = answer_question(query, best_match)
return answer
return "No relevant data found."
def answer_question(question, retrieved_data):
"""
Use a lightweight LLM to generate a natural-language answer on CPU.
Args:
question (str): The question to answer
retrieved_data (list): List of dictionaries with ticker, metric, value, year
Returns:
str: Natural-language answer
"""
# print(question)
# print(retrieved_data)
try:
# Initialize lightweight LLM (llama3.2:3b, CPU-friendly)
llm = Ollama(model="gemma:2b", num_gpu=0) # Explicitly disable GPU
# Minimal prompt for CPU efficiency
prompt_template = PromptTemplate(
input_variables=["question", "ticker", "metric", "value", "year"],
template=(
"Question: {question}\n"
"Data: Ticker={ticker}, Metric={metric}, Value={value}, Year={year}\n"
"Answer concisely, formatting the value with commas."
)
)
# print(prompt_template)
# Format data
if not retrieved_data:
return "No relevant data found."
prompt = prompt_template.format(
question=question,
ticker=retrieved_data['ticker'],
metric=retrieved_data['metric'],
value=retrieved_data, # formatted_value,
year=retrieved_data['year']
)
# Generate response
response = llm.invoke(prompt)
return response.strip()
except Exception as e:
print(f"Error generating answer: {e}")
return "Unable to generate answer."
# def main(file_path, query, question):
# """
# Main function to process a query, retrieve results, and answer a question.
# Args:
# file_path (str): Path to the CSV or Excel file
# query (str): Query string in 'ticker metric year' format
# question (str): Natural-language question to answer
# Returns:
# tuple: (retrieved data, answer)
# """
# try:
# retriever = Retriever(file_path)
# results = retriever.retrieve(query)
# answer = answer_question(question, results)
# return results, answer
# except Exception as e:
# print(f"Error processing query: {e}")
# return [], "Unable to process query."
# if __name__ == "__main__":
# file_path = "./financial_data.csv"
# query = "AAPL InterestExpense 2024"
# question = "What is the InterestExpense of AAPL 2024?"
# results, answer = main(file_path, query, question)
# for result in results:
# print(f"Ticker: {result['ticker']}, Metric: {result['metric']}, Value: {result['value']}, Year: {result['year']}")
# print(f"Answer: {answer}") |