S6six commited on
Commit
1265584
·
1 Parent(s): 9719f08

Add Gradio app and update dependencies

Browse files
.gitignore CHANGED
@@ -18,4 +18,7 @@ data/processed/*
18
  # Virtual environment
19
  venv/
20
  env/
21
- .venv/
 
 
 
 
18
  # Virtual environment
19
  venv/
20
  env/
21
+ .venv/
22
+
23
+ # Temporary files
24
+ temp_merged_data.csv
app.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ from datetime import datetime, timedelta
4
+ import matplotlib.pyplot as plt
5
+ import matplotlib.dates as mdates
6
+ import os
7
+ import sys
8
+ import io
9
+ import base64
10
+
11
+ # Add src directory to path to import modules
12
+ module_path = os.path.abspath(os.path.join('.'))
13
+ if module_path not in sys.path:
14
+ sys.path.append(module_path)
15
+
16
+ # Import functions from your src directory
17
+ try:
18
+ from src.data_fetcher import get_stock_data, get_news_articles, load_api_keys
19
+ from src.sentiment_analyzer import analyze_sentiment
20
+ except ImportError as e:
21
+ # Handle error gracefully if run from a different directory or modules missing
22
+ print(f"Error importing modules from src: {e}. Ensure app.py is in the project root and src/* exists.")
23
+ # Define dummy functions if imports fail, so Gradio interface can still load
24
+ def get_stock_data(*args, **kwargs): return None
25
+ def get_news_articles(*args, **kwargs): return None
26
+ def analyze_sentiment(*args, **kwargs): return None, None, None
27
+ def load_api_keys(): return None, None
28
+
29
+
30
+ # --- Data Fetching and Processing Logic ---
31
+ # (Similar to the Streamlit version, but adapted for Gradio outputs)
32
+ def perform_analysis(ticker_symbol, start_date_str, end_date_str): # Renamed date inputs
33
+ """Fetches data, analyzes sentiment, merges, and prepares outputs for Gradio."""
34
+ if not ticker_symbol:
35
+ return None, "Please enter a stock ticker.", None, None, None
36
+
37
+ # Ensure API keys are loaded (needed for news)
38
+ news_key, _ = load_api_keys()
39
+ if not news_key:
40
+ return None, "Error: NEWS_API_KEY not found in .env file. Cannot fetch news.", None, None, None
41
+
42
+ # Validate and parse date strings
43
+ try:
44
+ start_date_obj = datetime.strptime(start_date_str, '%Y-%m-%d').date()
45
+ end_date_obj = datetime.strptime(end_date_str, '%Y-%m-%d').date()
46
+ except ValueError:
47
+ return None, "Error: Invalid date format. Please use YYYY-MM-DD.", None, None, None
48
+
49
+
50
+ if start_date_obj >= end_date_obj:
51
+ return None, "Error: Start date must be before end date.", None, None, None
52
+
53
+ status_updates = f"Fetching data for {ticker_symbol} from {start_date_str} to {end_date_str}...\n"
54
+
55
+ # 1. Fetch Stock Data
56
+ stock_df = get_stock_data(ticker_symbol, start_date_str, end_date_str)
57
+ if stock_df is None or stock_df.empty:
58
+ status_updates += "Could not fetch stock data.\n"
59
+ # Return early if essential data is missing
60
+ return None, status_updates, None, None, None
61
+ else:
62
+ status_updates += f"Successfully fetched {len(stock_df)} days of stock data.\n"
63
+ stock_df['Date'] = pd.to_datetime(stock_df['Date'])
64
+
65
+
66
+ # 2. Fetch News Articles
67
+ articles_list = get_news_articles(ticker_symbol, start_date_str, end_date_str)
68
+ if articles_list is None or not articles_list:
69
+ status_updates += "Could not fetch news articles or none found.\n"
70
+ news_df = pd.DataFrame()
71
+ else:
72
+ status_updates += f"Found {len(articles_list)} potential news articles.\n"
73
+ news_df = pd.DataFrame(articles_list)
74
+ if 'publishedAt' in news_df.columns:
75
+ news_df['publishedAt'] = pd.to_datetime(news_df['publishedAt'])
76
+ news_df['date'] = news_df['publishedAt'].dt.date
77
+ news_df['date'] = pd.to_datetime(news_df['date']) # Convert date to datetime for merging
78
+ else:
79
+ status_updates += "Warning: News articles missing 'publishedAt' field.\n"
80
+ news_df['date'] = None
81
+
82
+ # 3. Sentiment Analysis (if news available)
83
+ daily_sentiment = pd.DataFrame(columns=['date', 'avg_sentiment_score']) # Default empty
84
+ if not news_df.empty and 'date' in news_df.columns and news_df['date'].notna().any():
85
+ status_updates += f"Performing sentiment analysis on {len(news_df)} articles...\n"
86
+ news_df['text_to_analyze'] = news_df['title'].fillna('') + ". " + news_df['description'].fillna('')
87
+ # --- Apply sentiment analysis ---
88
+ # This can be slow, consider progress updates if possible or running async
89
+ sentiment_results = news_df['text_to_analyze'].apply(lambda x: analyze_sentiment(x) if pd.notna(x) else (None, None, None))
90
+ news_df['sentiment_label'] = sentiment_results.apply(lambda x: x[0])
91
+ news_df['sentiment_score'] = sentiment_results.apply(lambda x: x[1])
92
+ status_updates += "Sentiment analysis complete.\n"
93
+
94
+ # 4. Aggregate Sentiment
95
+ valid_sentiment_df = news_df.dropna(subset=['sentiment_score', 'date'])
96
+ if not valid_sentiment_df.empty:
97
+ daily_sentiment = valid_sentiment_df.groupby('date')['sentiment_score'].mean().reset_index()
98
+ daily_sentiment.rename(columns={'sentiment_score': 'avg_sentiment_score'}, inplace=True)
99
+ status_updates += "Aggregated daily sentiment scores.\n"
100
+ else:
101
+ status_updates += "No valid sentiment scores found to aggregate.\n"
102
+
103
+ # 5. Merge Data
104
+ if not daily_sentiment.empty:
105
+ merged_df = pd.merge(stock_df, daily_sentiment, left_on='Date', right_on='date', how='left')
106
+ if 'date' in merged_df.columns:
107
+ merged_df.drop(columns=['date'], inplace=True)
108
+ status_updates += "Merged stock data with sentiment scores.\n"
109
+ else:
110
+ merged_df = stock_df.copy() # Keep stock data even if no sentiment
111
+ merged_df['avg_sentiment_score'] = None # Add column with None
112
+ status_updates += "No sentiment data to merge.\n"
113
+
114
+ # 6. Calculate Price Change and Lagged Sentiment for Correlation
115
+ merged_df['price_pct_change'] = merged_df['Close'].pct_change()
116
+ merged_df['sentiment_lagged'] = merged_df['avg_sentiment_score'].shift(1)
117
+
118
+ # --- Generate Outputs ---
119
+
120
+ # Plot
121
+ plot_object = None
122
+ if not merged_df.empty:
123
+ fig, ax1 = plt.subplots(figsize=(12, 6)) # Adjusted size for Gradio
124
+
125
+ color = 'tab:blue'
126
+ ax1.set_xlabel('Date')
127
+ ax1.set_ylabel('Stock Close Price', color=color)
128
+ ax1.plot(merged_df['Date'], merged_df['Close'], color=color, label='Stock Price')
129
+ ax1.tick_params(axis='y', labelcolor=color)
130
+ ax1.tick_params(axis='x', rotation=45)
131
+ ax1.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
132
+ ax1.xaxis.set_major_locator(mdates.AutoDateLocator(maxticks=10)) # Auto ticks
133
+
134
+ if 'avg_sentiment_score' in merged_df.columns and merged_df['avg_sentiment_score'].notna().any():
135
+ ax2 = ax1.twinx()
136
+ color = 'tab:red'
137
+ ax2.set_ylabel('Average Sentiment Score', color=color)
138
+ ax2.plot(merged_df['Date'], merged_df['avg_sentiment_score'], color=color, linestyle='--', marker='o', markersize=4, label='Avg Sentiment')
139
+ ax2.tick_params(axis='y', labelcolor=color)
140
+ ax2.axhline(0, color='grey', linestyle='--', linewidth=0.8)
141
+ ax2.set_ylim(-1.1, 1.1) # Fix sentiment axis range
142
+
143
+ # Combine legends
144
+ lines, labels = ax1.get_legend_handles_labels()
145
+ lines2, labels2 = ax2.get_legend_handles_labels()
146
+ ax2.legend(lines + lines2, labels + labels2, loc='upper left')
147
+ else:
148
+ ax1.legend(loc='upper left') # Only stock legend
149
+
150
+ plt.title(f"{ticker_symbol} Stock Price vs. Average Daily News Sentiment")
151
+ plt.grid(True, which='major', linestyle='--', linewidth='0.5', color='grey')
152
+ fig.tight_layout()
153
+ plot_object = fig # Return the figure object for Gradio plot component
154
+ status_updates += "Generated plot.\n"
155
+
156
+
157
+ # Correlation & Insights Text
158
+ insights_text = "## Analysis Results\n\n"
159
+ correlation = None
160
+ if 'sentiment_lagged' in merged_df.columns and merged_df['sentiment_lagged'].notna().any() and merged_df['price_pct_change'].notna().any():
161
+ correlation_df = merged_df[['sentiment_lagged', 'price_pct_change']].dropna()
162
+ if not correlation_df.empty and len(correlation_df) > 1:
163
+ correlation = correlation_df['sentiment_lagged'].corr(correlation_df['price_pct_change'])
164
+ insights_text += f"**Correlation (Lagged Sentiment vs Price Change):** {correlation:.4f}\n"
165
+ insights_text += "_Measures correlation between the previous day's average sentiment and the current day's price percentage change._\n\n"
166
+ else:
167
+ insights_text += "Correlation: Not enough overlapping data points to calculate.\n\n"
168
+ else:
169
+ insights_text += "Correlation: Sentiment or price change data missing.\n\n"
170
+
171
+ # Simple Insights
172
+ insights_text += "**Potential Insights (Not Financial Advice):**\n"
173
+ if 'avg_sentiment_score' in merged_df.columns and merged_df['avg_sentiment_score'].notna().any():
174
+ avg_sentiment_overall = merged_df['avg_sentiment_score'].mean()
175
+ insights_text += f"- Average Sentiment (Overall Period): {avg_sentiment_overall:.3f}\n"
176
+
177
+ if correlation is not None and pd.notna(correlation):
178
+ if correlation > 0.15:
179
+ insights_text += "- Positive correlation detected. Higher sentiment yesterday tended to correlate with price increases today.\n"
180
+ elif correlation < -0.15:
181
+ insights_text += "- Negative correlation detected. Higher sentiment yesterday tended to correlate with price decreases today (or vice-versa).\n"
182
+ else:
183
+ insights_text += "- Weak correlation detected. Sentiment may not be a strong short-term driver for this period.\n"
184
+ else:
185
+ insights_text += "- No sentiment data available to generate insights.\n"
186
+
187
+ insights_text += "\n**Disclaimer:** This analysis is automated and NOT financial advice. Many factors influence stock prices."
188
+ status_updates += "Generated insights.\n"
189
+
190
+ # Recent News DataFrame
191
+ recent_news_df = pd.DataFrame()
192
+ if not news_df.empty and 'publishedAt' in news_df.columns:
193
+ # Select and format columns for display
194
+ cols_to_show = ['publishedAt', 'title', 'sentiment_label', 'sentiment_score']
195
+ # Ensure all columns exist before selecting
196
+ cols_exist = [col for col in cols_to_show if col in news_df.columns]
197
+ if cols_exist:
198
+ recent_news_df = news_df.sort_values(by='publishedAt', ascending=False)[cols_exist].head(10)
199
+ # Format date for display
200
+ recent_news_df['publishedAt'] = recent_news_df['publishedAt'].dt.strftime('%Y-%m-%d %H:%M')
201
+ status_updates += "Prepared recent news table.\n"
202
+
203
+
204
+ return plot_object, insights_text, recent_news_df, status_updates, merged_df # Return merged_df for potential download
205
+
206
+
207
+ # --- Gradio Interface Definition ---
208
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
209
+ gr.Markdown("# Stock Sentiment Analysis Dashboard")
210
+
211
+ with gr.Row():
212
+ with gr.Column(scale=1):
213
+ ticker_input = gr.Textbox(label="Stock Ticker", value="AAPL", placeholder="e.g., AAPL, GOOGL")
214
+ # Use Textbox for dates, value should be string
215
+ start_date_input = gr.Textbox(label="Start Date (YYYY-MM-DD)", value=(datetime.now() - timedelta(days=30)).strftime('%Y-%m-%d'))
216
+ end_date_input = gr.Textbox(label="End Date (YYYY-MM-DD)", value=datetime.now().strftime('%Y-%m-%d'))
217
+ analyze_button = gr.Button("Analyze", variant="primary")
218
+ status_output = gr.Textbox(label="Analysis Status", lines=5, interactive=False)
219
+ # Optional: Add download button for the merged data
220
+ download_data = gr.File(label="Download Merged Data (CSV)")
221
+
222
+
223
+ with gr.Column(scale=3):
224
+ plot_output = gr.Plot(label="Stock Price vs. Sentiment")
225
+ insights_output = gr.Markdown(label="Analysis & Insights")
226
+ news_output = gr.DataFrame(label="Recent News Headlines", headers=['Date', 'Title', 'Sentiment', 'Score'], wrap=True)
227
+
228
+ # Hidden state to store the merged dataframe for download
229
+ merged_df_state = gr.State(None)
230
+
231
+ def run_analysis_and_prepare_download(ticker, start_date_str, end_date_str): # Use string names
232
+ """Wrapper function to run analysis and prepare CSV for download."""
233
+ # Parse dates inside the wrapper or ensure perform_analysis handles strings robustly
234
+ try:
235
+ start_date_obj = datetime.strptime(start_date_str, '%Y-%m-%d').date()
236
+ end_date_obj = datetime.strptime(end_date_str, '%Y-%m-%d').date()
237
+ except ValueError:
238
+ # Handle invalid date format input from textbox
239
+ return None, "Error: Invalid date format. Please use YYYY-MM-DD.", None, "Error processing dates.", None, None
240
+
241
+ plot, insights, news, status, merged_df = perform_analysis(ticker, start_date_str, end_date_str) # Pass strings
242
+
243
+ csv_path = None
244
+ if merged_df is not None and not merged_df.empty:
245
+ # Save to a temporary CSV file for Gradio download
246
+ csv_path = "temp_merged_data.csv"
247
+ merged_df.to_csv(csv_path, index=False)
248
+
249
+ return plot, insights, news, status, merged_df, csv_path # Return path for download
250
+
251
+
252
+ analyze_button.click(
253
+ fn=run_analysis_and_prepare_download,
254
+ inputs=[ticker_input, start_date_input, end_date_input], # Inputs are now textboxes
255
+ outputs=[plot_output, insights_output, news_output, status_output, merged_df_state, download_data] # Update state and file output
256
+ )
257
+
258
+ # --- Launch the App ---
259
+ if __name__ == "__main__":
260
+ demo.launch()
notebooks/financial_sentiment_analysis.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
requirements.txt CHANGED
@@ -7,4 +7,5 @@ transformers
7
  scikit-learn
8
  matplotlib
9
  nltk
10
- python-dotenv
 
 
7
  scikit-learn
8
  matplotlib
9
  nltk
10
+ python-dotenv
11
+ gradio