rasulbrur commited on
Commit
f1fce37
·
1 Parent(s): ebe592c

Changed classification model

Browse files
Files changed (2) hide show
  1. main.py +3 -1
  2. voice/classifier.py +228 -0
main.py CHANGED
@@ -3,6 +3,7 @@ import asyncio
3
  import importlib
4
  from voice.speech_to_text import SpeechToText
5
  from voice.intent_classifier import IntentClassifier
 
6
  from api.endpoints import FMPEndpoints
7
  from rag.retriever import Retriever
8
  from rag.sql_db import SQL_Key_Pair
@@ -11,7 +12,8 @@ from rag.web_search import duckduckgo_web_search
11
  async def process_query(vosk_model_path, audio_data=None, query_text=None, use_retriever=False):
12
  # Step 1: Initialize components
13
  stt = SpeechToText(model_path=vosk_model_path)
14
- classifier = IntentClassifier()
 
15
  endpoints = FMPEndpoints()
16
  # initialize rag tools
17
  retriever = Retriever(file_path="./data/financial_data.csv")
 
3
  import importlib
4
  from voice.speech_to_text import SpeechToText
5
  from voice.intent_classifier import IntentClassifier
6
+ from voice.classifier import TextClassifier
7
  from api.endpoints import FMPEndpoints
8
  from rag.retriever import Retriever
9
  from rag.sql_db import SQL_Key_Pair
 
12
  async def process_query(vosk_model_path, audio_data=None, query_text=None, use_retriever=False):
13
  # Step 1: Initialize components
14
  stt = SpeechToText(model_path=vosk_model_path)
15
+ # classifier = IntentClassifier()
16
+ classifier = TextClassifier()
17
  endpoints = FMPEndpoints()
18
  # initialize rag tools
19
  retriever = Retriever(file_path="./data/financial_data.csv")
