Spaces:
Running
Running
Changed classification model
Browse files- main.py +3 -1
- voice/classifier.py +228 -0
main.py
CHANGED
@@ -3,6 +3,7 @@ import asyncio
|
|
3 |
import importlib
|
4 |
from voice.speech_to_text import SpeechToText
|
5 |
from voice.intent_classifier import IntentClassifier
|
|
|
6 |
from api.endpoints import FMPEndpoints
|
7 |
from rag.retriever import Retriever
|
8 |
from rag.sql_db import SQL_Key_Pair
|
@@ -11,7 +12,8 @@ from rag.web_search import duckduckgo_web_search
|
|
11 |
async def process_query(vosk_model_path, audio_data=None, query_text=None, use_retriever=False):
|
12 |
# Step 1: Initialize components
|
13 |
stt = SpeechToText(model_path=vosk_model_path)
|
14 |
-
classifier = IntentClassifier()
|
|
|
15 |
endpoints = FMPEndpoints()
|
16 |
# initialize rag tools
|
17 |
retriever = Retriever(file_path="./data/financial_data.csv")
|
|
|
3 |
import importlib
|
4 |
from voice.speech_to_text import SpeechToText
|
5 |
from voice.intent_classifier import IntentClassifier
|
6 |
+
from voice.classifier import TextClassifier
|
7 |
from api.endpoints import FMPEndpoints
|
8 |
from rag.retriever import Retriever
|
9 |
from rag.sql_db import SQL_Key_Pair
|
|
|
12 |
async def process_query(vosk_model_path, audio_data=None, query_text=None, use_retriever=False):
|
13 |
# Step 1: Initialize components
|
14 |
stt = SpeechToText(model_path=vosk_model_path)
|
15 |
+
# classifier = IntentClassifier()
|
16 |
+
classifier = TextClassifier()
|
17 |
endpoints = FMPEndpoints()
|
18 |
# initialize rag tools
|
19 |
retriever = Retriever(file_path="./data/financial_data.csv")
|
voice/classifier.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import spacy
|
2 |
+
from transformers import pipeline
|
3 |
+
from dateutil.parser import parse
|
4 |
+
import re
|
5 |
+
import pandas as pd
|
6 |
+
from difflib import SequenceMatcher
|
7 |
+
|
8 |
+
class TextClassifier:
|
9 |
+
def __init__(self):
|
10 |
+
# Use a larger model for better NER (optional)
|
11 |
+
self.nlp = spacy.load("en_core_web_sm") # "en_core_web_lg"
|
12 |
+
try:
|
13 |
+
# Use a smaller, PyTorch-compatible model for zero-shot classification
|
14 |
+
self.classifier = pipeline("zero-shot-classification", model="typeform/distilbert-base-uncased-mnli")
|
15 |
+
self.model_available = True
|
16 |
+
print("Successfully loaded zero-shot classification model.")
|
17 |
+
except Exception as e:
|
18 |
+
print(f"Failed to load zero-shot classification model: {e}. Falling back to keyword-based classification.")
|
19 |
+
self.classifier = None
|
20 |
+
self.model_available = False
|
21 |
+
|
22 |
+
self.intents = [
|
23 |
+
"get_net_income",
|
24 |
+
"get_revenue",
|
25 |
+
"get_stock_price",
|
26 |
+
"get_profit_margin",
|
27 |
+
"get_company_profile",
|
28 |
+
"get_market_cap",
|
29 |
+
"get_historical_stock_price",
|
30 |
+
"get_dividend_info",
|
31 |
+
"get_balance_sheet",
|
32 |
+
"get_cash_flow",
|
33 |
+
"get_financial_ratios",
|
34 |
+
"get_earnings_per_share",
|
35 |
+
"get_interest",
|
36 |
+
"get_research_info",
|
37 |
+
"get_cost_info",
|
38 |
+
"get_income_tax"
|
39 |
+
]
|
40 |
+
|
41 |
+
# Mapping of company names to ticker symbols (case-insensitive)
|
42 |
+
self.company_to_ticker = {
|
43 |
+
"apple": "AAPL",
|
44 |
+
"microsoft corporation": "MSFT",
|
45 |
+
"microsoft": "MSFT",
|
46 |
+
"nvidia corporation": "NVDA",
|
47 |
+
"nvidia": "NVDA",
|
48 |
+
"amazon": "AMZN",
|
49 |
+
"alphabet inc": "GOOGL",
|
50 |
+
"google": "GOOGL",
|
51 |
+
"meta platforms": "META",
|
52 |
+
"meta": "META",
|
53 |
+
"facebook": "META",
|
54 |
+
"tesla": "TSLA",
|
55 |
+
"walmart inc": "WMT",
|
56 |
+
"walmart": "WMT",
|
57 |
+
"visa inc": "V",
|
58 |
+
"visa": "V",
|
59 |
+
"coca cola": "KO"
|
60 |
+
}
|
61 |
+
|
62 |
+
# Mapping of keywords to intents (case-insensitive)
|
63 |
+
self.intent_to_keywords = {
|
64 |
+
"get_net_income": ["net income", "income", "earnings"],
|
65 |
+
"get_revenue": ["revenue", "sales", "turnover", "gross income"],
|
66 |
+
"get_stock_price": ["stock price", "stock", "price", "share price", "current price", "price now", "stock value"],
|
67 |
+
"get_profit_margin": ["profit margin", "margin", "profit percentage", "net margin", "profit"],
|
68 |
+
"get_company_profile": ["who is", "company profile", "about company", "company info"],
|
69 |
+
"get_market_cap": ["market cap", "market capitalization", "company value", "valuation"],
|
70 |
+
"get_historical_stock_price": ["historical stock price", "stock price on", "past stock price", "stock price in", "price on"],
|
71 |
+
"get_dividend_info": ["dividend info", "dividend payout", "payout ratio", "dividend yield", "dividend"],
|
72 |
+
"get_balance_sheet": ["balance sheet", "sheet", "financial position", "assets and liabilities", "balance"],
|
73 |
+
"get_cash_flow": ["cash", "flow", "cash flow", "cashflow", "cash from operations", "operating cash"],
|
74 |
+
"get_financial_ratios": ["financial ratios", "ratios", "current ratio", "liquidity ratio", "debt ratio"],
|
75 |
+
"get_earnings_per_share": ["earnings per share", "eps", "per share earnings"],
|
76 |
+
}
|
77 |
+
|
78 |
+
def classify_by_keywords(self, text):
|
79 |
+
"""
|
80 |
+
Classify the intent based on keyword mapping.
|
81 |
+
|
82 |
+
Args:
|
83 |
+
text (str): The input text to classify.
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
str: The predicted intent, or None if no match is found.
|
87 |
+
"""
|
88 |
+
text_lower = text.lower()
|
89 |
+
for intent, keywords in self.intent_to_keywords.items():
|
90 |
+
if any(keyword in text_lower for keyword in keywords):
|
91 |
+
print(f"Classified intent: {intent} based on keywords: {keywords}")
|
92 |
+
return intent
|
93 |
+
print("No intent matched based on keywords.")
|
94 |
+
return None # Fallback if no keywords match
|
95 |
+
|
96 |
+
def classify_with_llm(self, text):
|
97 |
+
if not self.model_available:
|
98 |
+
print("Zero-shot classifier not available. Using keyword-based classification.")
|
99 |
+
return self.classify_by_keywords(text)
|
100 |
+
try:
|
101 |
+
hypothesis_template = "This text is requesting {} information."
|
102 |
+
result = self.classifier(text, candidate_labels=self.intents, hypothesis_template=hypothesis_template, multi_label=False)
|
103 |
+
predicted_intent = result["labels"][0]
|
104 |
+
print(f"Predicted intent: {predicted_intent} with scores: {dict(zip(result['labels'], result['scores']))}")
|
105 |
+
return predicted_intent
|
106 |
+
except Exception as e:
|
107 |
+
print(f"Error classifying intent with model: {e}. Falling back to keyword-based classification.")
|
108 |
+
return self.classify_by_keywords(text)
|
109 |
+
|
110 |
+
def extract_entities(self, text):
|
111 |
+
doc = self.nlp(text)
|
112 |
+
entities = {"ticker": None, "metric": None, "year": None, "date": None}
|
113 |
+
|
114 |
+
# Step 1: Extract entities using spaCy NER
|
115 |
+
for ent in doc.ents:
|
116 |
+
if ent.label_ == "ORG":
|
117 |
+
org_name = ent.text.lower()
|
118 |
+
ticker = self.company_to_ticker.get(org_name)
|
119 |
+
if ticker:
|
120 |
+
entities["ticker"] = ticker
|
121 |
+
else:
|
122 |
+
# If not found in the mapping, search in the CSV file
|
123 |
+
try:
|
124 |
+
# Load the CSV file (adjust the path as needed)
|
125 |
+
csv_path = "financial data sp500 companies.csv" # Same path as used in Retriever
|
126 |
+
df = pd.read_csv(csv_path)
|
127 |
+
|
128 |
+
# Ensure the required columns exist
|
129 |
+
if "firm" not in df.columns or "Ticker" not in df.columns:
|
130 |
+
print("Required columns 'firm' or 'Ticker' not found in CSV. Using fallback ticker.")
|
131 |
+
entities["ticker"] = ent.text.upper()
|
132 |
+
else:
|
133 |
+
# Calculate similarity scores between org_name and each firm name
|
134 |
+
df["similarity"] = df["firm"].apply(
|
135 |
+
lambda x: SequenceMatcher(None, org_name, str(x).lower()).ratio()
|
136 |
+
)
|
137 |
+
|
138 |
+
# Find rows with similarity >= 80%
|
139 |
+
matches = df[df["similarity"] >= 0.5]
|
140 |
+
|
141 |
+
if not matches.empty:
|
142 |
+
# Take the first match (highest similarity)
|
143 |
+
best_match = matches.sort_values(by="similarity", ascending=False).iloc[0]
|
144 |
+
ticker = best_match["Ticker"]
|
145 |
+
print(f"Found ticker {ticker} for {org_name} with similarity {best_match['similarity']:.2f}")
|
146 |
+
entities["ticker"] = ticker
|
147 |
+
else:
|
148 |
+
print(f"No match found for {org_name} with >= 50% similarity. Using fallback ticker.")
|
149 |
+
entities["ticker"] = ent.text.upper()
|
150 |
+
|
151 |
+
except Exception as e:
|
152 |
+
print(f"Error searching CSV for ticker: {e}. Using fallback ticker.")
|
153 |
+
entities["ticker"] = ent.text.upper()
|
154 |
+
elif ent.label_ == "DATE":
|
155 |
+
date_text = ent.text.lower()
|
156 |
+
try:
|
157 |
+
parsed_date = parse(date_text, fuzzy=True, default=parse("2025-01-01"))
|
158 |
+
# If the date is a year (e.g., "2023", "this year") or parsed as January 1
|
159 |
+
if "year" in date_text or date_text.isdigit() or (parsed_date.day == 1 and parsed_date.month == 1):
|
160 |
+
entities["year"] = parsed_date.strftime("%Y")
|
161 |
+
else:
|
162 |
+
# Otherwise, treat it as a specific date (e.g., "Jan 5")
|
163 |
+
entities["date"] = parsed_date.strftime("%Y-%m-%d")
|
164 |
+
except ValueError:
|
165 |
+
# Fallback if parsing fails
|
166 |
+
if "year" in date_text or date_text.isdigit():
|
167 |
+
entities["year"] = date_text
|
168 |
+
else:
|
169 |
+
entities["date"] = date_text
|
170 |
+
|
171 |
+
# Step 2: Fallback ticker extraction if spaCy fails to identify ORG
|
172 |
+
if not entities["ticker"]:
|
173 |
+
text_lower = text.lower()
|
174 |
+
for company_name, ticker in self.company_to_ticker.items():
|
175 |
+
if company_name in text_lower:
|
176 |
+
entities["ticker"] = ticker
|
177 |
+
break
|
178 |
+
|
179 |
+
# Step 3: Extract metric using keyword matching with synonyms
|
180 |
+
text_lower = text.lower()
|
181 |
+
if any(keyword in text_lower for keyword in ["net income", "net", "income"]):
|
182 |
+
entities["metric"] = "netIncome"
|
183 |
+
elif "revenue" in text_lower:
|
184 |
+
entities["metric"] = "revenue"
|
185 |
+
elif any(keyword in text_lower for keyword in ["profit margin", "profit", "margin"]):
|
186 |
+
entities["metric"] = "netProfitMargin"
|
187 |
+
elif any(keyword in text_lower for keyword in ["market cap", "market capitalization", "market"]):
|
188 |
+
entities["metric"] = "mktCap"
|
189 |
+
elif any(keyword in text_lower for keyword in ["payout ratio", "dividend payout"]):
|
190 |
+
entities["metric"] = "payoutRatio"
|
191 |
+
elif any(keyword in text_lower for keyword in ["current ratio", "liquidity ratio"]):
|
192 |
+
entities["metric"] = "currentRatio"
|
193 |
+
elif any(keyword in text_lower for keyword in ["eps", "earnings per share", "earnings"]):
|
194 |
+
entities["metric"] = "eps"
|
195 |
+
elif any(keyword in text_lower for keyword in ["stock", "stock price", "current price", "valuation", "price"]):
|
196 |
+
entities["metric"] = "price"
|
197 |
+
elif any(keyword in text_lower for keyword in ["company info", "about company", "who is"]):
|
198 |
+
entities["metric"] = "ceo"
|
199 |
+
elif any(keyword in text_lower for keyword in ["balance sheet", "sheet", "assets"]):
|
200 |
+
entities["metric"] = "Assets&Liabilities"
|
201 |
+
elif any(keyword in text_lower for keyword in ["historical", "earnings per share", "earnings"]):
|
202 |
+
entities["metric"] = "historical"
|
203 |
+
elif any(keyword in text_lower for keyword in ["cash", "flow", "cash flow"]):
|
204 |
+
entities["metric"] = "cashFlowFromOperatingActivities"
|
205 |
+
elif any(keyword in text_lower for keyword in ["tax"]):
|
206 |
+
entities["metric"] = "IncomeTax"
|
207 |
+
elif any(keyword in text_lower for keyword in ["interest", "interest expense", "expense"]):
|
208 |
+
entities["metric"] = "InterestExpense"
|
209 |
+
elif any(keyword in text_lower for keyword in ["research", "research development", "development"]):
|
210 |
+
entities["metric"] = "Research"
|
211 |
+
elif any(keyword in text_lower for keyword in ["cost", "total cost"]):
|
212 |
+
entities["metric"] = "TotalCost"
|
213 |
+
|
214 |
+
# Step 4: Normalize year (handle "this year", "last year", etc.)
|
215 |
+
if entities["year"]:
|
216 |
+
year_text = entities["year"].lower()
|
217 |
+
current_year = 2025 # Based on the current date (April 16, 2025)
|
218 |
+
if "this year" in year_text:
|
219 |
+
entities["year"] = str(current_year)
|
220 |
+
elif "last year" in year_text:
|
221 |
+
entities["year"] = str(current_year - 1)
|
222 |
+
elif re.match(r"^\d{4}$", year_text):
|
223 |
+
entities["year"] = year_text
|
224 |
+
else:
|
225 |
+
# If year is not a valid format, unset it
|
226 |
+
entities["year"] = None
|
227 |
+
|
228 |
+
return entities
|