voice/classifier.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spacy
2
+ from transformers import pipeline
3
+ from dateutil.parser import parse
4
+ import re
5
+ import pandas as pd
6
+ from difflib import SequenceMatcher
7
+
8
+ class TextClassifier:
9
+ def __init__(self):
10
+ # Use a larger model for better NER (optional)
11
+ self.nlp = spacy.load("en_core_web_sm") # "en_core_web_lg"
12
+ try:
13
+ # Use a smaller, PyTorch-compatible model for zero-shot classification
14
+ self.classifier = pipeline("zero-shot-classification", model="typeform/distilbert-base-uncased-mnli")
15
+ self.model_available = True
16
+ print("Successfully loaded zero-shot classification model.")
17
+ except Exception as e:
18
+ print(f"Failed to load zero-shot classification model: {e}. Falling back to keyword-based classification.")
19
+ self.classifier = None
20
+ self.model_available = False
21
+
22
+ self.intents = [
23
+ "get_net_income",
24
+ "get_revenue",
25
+ "get_stock_price",
26
+ "get_profit_margin",
27
+ "get_company_profile",
28
+ "get_market_cap",
29
+ "get_historical_stock_price",
30
+ "get_dividend_info",
31
+ "get_balance_sheet",
32
+ "get_cash_flow",
33
+ "get_financial_ratios",
34
+ "get_earnings_per_share",
35
+ "get_interest",
36
+ "get_research_info",
37
+ "get_cost_info",
38
+ "get_income_tax"
39
+ ]
40
+
41
+ # Mapping of company names to ticker symbols (case-insensitive)
42
+ self.company_to_ticker = {
43
+ "apple": "AAPL",
44
+ "microsoft corporation": "MSFT",
45
+ "microsoft": "MSFT",
46
+ "nvidia corporation": "NVDA",
47
+ "nvidia": "NVDA",
48
+ "amazon": "AMZN",
49
+ "alphabet inc": "GOOGL",
50
+ "google": "GOOGL",
51
+ "meta platforms": "META",
52
+ "meta": "META",
53
+ "facebook": "META",
54
+ "tesla": "TSLA",
55
+ "walmart inc": "WMT",
56
+ "walmart": "WMT",
57
+ "visa inc": "V",
58
+ "visa": "V",
59
+ "coca cola": "KO"
60
+ }
61
+
62
+ # Mapping of keywords to intents (case-insensitive)
63
+ self.intent_to_keywords = {
64
+ "get_net_income": ["net income", "income", "earnings"],
65
+ "get_revenue": ["revenue", "sales", "turnover", "gross income"],
66
+ "get_stock_price": ["stock price", "stock", "price", "share price", "current price", "price now", "stock value"],
67
+ "get_profit_margin": ["profit margin", "margin", "profit percentage", "net margin", "profit"],
68
+ "get_company_profile": ["who is", "company profile", "about company", "company info"],
69
+ "get_market_cap": ["market cap", "market capitalization", "company value", "valuation"],
70
+ "get_historical_stock_price": ["historical stock price", "stock price on", "past stock price", "stock price in", "price on"],
71
+ "get_dividend_info": ["dividend info", "dividend payout", "payout ratio", "dividend yield", "dividend"],
72
+ "get_balance_sheet": ["balance sheet", "sheet", "financial position", "assets and liabilities", "balance"],
73
+ "get_cash_flow": ["cash", "flow", "cash flow", "cashflow", "cash from operations", "operating cash"],
74
+ "get_financial_ratios": ["financial ratios", "ratios", "current ratio", "liquidity ratio", "debt ratio"],
75
+ "get_earnings_per_share": ["earnings per share", "eps", "per share earnings"],
76
+ }
77
+
78
+ def classify_by_keywords(self, text):
79
+ """
80
+ Classify the intent based on keyword mapping.
81
+
82
+ Args:
83
+ text (str): The input text to classify.
84
+
85
+ Returns:
86
+ str: The predicted intent, or None if no match is found.
87
+ """
88
+ text_lower = text.lower()
89
+ for intent, keywords in self.intent_to_keywords.items():
90
+ if any(keyword in text_lower for keyword in keywords):
91
+ print(f"Classified intent: {intent} based on keywords: {keywords}")
92
+ return intent
93
+ print("No intent matched based on keywords.")
94
+ return None # Fallback if no keywords match
95
+
96
+ def classify_with_llm(self, text):
97
+ if not self.model_available:
98
+ print("Zero-shot classifier not available. Using keyword-based classification.")
99
+ return self.classify_by_keywords(text)
100
+ try:
101
+ hypothesis_template = "This text is requesting {} information."
102
+ result = self.classifier(text, candidate_labels=self.intents, hypothesis_template=hypothesis_template, multi_label=False)
103
+ predicted_intent = result["labels"][0]
104
+ print(f"Predicted intent: {predicted_intent} with scores: {dict(zip(result['labels'], result['scores']))}")
105
+ return predicted_intent
106
+ except Exception as e:
107
+ print(f"Error classifying intent with model: {e}. Falling back to keyword-based classification.")
108
+ return self.classify_by_keywords(text)
109
+
110
+ def extract_entities(self, text):
111
+ doc = self.nlp(text)
112
+ entities = {"ticker": None, "metric": None, "year": None, "date": None}
113
+
114
+ # Step 1: Extract entities using spaCy NER
115
+ for ent in doc.ents:
116
+ if ent.label_ == "ORG":
117
+ org_name = ent.text.lower()
118
+ ticker = self.company_to_ticker.get(org_name)
119
+ if ticker:
120
+ entities["ticker"] = ticker
121
+ else:
122
+ # If not found in the mapping, search in the CSV file
123
+ try:
124
+ # Load the CSV file (adjust the path as needed)
125
+ csv_path = "financial data sp500 companies.csv" # Same path as used in Retriever
126
+ df = pd.read_csv(csv_path)
127
+
128
+ # Ensure the required columns exist
129
+ if "firm" not in df.columns or "Ticker" not in df.columns:
130
+ print("Required columns 'firm' or 'Ticker' not found in CSV. Using fallback ticker.")
131
+ entities["ticker"] = ent.text.upper()
132
+ else:
133
+ # Calculate similarity scores between org_name and each firm name
134
+ df["similarity"] = df["firm"].apply(
135
+ lambda x: SequenceMatcher(None, org_name, str(x).lower()).ratio()
136
+ )
137
+
138
+ # Find rows with similarity >= 80%
139
+ matches = df[df["similarity"] >= 0.5]
140
+
141
+ if not matches.empty:
142
+ # Take the first match (highest similarity)
143
+ best_match = matches.sort_values(by="similarity", ascending=False).iloc[0]
144
+ ticker = best_match["Ticker"]
145
+ print(f"Found ticker {ticker} for {org_name} with similarity {best_match['similarity']:.2f}")
146
+ entities["ticker"] = ticker
147
+ else:
148
+ print(f"No match found for {org_name} with >= 50% similarity. Using fallback ticker.")
149
+ entities["ticker"] = ent.text.upper()
150
+
151
+ except Exception as e:
152
+ print(f"Error searching CSV for ticker: {e}. Using fallback ticker.")
153
+ entities["ticker"] = ent.text.upper()
154
+ elif ent.label_ == "DATE":
155
+ date_text = ent.text.lower()
156
+ try:
157
+ parsed_date = parse(date_text, fuzzy=True, default=parse("2025-01-01"))
158
+ # If the date is a year (e.g., "2023", "this year") or parsed as January 1
159
+ if "year" in date_text or date_text.isdigit() or (parsed_date.day == 1 and parsed_date.month == 1):
160
+ entities["year"] = parsed_date.strftime("%Y")
161
+ else:
162
+ # Otherwise, treat it as a specific date (e.g., "Jan 5")
163
+ entities["date"] = parsed_date.strftime("%Y-%m-%d")
164
+ except ValueError:
165
+ # Fallback if parsing fails
166
+ if "year" in date_text or date_text.isdigit():
167
+ entities["year"] = date_text
168
+ else:
169
+ entities["date"] = date_text
170
+
171
+ # Step 2: Fallback ticker extraction if spaCy fails to identify ORG
172
+ if not entities["ticker"]:
173
+ text_lower = text.lower()
174
+ for company_name, ticker in self.company_to_ticker.items():
175
+ if company_name in text_lower:
176
+ entities["ticker"] = ticker
177
+ break
178
+
179
+ # Step 3: Extract metric using keyword matching with synonyms
180
+ text_lower = text.lower()
181
+ if any(keyword in text_lower for keyword in ["net income", "net", "income"]):
182
+ entities["metric"] = "netIncome"
183
+ elif "revenue" in text_lower:
184
+ entities["metric"] = "revenue"
185
+ elif any(keyword in text_lower for keyword in ["profit margin", "profit", "margin"]):
186
+ entities["metric"] = "netProfitMargin"
187
+ elif any(keyword in text_lower for keyword in ["market cap", "market capitalization", "market"]):
188
+ entities["metric"] = "mktCap"
189
+ elif any(keyword in text_lower for keyword in ["payout ratio", "dividend payout"]):
190
+ entities["metric"] = "payoutRatio"
191
+ elif any(keyword in text_lower for keyword in ["current ratio", "liquidity ratio"]):
192
+ entities["metric"] = "currentRatio"
193
+ elif any(keyword in text_lower for keyword in ["eps", "earnings per share", "earnings"]):
194
+ entities["metric"] = "eps"
195
+ elif any(keyword in text_lower for keyword in ["stock", "stock price", "current price", "valuation", "price"]):
196
+ entities["metric"] = "price"
197
+ elif any(keyword in text_lower for keyword in ["company info", "about company", "who is"]):
198
+ entities["metric"] = "ceo"
199
+ elif any(keyword in text_lower for keyword in ["balance sheet", "sheet", "assets"]):
200
+ entities["metric"] = "Assets&Liabilities"
201
+ elif any(keyword in text_lower for keyword in ["historical", "earnings per share", "earnings"]):
202
+ entities["metric"] = "historical"
203
+ elif any(keyword in text_lower for keyword in ["cash", "flow", "cash flow"]):
204
+ entities["metric"] = "cashFlowFromOperatingActivities"
205
+ elif any(keyword in text_lower for keyword in ["tax"]):
206
+ entities["metric"] = "IncomeTax"
207
+ elif any(keyword in text_lower for keyword in ["interest", "interest expense", "expense"]):
208
+ entities["metric"] = "InterestExpense"
209
+ elif any(keyword in text_lower for keyword in ["research", "research development", "development"]):
210
+ entities["metric"] = "Research"
211
+ elif any(keyword in text_lower for keyword in ["cost", "total cost"]):
212
+ entities["metric"] = "TotalCost"
213
+
214
+ # Step 4: Normalize year (handle "this year", "last year", etc.)
215
+ if entities["year"]:
216
+ year_text = entities["year"].lower()
217
+ current_year = 2025 # Based on the current date (April 16, 2025)
218
+ if "this year" in year_text:
219
+ entities["year"] = str(current_year)
220
+ elif "last year" in year_text:
221
+ entities["year"] = str(current_year - 1)
222
+ elif re.match(r"^\d{4}$", year_text):
223
+ entities["year"] = year_text
224
+ else:
225
+ # If year is not a valid format, unset it
226
+ entities["year"] = None
227
+
228
+ return entities