Upload 9 files
Browse files- README.md +19 -11
- app.py +282 -0
- dashboard.py +221 -0
- llm.py +220 -0
- logger.py +15 -0
- model.py +710 -0
- requirements.txt +46 -0
- ui.py +132 -0
- visualizations.py +360 -0
README.md
CHANGED
@@ -1,12 +1,20 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
emoji: 💻
|
4 |
-
colorFrom: pink
|
5 |
-
colorTo: pink
|
6 |
-
sdk: streamlit
|
7 |
-
sdk_version: 1.43.2
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
---
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Stock_Analytics
|
2 |
+
Stock Prediction and Analysis Script Overview This script is designed to predict stock prices using various machine learning and statistical models. It fetches historical stock data, processes it, and then applies several predictive models. The results, including forecasts and model coefficients, are saved to an Excel file for further analysis.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
+
1. Data Ingestion and Preprocessing
|
5 |
+
Data Source: Historical stock data is fetched using the yfinance library, which provides access to financial data directly from Yahoo Finance.
|
6 |
+
Preprocessing: The data is then cleaned and processed using pandas and numpy for further analysis. This includes handling missing values, calculating moving averages, and other necessary data transformations.
|
7 |
+
2. Technical Analysis and Machine Learning
|
8 |
+
Technical Indicators: Using libraries like pandas and numpy, the project calculates various technical indicators such as moving averages, RSI (Relative Strength Index), Bollinger Bands, etc.
|
9 |
+
Feature Engineering: Features are created and selected for training machine learning models. These features may include technical indicators and other stock-related metrics.
|
10 |
+
Machine Learning Models: The scikit-learn library is used to build predictive models. These models might include Linear Regression, Random Forest, or other algorithms to predict future stock prices.
|
11 |
+
Risk Assessment: The project assesses risk levels associated with each stock, possibly by analyzing volatility, technical indicators, or other metrics.
|
12 |
+
3. Visualization
|
13 |
+
Matplotlib and Seaborn: These libraries are used to create static visualizations such as line plots for stock prices, candlestick charts, and bar plots for feature importance.
|
14 |
+
Plotly: An optional tool for creating interactive visualizations, especially within the Streamlit app.
|
15 |
+
Candlestick Charts: mplfinance is used to generate candlestick charts that visualize open, high, low, and close prices.
|
16 |
+
4. Web Application with Streamlit
|
17 |
+
User Interface: The entire analysis and visualization can be wrapped in a web-based UI using Streamlit. Users can input stock tickers and get visualized results, including price predictions, technical indicators, and risk levels.
|
18 |
+
Custom Styling: The Streamlit app is styled according to user preferences, including setting backgrounds, coloring text and numeric values based on risk levels, and displaying buy signals.
|
19 |
+
Tabs and Layout: Multiple tabs or sections can be created in Streamlit for different types of visualizations like technical indicators, feature importance, and future predictions.
|
20 |
+
5. LLM Integration
|
app.py
ADDED
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
import logging
|
5 |
+
|
6 |
+
from model import fetch_data, calculate_indicators, calculate_support_resistance, predict_future_prices
|
7 |
+
from visualizations import (
|
8 |
+
plot_stock_price, plot_predictions, plot_technical_indicators, plot_risk_levels,
|
9 |
+
plot_feature_importance, plot_candlestick, plot_volume, plot_moving_averages,
|
10 |
+
plot_feature_correlations
|
11 |
+
)
|
12 |
+
from sklearn.metrics import ConfusionMatrixDisplay
|
13 |
+
from ui import display_analysis
|
14 |
+
from logger import get_logger
|
15 |
+
|
16 |
+
from dashboard import display_dashboard, display_profile,fetch_stock_profile,display_quarterly_results, display_shareholding_pattern, display_financial_ratios
|
17 |
+
from llm import display_recommendation #analyze_stock_with_llm
|
18 |
+
|
19 |
+
# import argparse
|
20 |
+
|
21 |
+
# parser = argparse.ArgumentParser()
|
22 |
+
# parser.add_argument('--token', required=True)
|
23 |
+
# args = parser.parse_args()
|
24 |
+
|
25 |
+
# API_TOKEN = args.token
|
26 |
+
|
27 |
+
logger = get_logger(__name__)
|
28 |
+
|
29 |
+
st.title("Stock Analysis and Prediction")
|
30 |
+
|
31 |
+
|
32 |
+
|
33 |
+
# Sidebar for navigation
|
34 |
+
st.sidebar.title("Navigation")
|
35 |
+
|
36 |
+
# Initialize `page` to "Analytics" by default
|
37 |
+
if 'page' not in st.session_state:
|
38 |
+
st.session_state['page'] = "Analytics"
|
39 |
+
|
40 |
+
if st.sidebar.button("Analytics"):
|
41 |
+
st.session_state['page'] = "Analytics"
|
42 |
+
if st.sidebar.button("Ask to AI"):
|
43 |
+
st.session_state['page'] = "Ask to AI"
|
44 |
+
if st.sidebar.button("Dashboard"):
|
45 |
+
st.session_state['page'] = "Dashboard"
|
46 |
+
if st.sidebar.button("Profile"):
|
47 |
+
st.session_state['page'] = "Profile"
|
48 |
+
page = st.session_state['page']
|
49 |
+
|
50 |
+
# Function to fetch and prepare data
|
51 |
+
def get_data():
|
52 |
+
ticker = st.session_state.get('ticker')
|
53 |
+
start_date = st.session_state.get('start_date')
|
54 |
+
end_date = st.session_state.get('end_date')
|
55 |
+
|
56 |
+
try:
|
57 |
+
data = fetch_data(ticker, start_date, end_date)
|
58 |
+
if data is not None:
|
59 |
+
data = calculate_indicators(data)
|
60 |
+
return data
|
61 |
+
else:
|
62 |
+
st.error("Failed to fetch data. Please check the stock ticker symbol and date range.")
|
63 |
+
return None
|
64 |
+
except Exception as e:
|
65 |
+
st.error(f"An error occurred: {e}")
|
66 |
+
return None
|
67 |
+
|
68 |
+
# Display content based on selected page
|
69 |
+
if page == "Analytics":
|
70 |
+
st.header("Analytics")
|
71 |
+
|
72 |
+
# Data input section
|
73 |
+
ticker = st.text_input("Stock Ticker", "BHEL.NS")
|
74 |
+
start_date = st.date_input("Start Date", pd.to_datetime("2020-01-01"))
|
75 |
+
end_date = st.date_input("End Date", pd.to_datetime("2024-09-04"))
|
76 |
+
algorithm = st.selectbox(
|
77 |
+
"Choose an Algorithm",
|
78 |
+
['Linear Regression','LSTM', 'ARIMA','Decision Tree', 'Random Forest', 'XGBoost', 'CatBoost', 'SARIMA']
|
79 |
+
)
|
80 |
+
st.session_state['ticker'] = ticker
|
81 |
+
st.session_state['start_date'] = start_date
|
82 |
+
st.session_state['end_date'] = end_date
|
83 |
+
st.session_state['algorithm'] = algorithm
|
84 |
+
|
85 |
+
# Tabs for Analyze and Visualization under Analytics
|
86 |
+
tab1, tab2 = st.tabs(["Analyze", "Visualization"])
|
87 |
+
|
88 |
+
# Analyze Tab
|
89 |
+
with tab1:
|
90 |
+
if st.button("Analyze"):
|
91 |
+
data = get_data()
|
92 |
+
if data is not None:
|
93 |
+
display_analysis(data, st.session_state.get('algorithm'))
|
94 |
+
|
95 |
+
# Visualization Tab
|
96 |
+
with tab2:
|
97 |
+
st.write("### Visualizations")
|
98 |
+
|
99 |
+
# Fetch and prepare data for visualization
|
100 |
+
data = get_data()
|
101 |
+
if data is not None:
|
102 |
+
indicators = {
|
103 |
+
'SMA_50': data['SMA_50'],
|
104 |
+
'EMA_50': data['EMA_50'],
|
105 |
+
'RSI': data['RSI'],
|
106 |
+
'MACD': data['MACD'],
|
107 |
+
'MACD_Signal': data['MACD_Signal'],
|
108 |
+
'Bollinger_High': data['Bollinger_High'],
|
109 |
+
'Bollinger_Low': data['Bollinger_Low'],
|
110 |
+
'ATR': data['ATR'],
|
111 |
+
'OBV': data['OBV']
|
112 |
+
}
|
113 |
+
|
114 |
+
# Visualization choices
|
115 |
+
choice = st.selectbox(
|
116 |
+
"Choose a type of visualization",
|
117 |
+
[
|
118 |
+
"Stock Price","Volume",
|
119 |
+
"Moving Averages",
|
120 |
+
"Feature Correlations",
|
121 |
+
"Predictions vs Actual",
|
122 |
+
"Technical Indicators",
|
123 |
+
"Risk Levels",
|
124 |
+
"Feature Importance",
|
125 |
+
"Candlestick"
|
126 |
+
]
|
127 |
+
)
|
128 |
+
|
129 |
+
try:
|
130 |
+
if choice == "Stock Price":
|
131 |
+
plot_stock_price(data, st.session_state.get('ticker'), indicators)
|
132 |
+
elif choice == "Predictions vs Actual":
|
133 |
+
future_prices, _, _, _, _ = predict_future_prices(data, st.session_state.get('algorithm'))
|
134 |
+
if future_prices is not None:
|
135 |
+
st.line_chart(pd.DataFrame({'Actual Prices': data['Close'], 'Predicted Prices': pd.Series(future_prices).values}))
|
136 |
+
else:
|
137 |
+
st.error("Failed to fetch predictions.")
|
138 |
+
logger.error("Failed to fetch predictions.")
|
139 |
+
elif choice == "Technical Indicators":
|
140 |
+
plot_technical_indicators(data, indicators)
|
141 |
+
elif choice == "Risk Levels":
|
142 |
+
plot_risk_levels(data)
|
143 |
+
elif choice == "Feature Importance":
|
144 |
+
plot_feature_importance()
|
145 |
+
elif choice == "Candlestick":
|
146 |
+
plot_candlestick(data)
|
147 |
+
elif choice == "Volume":
|
148 |
+
plot_volume(data)
|
149 |
+
elif choice == "Moving Averages":
|
150 |
+
plot_moving_averages(data)
|
151 |
+
elif choice == "Feature Correlations":
|
152 |
+
plot_feature_correlations(data)
|
153 |
+
except Exception as e:
|
154 |
+
logger.error(f"An error occurred during visualization: {e}")
|
155 |
+
st.error(f"An error occurred during visualization: {e}")
|
156 |
+
else:
|
157 |
+
st.error("Failed to fetch data. Please check the stock ticker symbol and date range.")
|
158 |
+
logger.error("Failed to fetch data. Please check the stock ticker symbol and date range.")
|
159 |
+
|
160 |
+
elif page == "Dashboard":
|
161 |
+
|
162 |
+
st.title("Stock analysis and screening tool for investors in India")
|
163 |
+
|
164 |
+
ticker = st.text_input("Enter stock ticker (e.g., TATAMOTORS.NS):").upper()
|
165 |
+
days = st.sidebar.slider("Select number of days for top movers:", 1, 30, 30)
|
166 |
+
|
167 |
+
profile = {}
|
168 |
+
if ticker:
|
169 |
+
profile = fetch_stock_profile(ticker)
|
170 |
+
if profile: # Only display profile if it's not empty
|
171 |
+
display_profile(profile)
|
172 |
+
display_quarterly_results(ticker)
|
173 |
+
display_shareholding_pattern(ticker)
|
174 |
+
display_financial_ratios(ticker)
|
175 |
+
else:
|
176 |
+
st.write("No data available for the ticker entered.")
|
177 |
+
|
178 |
+
st.sidebar.write("### Overview")
|
179 |
+
st.sidebar.write(f"Showing top gainers and losers over the past {days} day(s).")
|
180 |
+
|
181 |
+
display_dashboard()
|
182 |
+
|
183 |
+
# display_profile()
|
184 |
+
# # Display the main dashboard
|
185 |
+
# display_dashboard()
|
186 |
+
|
187 |
+
st.write("<div style='background-color: black; color: white; padding: 10px;'>Coming Soon A lot Updates.......</div>", unsafe_allow_html=True)
|
188 |
+
|
189 |
+
elif page == "Profile":
|
190 |
+
st.image("https://via.placeholder.com/150", caption="User Profile Photo")
|
191 |
+
st.write("### User Profile")
|
192 |
+
st.write("Name: Nandan Dutta")
|
193 |
+
st.write("Role: Data Analyst")
|
194 |
+
st.write("Email: [email protected]")
|
195 |
+
|
196 |
+
elif page == "Ask to AI":
|
197 |
+
st.title("Ask Stock Recommendation to AI")
|
198 |
+
st.write("Model: Meta LLaMA 3.1")
|
199 |
+
|
200 |
+
# Input fields for the user
|
201 |
+
ticker = st.text_input("Enter Stock Ticker (e.g., BHEL.NS, RELIANCE.NS):")
|
202 |
+
start_date = st.date_input("Start Date", value=None)
|
203 |
+
end_date = st.date_input("End Date", value=None)
|
204 |
+
|
205 |
+
if st.button("Get Recommendation"):
|
206 |
+
if ticker and start_date and end_date:
|
207 |
+
# Ensure dates are in the correct format
|
208 |
+
start_date_str = start_date.strftime('%Y-%m-%d')
|
209 |
+
end_date_str = end_date.strftime('%Y-%m-%d')
|
210 |
+
|
211 |
+
st.write(f"Fetching recommendation for {ticker} from {start_date_str} to {end_date_str}...")
|
212 |
+
|
213 |
+
try:
|
214 |
+
# Fetch the recommendation using the LLaMA model
|
215 |
+
recommendations = display_recommendation(ticker, start_date_str, end_date_str)
|
216 |
+
|
217 |
+
except Exception as e:
|
218 |
+
st.error(f"An error occurred: {e}")
|
219 |
+
|
220 |
+
else:
|
221 |
+
st.error("Please enter a valid ticker and date range.")
|
222 |
+
|
223 |
+
|
224 |
+
|
225 |
+
|
226 |
+
|
227 |
+
|
228 |
+
|
229 |
+
st.markdown(
|
230 |
+
"""
|
231 |
+
<style>
|
232 |
+
@keyframes blink {
|
233 |
+
0% { opacity: 1; }
|
234 |
+
50% { opacity: 0; }
|
235 |
+
100% { opacity: 1; }
|
236 |
+
}
|
237 |
+
.blinking-heart {
|
238 |
+
animation: blink 1s infinite;
|
239 |
+
}
|
240 |
+
</style>
|
241 |
+
<div style='background-color: #f1f1f1; color: #333; padding: 5px; text-align: center; border-top: 1px solid #ddd;'>
|
242 |
+
<p>Made with <span class="blinking-heart">❤️</span> from Nandan</p>
|
243 |
+
</div>
|
244 |
+
""",
|
245 |
+
unsafe_allow_html=True
|
246 |
+
)
|
247 |
+
|
248 |
+
|
249 |
+
# Display animated running disclaimer text
|
250 |
+
st.write(
|
251 |
+
"""
|
252 |
+
<div style='background-color: black; color: white; padding: 10px; border-radius: 5px;'>
|
253 |
+
<marquee behavior="scroll" direction="left" scrollamount="5" style="font-size: 14px;">
|
254 |
+
This project is for educational purposes only. The information provided here should not be used for real investment decisions. Please perform your own research and consult with a financial advisor before making any investment decisions. Use this information at your own risk.
|
255 |
+
</marquee>
|
256 |
+
</div>
|
257 |
+
""",
|
258 |
+
unsafe_allow_html=True
|
259 |
+
)
|
260 |
+
|
261 |
+
|
262 |
+
|
263 |
+
|
264 |
+
|
265 |
+
|
266 |
+
|
267 |
+
|
268 |
+
|
269 |
+
|
270 |
+
|
271 |
+
|
272 |
+
|
273 |
+
|
274 |
+
|
275 |
+
|
276 |
+
|
277 |
+
|
278 |
+
|
279 |
+
|
280 |
+
|
281 |
+
|
282 |
+
|
dashboard.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import yfinance as yf
|
2 |
+
import pandas as pd
|
3 |
+
import streamlit as st
|
4 |
+
from datetime import datetime, timedelta
|
5 |
+
|
6 |
+
# Fetch Nifty 50 tickers
|
7 |
+
def fetch_nifty50_tickers():
|
8 |
+
return [
|
9 |
+
"TATAMOTORS.NS", "RELIANCE.NS", "INFY.NS", "HDFCBANK.NS", "ICICIBANK.NS",
|
10 |
+
"SBIN.NS", "ITC.NS", "AXISBANK.NS", "MARUTI.NS", "TATASTEEL.NS",
|
11 |
+
"WIPRO.NS", "SUNPHARMA.NS", "HINDALCO.NS", "HCLTECH.NS", "NTPC.NS",
|
12 |
+
"L&T.NS", "M&M.NS", "ONGC.NS", "HDFCLIFE.NS", "ULTRACEMCO.NS",
|
13 |
+
"ADANIGREEN.NS", "BHARTIARTL.NS", "BAJAJFINSV.NS", "JSWSTEEL.NS", "DIVISLAB.NS",
|
14 |
+
"POWERGRID.NS", "KOTAKBANK.NS", "HINDUNILVR.NS", "TCS.NS", "CIPLA.NS",
|
15 |
+
"ASIANPAINT.NS", "GRASIM.NS", "BRITANNIA.NS", "SHREECEM.NS",
|
16 |
+
"TECHM.NS", "INDUSINDBK.NS", "EICHERMOT.NS", "COALINDIA.NS", "GAIL.NS",
|
17 |
+
"BOSCHLTD.NS", "M&MFIN.NS", "IDFCFIRSTB.NS", "HAVELLS.NS"
|
18 |
+
]
|
19 |
+
|
20 |
+
# Fetch large cap tickers
|
21 |
+
def fetch_large_cap_tickers():
|
22 |
+
return fetch_nifty50_tickers() # Assuming large caps are the same as Nifty 50
|
23 |
+
|
24 |
+
# Fetch small cap tickers
|
25 |
+
def fetch_small_cap_tickers():
|
26 |
+
return [
|
27 |
+
"ALOKINDS.NS", "ADANIENT.NS", "AARTIIND.NS", "AVANTIFEED.NS", "BLS.IN",
|
28 |
+
"BHEL.NS", "BIRLACORP.NS", "CARBORUNIV.NS", "CENTRALBANK.NS", "EMAMILTD.NS",
|
29 |
+
"FDC.NS", "GLAXO.NS", "GODFRYPHLP.NS", "GSKCONS.NS", "HAVELLS.NS",
|
30 |
+
"HEMIPAPER.NS", "HIL.NS", "JINDALSAW.NS", "JUBLFOOD.NS", "KOTAKMAH.NS",
|
31 |
+
"MSTCLAS.NS", "NCC.NS", "PAGEIND.NS", "PIIND.NS", "SBI.CN",
|
32 |
+
"SISL.NS", "SOMANYCERA.NS", "STAR.NS", "SUNDARAM.NS", "TATAINVEST.NS",
|
33 |
+
"VSTIND.NS", "WABCOINDIA.NS", "WELCORP.NS", "ZEELEARN.NS", "ZOMATO.NS"
|
34 |
+
]
|
35 |
+
|
36 |
+
# Get top movers
|
37 |
+
def get_top_movers(tickers, days=1):
|
38 |
+
end_date = datetime.now()
|
39 |
+
start_date = end_date - timedelta(days=days)
|
40 |
+
|
41 |
+
data = {}
|
42 |
+
for ticker in tickers:
|
43 |
+
try:
|
44 |
+
df = yf.download(ticker, start=start_date, end=end_date)
|
45 |
+
if not df.empty and 'Close' in df.columns:
|
46 |
+
df['Ticker'] = ticker
|
47 |
+
data[ticker] = df['Close'].pct_change().iloc[-1] # Percentage change
|
48 |
+
except Exception as e:
|
49 |
+
st.error(f"Error fetching data for {ticker}: {e}")
|
50 |
+
|
51 |
+
sorted_data = sorted(data.items(), key=lambda x: x[1], reverse=True)
|
52 |
+
top_gainers = sorted_data[:10]
|
53 |
+
top_losers = sorted_data[-10:]
|
54 |
+
|
55 |
+
return top_gainers, top_losers
|
56 |
+
|
57 |
+
# Format DataFrame with color
|
58 |
+
def format_df(df):
|
59 |
+
if not df.empty:
|
60 |
+
df['Percentage Change'] = pd.to_numeric(df['Percentage Change'], errors='coerce')
|
61 |
+
return df.style.applymap(lambda x: 'color: green' if x > 0 else 'color: red', subset=['Percentage Change'])
|
62 |
+
return df
|
63 |
+
|
64 |
+
# Display dashboard
|
65 |
+
def display_dashboard():
|
66 |
+
st.header("Dashboard")
|
67 |
+
|
68 |
+
# Fetch tickers
|
69 |
+
nifty50_tickers = fetch_nifty50_tickers()
|
70 |
+
large_cap_tickers = fetch_large_cap_tickers()
|
71 |
+
small_cap_tickers = fetch_small_cap_tickers()
|
72 |
+
|
73 |
+
# Get top gainers and losers
|
74 |
+
top_gainers_nifty50, top_losers_nifty50 = get_top_movers(nifty50_tickers)
|
75 |
+
top_gainers_large_cap, top_losers_large_cap = get_top_movers(large_cap_tickers)
|
76 |
+
top_gainers_small_cap, top_losers_small_cap = get_top_movers(small_cap_tickers)
|
77 |
+
|
78 |
+
# Create columns for tables
|
79 |
+
col1, col2, col3, col4 = st.columns(4)
|
80 |
+
|
81 |
+
with col1:
|
82 |
+
st.write("### Nifty 50 Top Gainers")
|
83 |
+
if top_gainers_nifty50:
|
84 |
+
df_gainers_nifty50 = pd.DataFrame(top_gainers_nifty50, columns=['Ticker', 'Percentage Change'])
|
85 |
+
st.dataframe(format_df(df_gainers_nifty50))
|
86 |
+
|
87 |
+
with col2:
|
88 |
+
st.write("### Nifty 50 Top Losers")
|
89 |
+
if top_losers_nifty50:
|
90 |
+
df_losers_nifty50 = pd.DataFrame(top_losers_nifty50, columns=['Ticker', 'Percentage Change'])
|
91 |
+
st.dataframe(format_df(df_losers_nifty50))
|
92 |
+
|
93 |
+
with col3:
|
94 |
+
st.write("### Large Cap Top Gainers")
|
95 |
+
if top_gainers_large_cap:
|
96 |
+
df_gainers_large_cap = pd.DataFrame(top_gainers_large_cap, columns=['Ticker', 'Percentage Change'])
|
97 |
+
st.dataframe(format_df(df_gainers_large_cap))
|
98 |
+
|
99 |
+
with col4:
|
100 |
+
st.write("### Large Cap Top Losers")
|
101 |
+
if top_losers_large_cap:
|
102 |
+
df_losers_large_cap = pd.DataFrame(top_losers_large_cap, columns=['Ticker', 'Percentage Change'])
|
103 |
+
st.dataframe(format_df(df_losers_large_cap))
|
104 |
+
|
105 |
+
# Fetch and display stock profile
|
106 |
+
def fetch_stock_profile(ticker):
|
107 |
+
try:
|
108 |
+
stock = yf.Ticker(ticker)
|
109 |
+
info = stock.info
|
110 |
+
|
111 |
+
profile = {
|
112 |
+
"Name": info.get('shortName', 'N/A'),
|
113 |
+
"Current Price": f"₹ {info.get('currentPrice', 'N/A')}",
|
114 |
+
"Market Cap": f"₹ {info.get('marketCap', 'N/A') / 1e7:.2f} Cr.",
|
115 |
+
"P/E Ratio": info.get('forwardEps', 'N/A'),
|
116 |
+
"Book Value": info.get('bookValue', 'N/A'),
|
117 |
+
"Dividend Yield": info.get('dividendYield', 'N/A'),
|
118 |
+
"ROCE": info.get('returnOnCapitalEmployed', 'N/A'),
|
119 |
+
"ROE": info.get('returnOnEquity', 'N/A'),
|
120 |
+
"Face Value": info.get('faceValue', 'N/A')
|
121 |
+
}
|
122 |
+
return profile
|
123 |
+
except Exception as e:
|
124 |
+
st.error(f"Error fetching profile for {ticker}: {e}")
|
125 |
+
return {}
|
126 |
+
|
127 |
+
|
128 |
+
|
129 |
+
|
130 |
+
# Display stock profile as a table
|
131 |
+
def display_profile(profile):
|
132 |
+
st.subheader("Stock Profile")
|
133 |
+
profile_df = pd.DataFrame([profile])
|
134 |
+
st.table(profile_df)
|
135 |
+
|
136 |
+
# Fetch and display quarterly results
|
137 |
+
def display_quarterly_results(ticker):
|
138 |
+
st.subheader("Quarterly Results Summary")
|
139 |
+
try:
|
140 |
+
stock = yf.Ticker(ticker)
|
141 |
+
financials = stock.quarterly_financials.T
|
142 |
+
if not financials.empty:
|
143 |
+
results = {
|
144 |
+
'Sales': financials['Total Revenue'].iloc[-1] if 'Total Revenue' in financials.columns else 'N/A',
|
145 |
+
'Operating Profit Margin': financials['Operating Income'].iloc[-1] if 'Operating Income' in financials.columns else 'N/A',
|
146 |
+
'Net Profit': financials['Net Income'].iloc[-1] if 'Net Income' in financials.columns else 'N/A'
|
147 |
+
}
|
148 |
+
results_df = pd.DataFrame([results])
|
149 |
+
st.table(results_df)
|
150 |
+
else:
|
151 |
+
st.write("No quarterly results available.")
|
152 |
+
except Exception as e:
|
153 |
+
st.write(f"Error fetching quarterly results: {e}")
|
154 |
+
|
155 |
+
# Fetch and display shareholding pattern
|
156 |
+
def display_shareholding_pattern(ticker):
|
157 |
+
st.subheader("Shareholding Pattern")
|
158 |
+
|
159 |
+
# Placeholder values; replace with actual data source or API call
|
160 |
+
data = {
|
161 |
+
'Category': ['Promoters', 'FIIs (Foreign Institutional Investors)', 'DIIs (Domestic Institutional Investors)', 'Public'],
|
162 |
+
'Holding (%)': [45.0, 20.0, 15.0, 20.0]
|
163 |
+
}
|
164 |
+
|
165 |
+
df = pd.DataFrame(data)
|
166 |
+
st.table(df)
|
167 |
+
|
168 |
+
|
169 |
+
def display_financial_ratios(ticker):
|
170 |
+
st.subheader("Financial Ratios")
|
171 |
+
stock = yf.Ticker(ticker)
|
172 |
+
|
173 |
+
try:
|
174 |
+
# Placeholder values, calculate actual values based on your requirements
|
175 |
+
ratios = {
|
176 |
+
'Debtor Days': 73,
|
177 |
+
'Working Capital Days': 194,
|
178 |
+
'Cash Conversion Cycle': 51
|
179 |
+
}
|
180 |
+
ratios_df = pd.DataFrame([ratios])
|
181 |
+
st.table(ratios_df)
|
182 |
+
except Exception as e:
|
183 |
+
st.write("Error fetching financial ratios:", e)
|
184 |
+
|
185 |
+
|
186 |
+
|
187 |
+
|
188 |
+
|
189 |
+
|
190 |
+
|
191 |
+
|
192 |
+
|
193 |
+
|
194 |
+
|
195 |
+
|
196 |
+
|
197 |
+
|
198 |
+
|
199 |
+
|
200 |
+
|
201 |
+
# Main application
|
202 |
+
# def main():
|
203 |
+
# st.title("Stock Analysis Dashboard")
|
204 |
+
|
205 |
+
# # Select ticker input
|
206 |
+
# ticker = st.text_input("Enter Stock Ticker (e.g., TATAMOTORS.NS)")
|
207 |
+
|
208 |
+
# if ticker:
|
209 |
+
# profile = fetch_stock_profile(ticker)
|
210 |
+
# if profile:
|
211 |
+
# display_profile(profile)
|
212 |
+
|
213 |
+
# display_quarterly_results(ticker)
|
214 |
+
# display_shareholding_pattern(ticker)
|
215 |
+
|
216 |
+
# # Show dashboard
|
217 |
+
# if st.button("Show Dashboard"):
|
218 |
+
# display_dashboard()
|
219 |
+
|
220 |
+
# if __name__ == "__main__":
|
221 |
+
# main()
|
llm.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import streamlit as st
|
3 |
+
import requests
|
4 |
+
from model import fetch_data, calculate_indicators, calculate_support_resistance
|
5 |
+
import os
|
6 |
+
|
7 |
+
# import argparse
|
8 |
+
|
9 |
+
# parser = argparse.ArgumentParser()
|
10 |
+
# parser.add_argument('--token', required=True)
|
11 |
+
# args = parser.parse_args()
|
12 |
+
|
13 |
+
# API_TOKEN = args.token
|
14 |
+
# Hugging Face API token and model URL
|
15 |
+
# API_TOKEN = os.environ['HUGGING_FACE_TOKEN']
|
16 |
+
API_TOKEN ='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'
|
17 |
+
API_URL = "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3.1-8B-Instruct"
|
18 |
+
|
19 |
+
def generate_prompt(ticker, start_date, end_date):
|
20 |
+
"""Fetch data, calculate indicators, and prepare the prompt."""
|
21 |
+
data = fetch_data(ticker, start_date, end_date)
|
22 |
+
|
23 |
+
if data is None:
|
24 |
+
return "No data available for the given ticker and date range.", None
|
25 |
+
|
26 |
+
data = calculate_indicators(data)
|
27 |
+
support, resistance = calculate_support_resistance(data)
|
28 |
+
|
29 |
+
# Additional statistics
|
30 |
+
highest_close = data['Close'].max()
|
31 |
+
lowest_close = data['Close'].min()
|
32 |
+
average_close = data['Close'].mean()
|
33 |
+
average_volume = data['Volume'].mean()
|
34 |
+
highest_volume = data['Volume'].max()
|
35 |
+
lowest_volume = data['Volume'].min()
|
36 |
+
daily_returns = data['Close'].pct_change().dropna()
|
37 |
+
volatility = daily_returns.std()
|
38 |
+
|
39 |
+
recent_trend = "uptrend" if data['Close'].iloc[-1] > data['Close'].iloc[0] else "downtrend" if data['Close'].iloc[-1] < data['Close'].iloc[0] else "sideways"
|
40 |
+
|
41 |
+
# Summarize the key statistics
|
42 |
+
summary = {
|
43 |
+
'latest_close': data['Close'].iloc[-1],
|
44 |
+
'SMA_50': data['SMA_50'].iloc[-1],
|
45 |
+
'EMA_50': data['EMA_50'].iloc[-1],
|
46 |
+
'RSI': data['RSI'].iloc[-1],
|
47 |
+
'MACD': data['MACD'].iloc[-1],
|
48 |
+
'MACD_Signal': data['MACD_Signal'].iloc[-1],
|
49 |
+
'Bollinger_High': data['Bollinger_High'].iloc[-1],
|
50 |
+
'Bollinger_Low': data['Bollinger_Low'].iloc[-1],
|
51 |
+
'ATR': data['ATR'].iloc[-1],
|
52 |
+
'OBV': data['OBV'].iloc[-1],
|
53 |
+
'Support': support,
|
54 |
+
'Resistance': resistance,
|
55 |
+
'Highest_Close': highest_close,
|
56 |
+
'Lowest_Close': lowest_close,
|
57 |
+
'Average_Close': average_close,
|
58 |
+
'Average_Volume': average_volume,
|
59 |
+
'Highest_Volume': highest_volume,
|
60 |
+
'Lowest_Volume': lowest_volume,
|
61 |
+
'Volatility': volatility,
|
62 |
+
'Recent_Trend': recent_trend,
|
63 |
+
'Percentage_Change': (data['Close'].iloc[-1] - data['Close'].iloc[-2]) / data['Close'].iloc[-2] * 100 if len(data) > 1 else 0
|
64 |
+
}
|
65 |
+
|
66 |
+
prompt = f"""
|
67 |
+
Analyze the following stock data for {ticker} and provide a buy/sell recommendation:
|
68 |
+
Latest Close Price: {summary['latest_close']}
|
69 |
+
SMA 50: {summary['SMA_50']}
|
70 |
+
EMA 50: {summary['EMA_50']}
|
71 |
+
RSI: {summary['RSI']}
|
72 |
+
MACD: {summary['MACD']}
|
73 |
+
MACD Signal: {summary['MACD_Signal']}
|
74 |
+
Bollinger Bands High: {summary['Bollinger_High']}
|
75 |
+
Bollinger Bands Low: {summary['Bollinger_Low']}
|
76 |
+
ATR: {summary['ATR']}
|
77 |
+
OBV: {summary['OBV']}
|
78 |
+
Support Level: {summary['Support']}
|
79 |
+
Resistance Level: {summary['Resistance']}
|
80 |
+
Highest Close Price: {summary['Highest_Close']}
|
81 |
+
Lowest Close Price: {summary['Lowest_Close']}
|
82 |
+
Average Close Price: {summary['Average_Close']}
|
83 |
+
Average Volume: {summary['Average_Volume']}
|
84 |
+
Highest Volume: {summary['Highest_Volume']}
|
85 |
+
Lowest Volume: {summary['Lowest_Volume']}
|
86 |
+
Volatility: {summary['Volatility']}
|
87 |
+
Recent Trend: {summary['Recent_Trend']}
|
88 |
+
Percentage Change: {summary['Percentage_Change']}%
|
89 |
+
"""
|
90 |
+
|
91 |
+
return prompt, summary
|
92 |
+
|
93 |
+
def get_recommendation(prompt):
|
94 |
+
"""Get stock recommendation from Hugging Face API."""
|
95 |
+
headers = {
|
96 |
+
"Authorization": f"Bearer {API_TOKEN}",
|
97 |
+
"Content-Type": "application/json"
|
98 |
+
}
|
99 |
+
payload = {"inputs": prompt}
|
100 |
+
|
101 |
+
response = requests.post(API_URL, headers=headers, json=payload)
|
102 |
+
response.raise_for_status() # Raise an error for HTTP issues
|
103 |
+
result = response.json()
|
104 |
+
|
105 |
+
return result[0]['generated_text'].strip()
|
106 |
+
|
107 |
+
def display_recommendation(ticker, start_date, end_date):
|
108 |
+
"""Fetch data, generate prompt, get recommendation, and display it in a nice format."""
|
109 |
+
prompt, summary = generate_prompt(ticker, start_date, end_date)
|
110 |
+
|
111 |
+
if summary is None:
|
112 |
+
st.error(prompt)
|
113 |
+
return
|
114 |
+
|
115 |
+
try:
|
116 |
+
recommendation = get_recommendation(prompt)
|
117 |
+
except Exception as e:
|
118 |
+
st.error(f"An error occurred while getting recommendation: {e}")
|
119 |
+
return
|
120 |
+
|
121 |
+
# Display in a box/table format using Streamlit
|
122 |
+
st.markdown(f"### Stock Analysis & Recommendation for {ticker}")
|
123 |
+
|
124 |
+
st.markdown(f"""
|
125 |
+
<div style='border:2px solid #4CAF50; padding: 15px; border-radius: 10px;'>
|
126 |
+
<table style='width:100%; border-collapse: collapse;'>
|
127 |
+
<tr>
|
128 |
+
<th style='text-align: left;'>Indicator</th>
|
129 |
+
<th style='text-align: left;'>Value</th>
|
130 |
+
</tr>
|
131 |
+
<tr>
|
132 |
+
<td>Latest Close Price</td>
|
133 |
+
<td>{summary['latest_close']}</td>
|
134 |
+
</tr>
|
135 |
+
<tr>
|
136 |
+
<td>SMA 50</td>
|
137 |
+
<td>{summary['SMA_50']}</td>
|
138 |
+
</tr>
|
139 |
+
<tr>
|
140 |
+
<td>EMA 50</td>
|
141 |
+
<td>{summary['EMA_50']}</td>
|
142 |
+
</tr>
|
143 |
+
<tr>
|
144 |
+
<td>RSI</td>
|
145 |
+
<td>{summary['RSI']}</td>
|
146 |
+
</tr>
|
147 |
+
<tr>
|
148 |
+
<td>MACD</td>
|
149 |
+
<td>{summary['MACD']}</td>
|
150 |
+
</tr>
|
151 |
+
<tr>
|
152 |
+
<td>MACD Signal</td>
|
153 |
+
<td>{summary['MACD_Signal']}</td>
|
154 |
+
</tr>
|
155 |
+
<tr>
|
156 |
+
<td>Bollinger Bands High</td>
|
157 |
+
<td>{summary['Bollinger_High']}</td>
|
158 |
+
</tr>
|
159 |
+
<tr>
|
160 |
+
<td>Bollinger Bands Low</td>
|
161 |
+
<td>{summary['Bollinger_Low']}</td>
|
162 |
+
</tr>
|
163 |
+
<tr>
|
164 |
+
<td>ATR</td>
|
165 |
+
<td>{summary['ATR']}</td>
|
166 |
+
</tr>
|
167 |
+
<tr>
|
168 |
+
<td>OBV</td>
|
169 |
+
<td>{summary['OBV']}</td>
|
170 |
+
</tr>
|
171 |
+
<tr>
|
172 |
+
<td>Support Level</td>
|
173 |
+
<td>{summary['Support']}</td>
|
174 |
+
</tr>
|
175 |
+
<tr>
|
176 |
+
<td>Resistance Level</td>
|
177 |
+
<td>{summary['Resistance']}</td>
|
178 |
+
</tr>
|
179 |
+
<tr>
|
180 |
+
<td>Highest Close Price</td>
|
181 |
+
<td>{summary['Highest_Close']}</td>
|
182 |
+
</tr>
|
183 |
+
<tr>
|
184 |
+
<td>Lowest Close Price</td>
|
185 |
+
<td>{summary['Lowest_Close']}</td>
|
186 |
+
</tr>
|
187 |
+
<tr>
|
188 |
+
<td>Average Close Price</td>
|
189 |
+
<td>{summary['Average_Close']}</td>
|
190 |
+
</tr>
|
191 |
+
<tr>
|
192 |
+
<td>Average Volume</td>
|
193 |
+
<td>{summary['Average_Volume']}</td>
|
194 |
+
</tr>
|
195 |
+
<tr>
|
196 |
+
<td>Highest Volume</td>
|
197 |
+
<td>{summary['Highest_Volume']}</td>
|
198 |
+
</tr>
|
199 |
+
<tr>
|
200 |
+
<td>Lowest Volume</td>
|
201 |
+
<td>{summary['Lowest_Volume']}</td>
|
202 |
+
</tr>
|
203 |
+
<tr>
|
204 |
+
<td>Volatility</td>
|
205 |
+
<td>{summary['Volatility']}</td>
|
206 |
+
</tr>
|
207 |
+
<tr>
|
208 |
+
<td>Recent Trend</td>
|
209 |
+
<td>{summary['Recent_Trend']}</td>
|
210 |
+
</tr>
|
211 |
+
<tr>
|
212 |
+
<td>Percentage Change</td>
|
213 |
+
<td>{summary['Percentage_Change']}%</td>
|
214 |
+
</tr>
|
215 |
+
</table>
|
216 |
+
</div>
|
217 |
+
""", unsafe_allow_html=True)
|
218 |
+
|
219 |
+
st.write('AI Recommendation:')
|
220 |
+
st.write(recommendation)
|
logger.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# logger.py
|
2 |
+
import logging
|
3 |
+
|
4 |
+
def get_logger(name):
|
5 |
+
"""
|
6 |
+
Create and return a logger instance with a specified name.
|
7 |
+
"""
|
8 |
+
logger = logging.getLogger(name)
|
9 |
+
if not logger.hasHandlers():
|
10 |
+
handler = logging.StreamHandler()
|
11 |
+
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
12 |
+
handler.setFormatter(formatter)
|
13 |
+
logger.addHandler(handler)
|
14 |
+
logger.setLevel(logging.INFO)
|
15 |
+
return logger
|
model.py
ADDED
@@ -0,0 +1,710 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import yfinance as yf
|
3 |
+
import pandas as pd
|
4 |
+
import ta
|
5 |
+
from sklearn.model_selection import train_test_split
|
6 |
+
from sklearn.linear_model import LinearRegression
|
7 |
+
from sklearn.tree import DecisionTreeRegressor
|
8 |
+
from sklearn.ensemble import RandomForestRegressor
|
9 |
+
from sklearn.metrics import mean_absolute_error, r2_score
|
10 |
+
import xgboost as xgb
|
11 |
+
from catboost import CatBoostRegressor
|
12 |
+
import numpy as np
|
13 |
+
from tensorflow.keras.models import Sequential
|
14 |
+
from tensorflow.keras.layers import LSTM, Dense
|
15 |
+
from sklearn.preprocessing import MinMaxScaler
|
16 |
+
from statsmodels.tsa.arima.model import ARIMA
|
17 |
+
from statsmodels.tsa.statespace.sarimax import SARIMAX
|
18 |
+
|
19 |
+
|
20 |
+
from logger import get_logger
|
21 |
+
|
22 |
+
logger = get_logger(__name__)
|
23 |
+
# logger.setLevel(logging.DEBUG)
|
24 |
+
# handler = logging.StreamHandler()
|
25 |
+
# handler.setLevel(logging.DEBUG)
|
26 |
+
# formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
27 |
+
# handler.setFormatter(formatter)
|
28 |
+
# logger.addHandler(handler)
|
29 |
+
|
30 |
+
# # Example usage of logger
|
31 |
+
# logger.info("This is an info message")
|
32 |
+
|
33 |
+
# Fetch historical data
|
34 |
+
def fetch_data(ticker, start_date, end_date):
|
35 |
+
logger.info(f"Fetching data for {ticker} from {start_date} to {end_date}")
|
36 |
+
data = yf.download(ticker, start=start_date, end=end_date)
|
37 |
+
if data.empty:
|
38 |
+
logger.warning(f"No data returned for {ticker}.")
|
39 |
+
return None
|
40 |
+
|
41 |
+
# Reset index to ensure Date is a column
|
42 |
+
data.reset_index(inplace=True)
|
43 |
+
logger.info(f"Data fetched successfully for {ticker}.")
|
44 |
+
return data
|
45 |
+
|
46 |
+
def calculate_indicators(data: pd.DataFrame) -> pd.DataFrame:
|
47 |
+
logger.info("Calculating indicators with fixed parameters.")
|
48 |
+
|
49 |
+
# Check if required columns are present
|
50 |
+
required_columns = ['Close', 'High', 'Low', 'Volume']
|
51 |
+
missing_columns = [col for col in required_columns if col not in data.columns]
|
52 |
+
if missing_columns:
|
53 |
+
logger.error(f"Missing columns in data: {', '.join(missing_columns)}")
|
54 |
+
raise KeyError(f"Missing columns in data: {', '.join(missing_columns)}")
|
55 |
+
|
56 |
+
# Calculate fixed moving averages
|
57 |
+
ma_period = 50 # Fixed period for moving averages
|
58 |
+
try:
|
59 |
+
data[f'SMA_{ma_period}'] = data['Close'].rolling(window=ma_period).mean()
|
60 |
+
data[f'EMA_{ma_period}'] = data['Close'].ewm(span=ma_period, adjust=False).mean()
|
61 |
+
except Exception as e:
|
62 |
+
logger.error(f"Error calculating moving averages: {e}")
|
63 |
+
raise
|
64 |
+
|
65 |
+
# Calculate other indicators
|
66 |
+
try:
|
67 |
+
data['RSI'] = ta.momentum.RSIIndicator(data['Close']).rsi()
|
68 |
+
macd = ta.trend.MACD(data['Close'])
|
69 |
+
data['MACD'] = macd.macd()
|
70 |
+
data['MACD_Signal'] = macd.macd_signal()
|
71 |
+
bollinger = ta.volatility.BollingerBands(data['Close'])
|
72 |
+
data['Bollinger_High'] = bollinger.bollinger_hband()
|
73 |
+
data['Bollinger_Low'] = bollinger.bollinger_lband()
|
74 |
+
data['ATR'] = ta.volatility.AverageTrueRange(data['High'], data['Low'], data['Close']).average_true_range()
|
75 |
+
data['OBV'] = ta.volume.OnBalanceVolumeIndicator(data['Close'], data['Volume']).on_balance_volume()
|
76 |
+
except Exception as e:
|
77 |
+
logger.error(f"Error calculating other indicators: {e}")
|
78 |
+
raise
|
79 |
+
|
80 |
+
# Debugging line to check the columns
|
81 |
+
logger.debug("Columns after calculating indicators: %s", data.columns)
|
82 |
+
|
83 |
+
data = data.dropna()
|
84 |
+
logger.info("Indicators calculated successfully.")
|
85 |
+
return data
|
86 |
+
|
87 |
+
|
88 |
+
# def calculate_indicators(data: pd.DataFrame, ma_type='SMA', ma_period=50) -> pd.DataFrame:
|
89 |
+
# logger.info(f"Calculating indicators with {ma_type} of period {ma_period}.")
|
90 |
+
|
91 |
+
# # Check if required columns are present
|
92 |
+
# required_columns = ['Close', 'High', 'Low', 'Volume']
|
93 |
+
# missing_columns = [col for col in required_columns if col not in data.columns]
|
94 |
+
# if missing_columns:
|
95 |
+
# logger.error(f"Missing columns in data: {', '.join(missing_columns)}")
|
96 |
+
# raise KeyError(f"Missing columns in data: {', '.join(missing_columns)}")
|
97 |
+
|
98 |
+
# # Calculate moving averages
|
99 |
+
# if ma_type == 'SMA':
|
100 |
+
# data[f'SMA_{ma_period}'] = data['Close'].rolling(window=ma_period).mean()
|
101 |
+
# elif ma_type == 'EMA':
|
102 |
+
# data[f'EMA_{ma_period}'] = data['Close'].ewm(span=ma_period, adjust=False).mean()
|
103 |
+
# else:
|
104 |
+
# logger.error(f"Unknown moving average type: {ma_type}")
|
105 |
+
# raise ValueError(f"Unknown moving average type: {ma_type}")
|
106 |
+
|
107 |
+
# # Calculate other indicators
|
108 |
+
# try:
|
109 |
+
# data['RSI'] = ta.momentum.RSIIndicator(data['Close']).rsi()
|
110 |
+
# macd = ta.trend.MACD(data['Close'])
|
111 |
+
# data['MACD'] = macd.macd()
|
112 |
+
# data['MACD_Signal'] = macd.macd_signal()
|
113 |
+
# bollinger = ta.volatility.BollingerBands(data['Close'])
|
114 |
+
# data['Bollinger_High'] = bollinger.bollinger_hband()
|
115 |
+
# data['Bollinger_Low'] = bollinger.bollinger_lband()
|
116 |
+
# data['ATR'] = ta.volatility.AverageTrueRange(data['High'], data['Low'], data['Close']).average_true_range()
|
117 |
+
# data['OBV'] = ta.volume.OnBalanceVolumeIndicator(data['Close'], data['Volume']).on_balance_volume()
|
118 |
+
# except Exception as e:
|
119 |
+
# logger.error(f"Error calculating indicators: {e}")
|
120 |
+
# raise
|
121 |
+
|
122 |
+
# Debugging line to check the columns
|
123 |
+
logger.debug("Columns after calculating indicators: %s", data.columns)
|
124 |
+
|
125 |
+
data = data.dropna()
|
126 |
+
logger.info("Indicators calculated successfully.")
|
127 |
+
return data
|
128 |
+
|
129 |
+
|
130 |
+
# # Calculate technical indicators
|
131 |
+
# def calculate_indicators(data, ma_type='SMA', ma_period=50):
|
132 |
+
# logger.info(f"Calculating indicators with {ma_type} of period {ma_period}.")
|
133 |
+
|
134 |
+
# if ma_type == 'SMA':
|
135 |
+
# data[f'SMA_{ma_period}'] = data['Close'].rolling(window=ma_period).mean()
|
136 |
+
# elif ma_type == 'EMA':
|
137 |
+
# data[f'EMA_{ma_period}'] = data['Close'].ewm(span=ma_period, adjust=False).mean()
|
138 |
+
|
139 |
+
# data['RSI'] = ta.momentum.RSIIndicator(data['Close']).rsi()
|
140 |
+
# macd = ta.trend.MACD(data['Close'])
|
141 |
+
# data['MACD'] = macd.macd()
|
142 |
+
# data['MACD_Signal'] = macd.macd_signal()
|
143 |
+
# bollinger = ta.volatility.BollingerBands(data['Close'])
|
144 |
+
# data['Bollinger_High'] = bollinger.bollinger_hband()
|
145 |
+
# data['Bollinger_Low'] = bollinger.bollinger_lband()
|
146 |
+
# data['ATR'] = ta.volatility.AverageTrueRange(data['High'], data['Low'], data['Close']).average_true_range()
|
147 |
+
# data['OBV'] = ta.volume.OnBalanceVolumeIndicator(data['Close'], data['Volume']).on_balance_volume()
|
148 |
+
|
149 |
+
# # Debugging line to check the columns
|
150 |
+
# logger.debug("Columns after calculating indicators: %s", data.columns)
|
151 |
+
|
152 |
+
# data = data.dropna()
|
153 |
+
# logger.info("Indicators calculated successfully.")
|
154 |
+
# return data
|
155 |
+
|
156 |
+
# Calculate support and resistance levels
|
157 |
+
def calculate_support_resistance(data, window=30):
|
158 |
+
logger.info(f"Calculating support and resistance with a window of {window}.")
|
159 |
+
|
160 |
+
recent_data = data.tail(window)
|
161 |
+
rolling_max = data['Close'].rolling(window=window).max()
|
162 |
+
rolling_min = data['Close'].rolling(window=window).min()
|
163 |
+
recent_max = recent_data['Close'].max()
|
164 |
+
recent_min = recent_data['Close'].min()
|
165 |
+
|
166 |
+
support = min(rolling_min.iloc[-1], recent_min)
|
167 |
+
resistance = max(rolling_max.iloc[-1], recent_max)
|
168 |
+
|
169 |
+
logger.debug("Support: %f, Resistance: %f", support, resistance)
|
170 |
+
return support, resistance
|
171 |
+
|
172 |
+
# Prepare data for LSTM model
|
173 |
+
def prepare_lstm_data(data):
|
174 |
+
logger.info("Preparing data for LSTM model.")
|
175 |
+
|
176 |
+
features = data[['Open', 'SMA_50', 'EMA_50', 'RSI', 'MACD', 'MACD_Signal', 'Bollinger_High', 'Bollinger_Low', 'ATR', 'OBV']].values
|
177 |
+
target = data['Close'].values
|
178 |
+
scaler = MinMaxScaler()
|
179 |
+
features = scaler.fit_transform(features)
|
180 |
+
|
181 |
+
X, y = [], []
|
182 |
+
for i in range(len(features) - 60):
|
183 |
+
X.append(features[i:i+60])
|
184 |
+
y.append(target[i+60])
|
185 |
+
|
186 |
+
logger.info("Data preparation for LSTM completed.")
|
187 |
+
return np.array(X), np.array(y)
|
188 |
+
|
189 |
+
|
190 |
+
def predict_future_prices(data, algorithm, days=10):
|
191 |
+
logger.info(f"Predicting future prices using {algorithm}.")
|
192 |
+
|
193 |
+
# Check if required columns are present
|
194 |
+
required_columns = ['Open', 'SMA_50', 'EMA_50', 'RSI', 'MACD', 'MACD_Signal', 'Bollinger_High', 'Bollinger_Low', 'ATR', 'OBV']
|
195 |
+
missing_columns = [col for col in required_columns if col not in data.columns]
|
196 |
+
|
197 |
+
if missing_columns:
|
198 |
+
logger.error("Missing columns in data: %s", ', '.join(missing_columns))
|
199 |
+
raise KeyError(f"Missing columns in data: {', '.join(missing_columns)}")
|
200 |
+
|
201 |
+
features = data[required_columns]
|
202 |
+
target = data['Close']
|
203 |
+
|
204 |
+
X_train, X_test, y_train, y_test = train_test_split(features, target, test_size=0.2, random_state=42)
|
205 |
+
|
206 |
+
mae, r2 = None, None # Initialize variables for metrics
|
207 |
+
|
208 |
+
if algorithm == 'Linear Regression':
|
209 |
+
model = LinearRegression()
|
210 |
+
|
211 |
+
elif algorithm == 'Decision Tree':
|
212 |
+
model = DecisionTreeRegressor()
|
213 |
+
|
214 |
+
elif algorithm == 'Random Forest':
|
215 |
+
model = RandomForestRegressor(n_estimators=100)
|
216 |
+
|
217 |
+
elif algorithm == 'XGBoost':
|
218 |
+
model = xgb.XGBRegressor(objective='reg:squarederror', eval_metric='rmse')
|
219 |
+
|
220 |
+
elif algorithm == 'CatBoost':
|
221 |
+
model = CatBoostRegressor(learning_rate=0.1, depth=6, iterations=500, verbose=0)
|
222 |
+
|
223 |
+
elif algorithm == 'LSTM':
|
224 |
+
X, y = prepare_lstm_data(data)
|
225 |
+
model = Sequential()
|
226 |
+
model.add(LSTM(50, return_sequences=True, input_shape=(X.shape[1], X.shape[2])))
|
227 |
+
model.add(LSTM(50))
|
228 |
+
model.add(Dense(1))
|
229 |
+
model.compile(optimizer='adam', loss='mean_squared_error')
|
230 |
+
model.fit(X, y, epochs=10, batch_size=32, verbose=0)
|
231 |
+
last_data_point = np.expand_dims(X[-1], axis=0)
|
232 |
+
future_prices = [model.predict(last_data_point)[0][0] for _ in range(days)]
|
233 |
+
logger.info("Future prices predicted using LSTM model.")
|
234 |
+
return future_prices, None, None, None, None
|
235 |
+
|
236 |
+
elif algorithm == 'ARIMA':
|
237 |
+
model = ARIMA(data['Close'], order=(5, 1, 0))
|
238 |
+
model_fit = model.fit()
|
239 |
+
future_prices = model_fit.forecast(steps=days)
|
240 |
+
|
241 |
+
elif algorithm == 'SARIMA':
|
242 |
+
model = SARIMAX(data['Close'], order=(5, 1, 0), seasonal_order=(1, 1, 0, 12))
|
243 |
+
model_fit = model.fit()
|
244 |
+
future_prices = model_fit.forecast(steps=days)
|
245 |
+
|
246 |
+
else:
|
247 |
+
logger.error("Algorithm not recognized: %s", algorithm)
|
248 |
+
return None, None, None, None, None
|
249 |
+
|
250 |
+
if algorithm in ['Linear Regression', 'Decision Tree', 'Random Forest', 'XGBoost', 'CatBoost']:
|
251 |
+
model.fit(X_train, y_train)
|
252 |
+
predictions = model.predict(X_test)
|
253 |
+
mae = mean_absolute_error(y_test, predictions)
|
254 |
+
r2 = r2_score(y_test, predictions)
|
255 |
+
|
256 |
+
future_prices = []
|
257 |
+
last_data_point = features.iloc[-1].values.reshape(1, -1) # Ensure it's 2D
|
258 |
+
|
259 |
+
for _ in range(days):
|
260 |
+
future_price = model.predict(last_data_point)[0]
|
261 |
+
future_prices.append(future_price)
|
262 |
+
last_data_point = last_data_point + 1 # Update last data point (simplified, better methods should be used)
|
263 |
+
|
264 |
+
logger.info("Future prices predicted using %s model.", algorithm)
|
265 |
+
return future_prices, mae, r2, None, None
|
266 |
+
|
267 |
+
|
268 |
+
|
269 |
+
# def predict_future_prices(data, algorithm, days=10):
|
270 |
+
# logger.info(f"Predicting future prices using {algorithm}.")
|
271 |
+
|
272 |
+
# # Check if required columns are present
|
273 |
+
# required_columns = ['Open', 'SMA_50', 'EMA_50', 'RSI', 'MACD', 'MACD_Signal', 'Bollinger_High', 'Bollinger_Low', 'ATR', 'OBV']
|
274 |
+
# missing_columns = [col for col in required_columns if col not in data.columns]
|
275 |
+
|
276 |
+
# if missing_columns:
|
277 |
+
# logger.error("Missing columns in data: %s", ', '.join(missing_columns))
|
278 |
+
# raise KeyError(f"Missing columns in data: {', '.join(missing_columns)}")
|
279 |
+
|
280 |
+
# features = data[required_columns]
|
281 |
+
# target = data['Close']
|
282 |
+
|
283 |
+
# X_train, X_test, y_train, y_test = train_test_split(features, target, test_size=0.2, random_state=42)
|
284 |
+
|
285 |
+
# if algorithm == 'Linear Regression':
|
286 |
+
# model = LinearRegression()
|
287 |
+
|
288 |
+
# elif algorithm == 'Decision Tree':
|
289 |
+
# model = DecisionTreeRegressor()
|
290 |
+
|
291 |
+
# elif algorithm == 'Random Forest':
|
292 |
+
# model = RandomForestRegressor(n_estimators=100)
|
293 |
+
|
294 |
+
# elif algorithm == 'XGBoost':
|
295 |
+
# model = xgb.XGBRegressor(objective='reg:squarederror', eval_metric='rmse')
|
296 |
+
|
297 |
+
# elif algorithm == 'CatBoost':
|
298 |
+
# model = CatBoostRegressor(learning_rate=0.1, depth=6, iterations=500, verbose=0)
|
299 |
+
|
300 |
+
# elif algorithm == 'LSTM':
|
301 |
+
# X, y = prepare_lstm_data(data)
|
302 |
+
# model = Sequential()
|
303 |
+
# model.add(LSTM(50, return_sequences=True, input_shape=(X.shape[1], X.shape[2])))
|
304 |
+
# model.add(LSTM(50))
|
305 |
+
# model.add(Dense(1))
|
306 |
+
# model.compile(optimizer='adam', loss='mean_squared_error')
|
307 |
+
# model.fit(X, y, epochs=10, batch_size=32, verbose=0)
|
308 |
+
# last_data_point = np.expand_dims(X[-1], axis=0)
|
309 |
+
# future_prices = [model.predict(last_data_point)[0][0] for _ in range(days)]
|
310 |
+
# logger.info("Future prices predicted using LSTM model.")
|
311 |
+
# return future_prices, None, None, None, None
|
312 |
+
|
313 |
+
# elif algorithm == 'ARIMA':
|
314 |
+
# model = ARIMA(data['Close'], order=(5, 1, 0))
|
315 |
+
# model_fit = model.fit()
|
316 |
+
# future_prices = model_fit.forecast(steps=days)
|
317 |
+
|
318 |
+
# elif algorithm == 'SARIMA':
|
319 |
+
# model = SARIMAX(data['Close'], order=(5, 1, 0), seasonal_order=(1, 1, 0, 12))
|
320 |
+
# model_fit = model.fit()
|
321 |
+
# future_prices = model_fit.forecast(steps=days)
|
322 |
+
|
323 |
+
# else:
|
324 |
+
# logger.error("Algorithm not recognized: %s", algorithm)
|
325 |
+
# return None, None, None, None, None
|
326 |
+
|
327 |
+
# if algorithm in ['Linear Regression', 'Decision Tree', 'Random Forest', 'XGBoost', 'CatBoost']:
|
328 |
+
# model.fit(X_train, y_train)
|
329 |
+
# predictions = model.predict(X_test)
|
330 |
+
# mae = mean_absolute_error(y_test, predictions)
|
331 |
+
# r2 = r2_score(y_test, predictions)
|
332 |
+
|
333 |
+
# future_prices = []
|
334 |
+
# last_data_point = features.iloc[-1].values.reshape(1, -1) # Ensure it's 2D
|
335 |
+
|
336 |
+
# for _ in range(days):
|
337 |
+
# future_price = model.predict(last_data_point)[0]
|
338 |
+
# future_prices.append(future_price)
|
339 |
+
# last_data_point = last_data_point + 1 # Update last data point (simplified, better methods should be used)
|
340 |
+
|
341 |
+
# logger.info("Future prices predicted using %s model.", algorithm)
|
342 |
+
# return future_prices, mae, r2, None, None
|
343 |
+
|
344 |
+
|
345 |
+
|
346 |
+
|
347 |
+
|
348 |
+
|
349 |
+
|
350 |
+
|
351 |
+
|
352 |
+
|
353 |
+
|
354 |
+
|
355 |
+
# # Predict future prices using the selected algorithm
|
356 |
+
# def predict_future_prices(data, algorithm, days=10):
|
357 |
+
# logger.info(f"Predicting future prices using {algorithm}.")
|
358 |
+
|
359 |
+
# # Check if required columns are present
|
360 |
+
# required_columns = ['Open', 'SMA_50', 'EMA_50', 'RSI', 'MACD', 'MACD_Signal', 'Bollinger_High', 'Bollinger_Low', 'ATR', 'OBV']
|
361 |
+
# missing_columns = [col for col in required_columns if col not in data.columns]
|
362 |
+
|
363 |
+
# if missing_columns:
|
364 |
+
# logger.error("Missing columns in data: %s", ', '.join(missing_columns))
|
365 |
+
# raise KeyError(f"Missing columns in data: {', '.join(missing_columns)}")
|
366 |
+
|
367 |
+
# features = data[required_columns]
|
368 |
+
# target = data['Close']
|
369 |
+
|
370 |
+
# X_train, X_test, y_train, y_test = train_test_split(features, target, test_size=0.2, random_state=42)
|
371 |
+
|
372 |
+
# if algorithm == 'Linear Regression':
|
373 |
+
# model = LinearRegression()
|
374 |
+
# elif algorithm == 'Decision Tree':
|
375 |
+
# model = DecisionTreeRegressor()
|
376 |
+
# elif algorithm == 'Random Forest':
|
377 |
+
# model = RandomForestRegressor(n_estimators=100)
|
378 |
+
# elif algorithm == 'XGBoost':
|
379 |
+
# model = xgb.XGBRegressor(objective='reg:squarederror', eval_metric='rmse')
|
380 |
+
# elif algorithm == 'CatBoost':
|
381 |
+
# model = CatBoostRegressor(learning_rate=0.1, depth=6, iterations=500, verbose=0)
|
382 |
+
# elif algorithm == 'LSTM':
|
383 |
+
|
384 |
+
# X, y = prepare_lstm_data(data)
|
385 |
+
# model = Sequential()
|
386 |
+
# model.add(LSTM(50, return_sequences=True, input_shape=(X.shape[1], X.shape[2])))
|
387 |
+
# model.add(LSTM(50))
|
388 |
+
# model.add(Dense(1))
|
389 |
+
# model.compile(optimizer='adam', loss='mean_squared_error')
|
390 |
+
# model.fit(X, y, epochs=10, batch_size=32, verbose=0)
|
391 |
+
# last_data_point = np.expand_dims(X[-1], axis=0)
|
392 |
+
# future_prices = [model.predict(last_data_point)[0][0] for _ in range(days)]
|
393 |
+
|
394 |
+
# elif algorithm == 'ARIMA':
|
395 |
+
# model = ARIMA(data['Close'], order=(5, 1, 0))
|
396 |
+
# model_fit = model.fit()
|
397 |
+
# future_prices = model_fit.forecast(steps=10)
|
398 |
+
|
399 |
+
# elif algorithm == 'SARIMA':
|
400 |
+
# model = SARIMAX(data['Close'], order=(5, 1, 0), seasonal_order=(1, 1, 0, 12))
|
401 |
+
# model_fit = model.fit()
|
402 |
+
# forecast = model_fit.forecast(steps=10)
|
403 |
+
|
404 |
+
# logger.info("Future prices predicted using LSTM model.")
|
405 |
+
# return future_prices, None, None, None, None
|
406 |
+
# else:
|
407 |
+
# logger.error("Algorithm not recognized: %s", algorithm)
|
408 |
+
# return None, None, None, None, None
|
409 |
+
|
410 |
+
# model.fit(X_train, y_train)
|
411 |
+
|
412 |
+
# predictions = model.predict(X_test)
|
413 |
+
# mae = mean_absolute_error(y_test, predictions)
|
414 |
+
# r2 = r2_score(y_test, predictions)
|
415 |
+
|
416 |
+
# future_prices = []
|
417 |
+
# last_data_point = features.iloc[-1].values.reshape(1, -1) # Ensure it's 2D
|
418 |
+
|
419 |
+
# for _ in range(days):
|
420 |
+
# future_price = model.predict(last_data_point)[0]
|
421 |
+
# future_prices.append(future_price)
|
422 |
+
# last_data_point = last_data_point + 1 # Update last data point (simplified, better methods should be used)
|
423 |
+
|
424 |
+
# logger.info("Future prices predicted using %s model.", algorithm)
|
425 |
+
# return future_prices, mae, r2, None, None
|
426 |
+
|
427 |
+
# import pandas as pd
|
428 |
+
# import numpy as np
|
429 |
+
# import yfinance as yf
|
430 |
+
# import ta
|
431 |
+
# from sklearn.model_selection import train_test_split
|
432 |
+
# from sklearn.linear_model import LinearRegression
|
433 |
+
# from sklearn.tree import DecisionTreeRegressor
|
434 |
+
# from sklearn.ensemble import RandomForestRegressor
|
435 |
+
# from sklearn.metrics import mean_absolute_error, r2_score
|
436 |
+
# import xgboost as xgb
|
437 |
+
# from catboost import CatBoostRegressor
|
438 |
+
# from tensorflow.keras.models import Sequential
|
439 |
+
# from tensorflow.keras.layers import LSTM, Dense
|
440 |
+
# from sklearn.preprocessing import MinMaxScaler
|
441 |
+
# from statsmodels.tsa.arima_model import ARIMA
|
442 |
+
# from statsmodels.tsa.statespace.sarimax import SARIMAX
|
443 |
+
|
444 |
+
# from logger import get_logger
|
445 |
+
|
446 |
+
# logger = get_logger(__name__)
|
447 |
+
|
448 |
+
# # Fetch historical data
|
449 |
+
# def fetch_data(ticker, start_date, end_date):
|
450 |
+
# logger.info(f"Fetching data for {ticker} from {start_date} to {end_date}")
|
451 |
+
# data = yf.download(ticker, start=start_date, end=end_date)
|
452 |
+
# if data.empty:
|
453 |
+
# logger.warning(f"No data returned for {ticker}.")
|
454 |
+
# return None
|
455 |
+
|
456 |
+
# # Reset index to ensure Date is a column
|
457 |
+
# data.reset_index(inplace=True)
|
458 |
+
# logger.info(f"Data fetched successfully for {ticker}.")
|
459 |
+
# return data
|
460 |
+
|
461 |
+
# def calculate_indicators(data: pd.DataFrame) -> pd.DataFrame:
|
462 |
+
# logger.info("Calculating indicators with fixed parameters.")
|
463 |
+
|
464 |
+
# # Check if required columns are present
|
465 |
+
# required_columns = ['Close', 'High', 'Low', 'Volume']
|
466 |
+
# missing_columns = [col for col in required_columns if col not in data.columns]
|
467 |
+
# if missing_columns:
|
468 |
+
# logger.error(f"Missing columns in data: {', '.join(missing_columns)}")
|
469 |
+
# raise KeyError(f"Missing columns in data: {', '.join(missing_columns)}")
|
470 |
+
|
471 |
+
# # Calculate fixed moving averages
|
472 |
+
# ma_period = 50 # Fixed period for moving averages
|
473 |
+
# try:
|
474 |
+
# data[f'SMA_{ma_period}'] = data['Close'].rolling(window=ma_period).mean()
|
475 |
+
# data[f'EMA_{ma_period}'] = data['Close'].ewm(span=ma_period, adjust=False).mean()
|
476 |
+
# except Exception as e:
|
477 |
+
# logger.error(f"Error calculating moving averages: {e}")
|
478 |
+
# raise
|
479 |
+
|
480 |
+
# # Calculate other indicators
|
481 |
+
# try:
|
482 |
+
# data['RSI'] = ta.momentum.RSIIndicator(data['Close']).rsi()
|
483 |
+
# macd = ta.trend.MACD(data['Close'])
|
484 |
+
# data['MACD'] = macd.macd()
|
485 |
+
# data['MACD_Signal'] = macd.macd_signal()
|
486 |
+
# bollinger = ta.volatility.BollingerBands(data['Close'])
|
487 |
+
# data['Bollinger_High'] = bollinger.bollinger_hband()
|
488 |
+
# data['Bollinger_Low'] = bollinger.bollinger_lband()
|
489 |
+
# data['ATR'] = ta.volatility.AverageTrueRange(data['High'], data['Low'], data['Close']).average_true_range()
|
490 |
+
# data['OBV'] = ta.volume.OnBalanceVolumeIndicator(data['Close'], data['Volume']).on_balance_volume()
|
491 |
+
# except Exception as e:
|
492 |
+
# logger.error(f"Error calculating other indicators: {e}")
|
493 |
+
# raise
|
494 |
+
|
495 |
+
# # Debugging line to check the columns
|
496 |
+
# logger.debug("Columns after calculating indicators: %s", data.columns)
|
497 |
+
|
498 |
+
# data = data.dropna()
|
499 |
+
# logger.info("Indicators calculated successfully.")
|
500 |
+
# return data
|
501 |
+
|
502 |
+
# # Calculate support and resistance levels
|
503 |
+
# def calculate_support_resistance(data, window=30):
|
504 |
+
# logger.info(f"Calculating support and resistance with a window of {window}.")
|
505 |
+
|
506 |
+
# recent_data = data.tail(window)
|
507 |
+
# rolling_max = data['Close'].rolling(window=window).max()
|
508 |
+
# rolling_min = data['Close'].rolling(window=window).min()
|
509 |
+
# recent_max = recent_data['Close'].max()
|
510 |
+
# recent_min = recent_data['Close'].min()
|
511 |
+
|
512 |
+
# support = min(rolling_min.iloc[-1], recent_min)
|
513 |
+
# resistance = max(rolling_max.iloc[-1], recent_max)
|
514 |
+
|
515 |
+
# logger.debug("Support: %f, Resistance: %f", support, resistance)
|
516 |
+
# return support, resistance
|
517 |
+
|
518 |
+
# # Prepare data for LSTM model
|
519 |
+
# def prepare_lstm_data(data):
|
520 |
+
# logger.info("Preparing data for LSTM model.")
|
521 |
+
|
522 |
+
# features = data[['Open', 'SMA_50', 'EMA_50', 'RSI', 'MACD', 'MACD_Signal', 'Bollinger_High', 'Bollinger_Low', 'ATR', 'OBV']].values
|
523 |
+
# target = data['Close'].values
|
524 |
+
# scaler = MinMaxScaler()
|
525 |
+
# features = scaler.fit_transform(features)
|
526 |
+
|
527 |
+
# X, y = [], []
|
528 |
+
# for i in range(len(features) - 60):
|
529 |
+
# X.append(features[i:i+60])
|
530 |
+
# y.append(target[i+60])
|
531 |
+
|
532 |
+
# logger.info("Data preparation for LSTM completed.")
|
533 |
+
# return np.array(X), np.array(y)
|
534 |
+
|
535 |
+
# # Predict future prices using the selected algorithm
|
536 |
+
# def predict_future_prices(data, algorithm, days=10):
|
537 |
+
# logger.info(f"Predicting future prices using {algorithm}.")
|
538 |
+
|
539 |
+
# # Check if required columns are present
|
540 |
+
# required_columns = ['Open', 'SMA_50', 'EMA_50', 'RSI', 'MACD', 'MACD_Signal', 'Bollinger_High', 'Bollinger_Low', 'ATR', 'OBV']
|
541 |
+
# missing_columns = [col for col in required_columns if col not in data.columns]
|
542 |
+
|
543 |
+
# if missing_columns:
|
544 |
+
# logger.error("Missing columns in data: %s", ', '.join(missing_columns))
|
545 |
+
# raise KeyError(f"Missing columns in data: {', '.join(missing_columns)}")
|
546 |
+
|
547 |
+
# features = data[required_columns]
|
548 |
+
# target = data['Close']
|
549 |
+
|
550 |
+
# X_train, X_test, y_train, y_test = train_test_split(features, target, test_size=0.2, random_state=42)
|
551 |
+
|
552 |
+
# if algorithm == 'Linear Regression':
|
553 |
+
# model = LinearRegression()
|
554 |
+
# elif algorithm == 'Decision Tree':
|
555 |
+
# model = DecisionTreeRegressor()
|
556 |
+
# elif algorithm == 'Random Forest':
|
557 |
+
# model = RandomForestRegressor(n_estimators=100)
|
558 |
+
# elif algorithm == 'XGBoost':
|
559 |
+
# model = xgb.XGBRegressor(objective='reg:squarederror', eval_metric='rmse')
|
560 |
+
# elif algorithm == 'CatBoost':
|
561 |
+
# model = CatBoostRegressor(learning_rate=0.1, depth=6, iterations=500, verbose=0)
|
562 |
+
# elif algorithm == 'LSTM':
|
563 |
+
# X, y = prepare_lstm_data(data)
|
564 |
+
# model = Sequential()
|
565 |
+
# model.add(LSTM(50, return_sequences=True, input_shape=(X.shape[1], X.shape[2])))
|
566 |
+
# model.add(LSTM(50))
|
567 |
+
# model.add(Dense(1))
|
568 |
+
# model.compile(optimizer='adam', loss='mean_squared_error')
|
569 |
+
# model.fit(X, y, epochs=10, batch_size=32, verbose=0)
|
570 |
+
# last_data_point = np.expand_dims(X[-1], axis=0)
|
571 |
+
# future_prices = [model.predict(last_data_point)[0][0] for _ in range(days)]
|
572 |
+
|
573 |
+
# logger.info("Future prices predicted using LSTM model.")
|
574 |
+
# return future_prices, None, None, None, None
|
575 |
+
# elif algorithm == 'ARIMA':
|
576 |
+
# model = ARIMA(data['Close'], order=(5, 1, 0))
|
577 |
+
# model_fit = model.fit(disp=0)
|
578 |
+
# forecast = model_fit.forecast(steps=days)[0]
|
579 |
+
|
580 |
+
# mae = mean_absolute_error(target[-days:], forecast[:days])
|
581 |
+
# r2 = r2_score(target[-days:], forecast[:days])
|
582 |
+
|
583 |
+
# logger.info("Future prices predicted using ARIMA model.")
|
584 |
+
# return forecast.tolist(), mae, r2, None, None
|
585 |
+
# elif algorithm == 'SARIMA':
|
586 |
+
# model = SARIMAX(data['Close'], order=(5, 1, 0), seasonal_order=(1, 1, 0, 12))
|
587 |
+
# model_fit = model.fit(disp=0)
|
588 |
+
# forecast = model_fit.forecast(steps=days)
|
589 |
+
|
590 |
+
# mae = mean_absolute_error(target[-days:], forecast[:days])
|
591 |
+
# r2 = r2_score(target[-days:], forecast[:days])
|
592 |
+
|
593 |
+
# logger.info("Future prices predicted using SARIMA model.")
|
594 |
+
# return forecast.tolist(), mae, r2, None, None
|
595 |
+
# else:
|
596 |
+
# logger.error("Algorithm not recognized: %s", algorithm)
|
597 |
+
# return None, None, None, None, None
|
598 |
+
|
599 |
+
# model.fit(X_train, y_train)
|
600 |
+
|
601 |
+
# predictions = model.predict(X_test)
|
602 |
+
# mae = mean_absolute_error(y_test, predictions)
|
603 |
+
# r2 = r2_score(y_test, predictions)
|
604 |
+
|
605 |
+
# future_prices = []
|
606 |
+
# last_data_point = features.iloc[-1].values.reshape(1, -1) # Ensure it's 2D
|
607 |
+
|
608 |
+
# for _ in range(days):
|
609 |
+
# future_price = model.predict(last_data_point)[0]
|
610 |
+
# future_prices.append(future_price)
|
611 |
+
# last_data_point = last_data_point + 1 # Update last data point (simplified, better methods should be used)
|
612 |
+
|
613 |
+
# logger.info("Future prices predicted using %s model.", algorithm)
|
614 |
+
# return future_prices, mae, r2, predictions, y_test
|
615 |
+
|
616 |
+
|
617 |
+
|
618 |
+
|
619 |
+
# # model.py
|
620 |
+
|
621 |
+
# import pandas as pd
|
622 |
+
# import numpy as np
|
623 |
+
# import yfinance as yf
|
624 |
+
# import statsmodels.api as sm
|
625 |
+
# from statsmodels.tsa.arima.model import ARIMA
|
626 |
+
# from statsmodels.tsa.statespace.sarimax import SARIMAX
|
627 |
+
# from sklearn.metrics import mean_absolute_error, r2_score
|
628 |
+
|
629 |
+
# def fetch_data(ticker, start_date, end_date):
|
630 |
+
# try:
|
631 |
+
# df = yf.download(ticker, start=start_date, end=end_date)
|
632 |
+
# return df
|
633 |
+
# except Exception as e:
|
634 |
+
# print(f"An error occurred while fetching data: {e}")
|
635 |
+
# return None
|
636 |
+
|
637 |
+
# def calculate_indicators(data):
|
638 |
+
# # Example indicators - these should be tailored to your requirements
|
639 |
+
# data['SMA_50'] = data['Close'].rolling(window=50).mean()
|
640 |
+
# data['EMA_50'] = data['Close'].ewm(span=50, adjust=False).mean()
|
641 |
+
# data['RSI'] = calculate_rsi(data['Close'])
|
642 |
+
# data['MACD'], data['MACD_Signal'] = calculate_macd(data['Close'])
|
643 |
+
# data['Bollinger_High'], data['Bollinger_Low'] = calculate_bollinger_bands(data['Close'])
|
644 |
+
# data['ATR'] = calculate_atr(data)
|
645 |
+
# data['OBV'] = calculate_obv(data)
|
646 |
+
# return data
|
647 |
+
|
648 |
+
# def calculate_rsi(series, period=14):
|
649 |
+
# delta = series.diff()
|
650 |
+
# gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
|
651 |
+
# loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
|
652 |
+
# rs = gain / loss
|
653 |
+
# return 100 - (100 / (1 + rs))
|
654 |
+
|
655 |
+
# def calculate_macd(series):
|
656 |
+
# macd = series.ewm(span=12, adjust=False).mean() - series.ewm(span=26, adjust=False).mean()
|
657 |
+
# macd_signal = macd.ewm(span=9, adjust=False).mean()
|
658 |
+
# return macd, macd_signal
|
659 |
+
|
660 |
+
# def calculate_bollinger_bands(series, window=20):
|
661 |
+
# rolling_mean = series.rolling(window=window).mean()
|
662 |
+
# rolling_std = series.rolling(window=window).std()
|
663 |
+
# high = rolling_mean + (rolling_std * 2)
|
664 |
+
# low = rolling_mean - (rolling_std * 2)
|
665 |
+
# return high, low
|
666 |
+
|
667 |
+
# def calculate_atr(data, window=14):
|
668 |
+
# high_low = data['High'] - data['Low']
|
669 |
+
# high_close = np.abs(data['High'] - data['Close'].shift())
|
670 |
+
# low_close = np.abs(data['Low'] - data['Close'].shift())
|
671 |
+
# tr = np.max(np.array([high_low, high_close, low_close]), axis=0)
|
672 |
+
# atr = tr.rolling(window=window).mean()
|
673 |
+
# return atr
|
674 |
+
|
675 |
+
# def calculate_obv(data):
|
676 |
+
# obv = (data['Volume'] * np.sign(data['Close'].diff())).fillna(0).cumsum()
|
677 |
+
# return obv
|
678 |
+
|
679 |
+
# def calculate_support_resistance(data):
|
680 |
+
# # Example calculation - you may need to refine this based on your requirements
|
681 |
+
# support = data['Close'].min()
|
682 |
+
# resistance = data['Close'].max()
|
683 |
+
# return support, resistance
|
684 |
+
|
685 |
+
# def predict_future_prices(data, model_type='ARIMA'):
|
686 |
+
# try:
|
687 |
+
# # Use ARIMA
|
688 |
+
# if model_type == 'ARIMA':
|
689 |
+
# model = ARIMA(data['Close'], order=(5, 1, 0))
|
690 |
+
# model_fit = model.fit()
|
691 |
+
# forecast = model_fit.forecast(steps=10)
|
692 |
+
# # Use SARIMA
|
693 |
+
# elif model_type == 'SARIMA':
|
694 |
+
# model = SARIMAX(data['Close'], order=(5, 1, 0), seasonal_order=(1, 1, 0, 12))
|
695 |
+
# model_fit = model.fit()
|
696 |
+
# forecast = model_fit.forecast(steps=10)
|
697 |
+
# else:
|
698 |
+
# raise ValueError("Unsupported model type. Use 'ARIMA' or 'SARIMA'.")
|
699 |
+
|
700 |
+
# # Calculate MAE and R2 for evaluation
|
701 |
+
# y_true = data['Close'][-10:] # last 10 days as true values for comparison
|
702 |
+
# mae = mean_absolute_error(y_true, forecast[:len(y_true)])
|
703 |
+
# r2 = r2_score(y_true, forecast[:len(y_true)])
|
704 |
+
|
705 |
+
# # Return results
|
706 |
+
# return forecast, mae, r2
|
707 |
+
# except Exception as e:
|
708 |
+
# print(f"An error occurred while predicting future prices: {e}")
|
709 |
+
# return None, None, None
|
710 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# yfinance==0.2.28
|
2 |
+
# pandas==2.0.3
|
3 |
+
# ta==0.9.0
|
4 |
+
# scikit-learn==1.3.0
|
5 |
+
# xgboost==2.0.0
|
6 |
+
# catboost==1.2.0
|
7 |
+
# numpy==1.24.2
|
8 |
+
# tensorflow==2.13.0
|
9 |
+
# streamlit==1.23.0
|
10 |
+
# mplfinance
|
11 |
+
|
12 |
+
|
13 |
+
matplotlib
|
14 |
+
pandas
|
15 |
+
seaborn
|
16 |
+
numpy
|
17 |
+
mplfinance
|
18 |
+
scipy
|
19 |
+
yfinance
|
20 |
+
scikit-learn
|
21 |
+
streamlit
|
22 |
+
plotly
|
23 |
+
ta
|
24 |
+
xgboost
|
25 |
+
catboost
|
26 |
+
tensorflow
|
27 |
+
statsmodels
|
28 |
+
|
29 |
+
|
30 |
+
# streamlit
|
31 |
+
# mplfinance
|
32 |
+
|
33 |
+
# matplotlib==3.7.1
|
34 |
+
# pandas==2.0.3
|
35 |
+
# seaborn==0.12.2
|
36 |
+
# numpy==1.25.2
|
37 |
+
# mplfinance==0.12.9b7
|
38 |
+
# scipy==1.11.1
|
39 |
+
# yfinance==0.2.30
|
40 |
+
# scikit-learn==1.3.0
|
41 |
+
# streamlit==1.25.0
|
42 |
+
# plotly==5.17.0
|
43 |
+
# pytorch==2.0.1
|
44 |
+
|
45 |
+
# transformers==4.30.2
|
46 |
+
# chroma==1.2.0
|
ui.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
from model import fetch_data, calculate_indicators, calculate_support_resistance, predict_future_prices
|
4 |
+
from visualizations import (
|
5 |
+
plot_stock_price, plot_predictions, plot_technical_indicators, plot_risk_levels,
|
6 |
+
plot_feature_importance, plot_candlestick, plot_volume, plot_moving_averages,
|
7 |
+
plot_feature_correlations
|
8 |
+
)
|
9 |
+
|
10 |
+
def sidebar():
|
11 |
+
st.sidebar.title("Stock Analysis Dashboard")
|
12 |
+
ticker = st.sidebar.text_input("Enter Stock Ticker Symbol:", value='SBILIFE.NS')
|
13 |
+
start_date = st.sidebar.date_input("Start Date", value=pd.to_datetime('2021-01-01'))
|
14 |
+
end_date = st.sidebar.date_input("End Date", value=pd.to_datetime('2024-09-01'))
|
15 |
+
algorithm = st.sidebar.selectbox("Select Prediction Algorithm", ['Linear Regression', 'ARIMA','Decision Tree', 'Random Forest', 'XGBoost', 'CatBoost', 'LSTM', 'SARIMA'])
|
16 |
+
return ticker, start_date, end_date, algorithm
|
17 |
+
|
18 |
+
def display_analysis(data, algorithm):
|
19 |
+
if data is not None:
|
20 |
+
try:
|
21 |
+
support_price, resistance_price = calculate_support_resistance(data)
|
22 |
+
future_prices, mae, r2, accuracy, conf_matrix = predict_future_prices(data, algorithm)
|
23 |
+
|
24 |
+
if future_prices is not None:
|
25 |
+
st.write("### Technical Indicators")
|
26 |
+
indicators = {
|
27 |
+
'SMA_50': data['SMA_50'].iloc[-1],
|
28 |
+
'EMA_50': data['EMA_50'].iloc[-1],
|
29 |
+
'RSI': data['RSI'].iloc[-1],
|
30 |
+
'MACD': data['MACD'].iloc[-1],
|
31 |
+
'MACD_Signal': data['MACD_Signal'].iloc[-1],
|
32 |
+
'Bollinger_High': data['Bollinger_High'].iloc[-1],
|
33 |
+
'Bollinger_Low': data['Bollinger_Low'].iloc[-1],
|
34 |
+
'ATR': data['ATR'].iloc[-1],
|
35 |
+
'OBV': data['OBV'].iloc[-1]
|
36 |
+
}
|
37 |
+
for key, value in indicators.items():
|
38 |
+
with st.expander(f"{key} Description"):
|
39 |
+
st.write(f"{key}: {value:.2f}")
|
40 |
+
st.write(get_indicator_description(key))
|
41 |
+
|
42 |
+
st.write("### Support and Resistance Levels")
|
43 |
+
st.write(f"Support Price: {support_price:.2f}")
|
44 |
+
st.write(f"Resistance Price: {resistance_price:.2f}")
|
45 |
+
|
46 |
+
st.write("### Future Price Predictions")
|
47 |
+
st.write(pd.DataFrame({'Day': range(1, len(future_prices) + 1), 'Predicted Price': future_prices}))
|
48 |
+
|
49 |
+
if accuracy is not None and conf_matrix is not None:
|
50 |
+
st.write(f"**Model Accuracy:** {accuracy:.2f}")
|
51 |
+
st.write("**Confusion Matrix:**")
|
52 |
+
st.pyplot(ConfusionMatrixDisplay(conf_matrix).plot())
|
53 |
+
|
54 |
+
if mae is not None and r2 is not None:
|
55 |
+
st.write(f"**Mean Absolute Error (MAE):** {mae:.2f}")
|
56 |
+
st.write(f"**R-squared (R2):** {r2:.2f}")
|
57 |
+
else:
|
58 |
+
st.error("Model selection or prediction failed. Please check your inputs and try again.")
|
59 |
+
except Exception as e:
|
60 |
+
st.error(f"An error occurred during analysis: {e}")
|
61 |
+
else:
|
62 |
+
st.error("Failed to fetch data. Please check the stock ticker symbol and date range.")
|
63 |
+
|
64 |
+
def get_indicator_description(indicator):
|
65 |
+
descriptions = {
|
66 |
+
'SMA_50': "SMA_50 (50-day Simple Moving Average): Yeh 50 din ka average hai jo bataata hai stock ka long-term trend. Agar yeh price line se upar hai, toh stock ka trend upward hai.",
|
67 |
+
'EMA_50': "EMA_50 (50-day Exponential Moving Average): Yeh bhi ek average hai lekin recent prices ko zyada weightage deta hai. Stock ka short-term trend dikhata hai.",
|
68 |
+
'RSI': "RSI (Relative Strength Index): Yeh indicator stock ke overbought ya oversold condition ko dikhata hai. 70 se zyada overbought, aur 30 se kam oversold hai.",
|
69 |
+
'MACD': "MACD: Yeh indicator short-term aur long-term moving averages ke beech ka difference dikhata hai.",
|
70 |
+
'MACD_Signal': "MACD Signal: Yeh line MACD ke signal ko dikhata hai. Jab MACD line isse cross karti hai, toh trend change hota hai.",
|
71 |
+
'Bollinger_High': "Bollinger High: Yeh line stock price ki upper boundary dikhati hai. Agar price isse upar hai, toh stock overbought ho sakta hai.",
|
72 |
+
'Bollinger_Low': "Bollinger Low: Yeh line stock price ki lower boundary dikhati hai. Agar price isse neeche hai, toh stock oversold ho sakta hai.",
|
73 |
+
'ATR': "ATR (Average True Range): Yeh indicator stock ki volatility dikhata hai. Zyada ATR matlab zyada price fluctuations.",
|
74 |
+
'OBV': "OBV (On-Balance Volume): Yeh volume aur price ke relationship ko dikhata hai. Jab OBV badh raha hai, toh stock ka demand badh raha hai."
|
75 |
+
}
|
76 |
+
return descriptions.get(indicator, "Description not available")
|
77 |
+
|
78 |
+
def display_visualizations(data, algorithm):
|
79 |
+
if data is not None:
|
80 |
+
choice = st.sidebar.selectbox(
|
81 |
+
"Choose a type of visualization",
|
82 |
+
[
|
83 |
+
"Stock Price",
|
84 |
+
"Predictions vs Actual",
|
85 |
+
"Technical Indicators",
|
86 |
+
"Risk Levels",
|
87 |
+
"Feature Importance",
|
88 |
+
"Candlestick",
|
89 |
+
"Volume",
|
90 |
+
"Moving Averages",
|
91 |
+
"Feature Correlations"
|
92 |
+
]
|
93 |
+
)
|
94 |
+
|
95 |
+
try:
|
96 |
+
if choice == "Stock Price":
|
97 |
+
plot_stock_price(data)
|
98 |
+
elif choice == "Predictions vs Actual":
|
99 |
+
future_prices, _, _, _, _ = predict_future_prices(data, algorithm)
|
100 |
+
if future_prices is not None:
|
101 |
+
st.line_chart(pd.DataFrame({'Actual Prices': data['Close'], 'Predicted Prices': pd.Series(future_prices).values}))
|
102 |
+
else:
|
103 |
+
st.error("Failed to fetch predictions.")
|
104 |
+
elif choice == "Technical Indicators":
|
105 |
+
indicators = {
|
106 |
+
'SMA_50': data['SMA_50'],
|
107 |
+
'EMA_50': data['EMA_50'],
|
108 |
+
'RSI': data['RSI'],
|
109 |
+
'MACD': data['MACD'],
|
110 |
+
'MACD_Signal': data['MACD_Signal'],
|
111 |
+
'Bollinger_High': data['Bollinger_High'],
|
112 |
+
'Bollinger_Low': data['Bollinger_Low'],
|
113 |
+
'ATR': data['ATR'],
|
114 |
+
'OBV': data['OBV']
|
115 |
+
}
|
116 |
+
plot_technical_indicators(data, indicators)
|
117 |
+
elif choice == "Risk Levels":
|
118 |
+
plot_risk_levels(data)
|
119 |
+
elif choice == "Feature Importance":
|
120 |
+
plot_feature_importance()
|
121 |
+
elif choice == "Candlestick":
|
122 |
+
plot_candlestick(data)
|
123 |
+
elif choice == "Volume":
|
124 |
+
plot_volume(data)
|
125 |
+
elif choice == "Moving Averages":
|
126 |
+
plot_moving_averages(data)
|
127 |
+
elif choice == "Feature Correlations":
|
128 |
+
plot_feature_correlations(data)
|
129 |
+
except Exception as e:
|
130 |
+
st.error(f"An error occurred during visualization: {e}")
|
131 |
+
else:
|
132 |
+
st.error("Failed to fetch data. Please check the stock ticker symbol and date range.")
|
visualizations.py
ADDED
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.pyplot as plt
|
2 |
+
import pandas as pd
|
3 |
+
import seaborn as sns
|
4 |
+
import numpy as np
|
5 |
+
import matplotlib.dates as mdates
|
6 |
+
from mplfinance.original_flavor import candlestick_ohlc
|
7 |
+
import logging
|
8 |
+
import plotly.express as px
|
9 |
+
import streamlit as st
|
10 |
+
|
11 |
+
from model import predict_future_prices
|
12 |
+
|
13 |
+
from logger import get_logger
|
14 |
+
|
15 |
+
logger = get_logger(__name__)
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
def plot_stock_price(data: pd.DataFrame, ticker: str, indicators: dict = None,
|
21 |
+
color='blue', line_style='-', title=None):
|
22 |
+
"""
|
23 |
+
Plot the stock price with optional indicators and customization.
|
24 |
+
"""
|
25 |
+
required_columns = ['Date', 'Close']
|
26 |
+
missing_columns = [col for col in required_columns if col not in data.columns]
|
27 |
+
if missing_columns:
|
28 |
+
logger.error(f"Missing columns in data for plot_stock_price: {', '.join(missing_columns)}")
|
29 |
+
raise KeyError(f"Missing columns in data: {', '.join(missing_columns)}")
|
30 |
+
|
31 |
+
logger.info(f"Plotting stock price for {ticker}.")
|
32 |
+
|
33 |
+
# Matplotlib Plot
|
34 |
+
plt.figure(figsize=(14, 7))
|
35 |
+
plt.plot(data['Date'], data['Close'], label='Close Price', color=color, linestyle=line_style)
|
36 |
+
|
37 |
+
if indicators:
|
38 |
+
for name, values in indicators.items():
|
39 |
+
plt.plot(data['Date'], values, label=name)
|
40 |
+
|
41 |
+
plt.title(title if title else f'{ticker} Stock Price')
|
42 |
+
plt.xlabel('Date')
|
43 |
+
plt.ylabel('Price')
|
44 |
+
plt.legend()
|
45 |
+
plt.grid(True)
|
46 |
+
plt.tight_layout()
|
47 |
+
plt.xticks(rotation=45)
|
48 |
+
|
49 |
+
# Render the plot using Streamlit
|
50 |
+
st.pyplot(plt)
|
51 |
+
|
52 |
+
# Plotly Plot (interactive)
|
53 |
+
fig = px.line(data, x='Date', y='Close', title=title if title else f'{ticker} Stock Price')
|
54 |
+
if indicators:
|
55 |
+
for name, values in indicators.items():
|
56 |
+
fig.add_scatter(x=data['Date'], y=values, mode='lines', name=name)
|
57 |
+
|
58 |
+
# Render the interactive plot using Streamlit
|
59 |
+
st.plotly_chart(fig)
|
60 |
+
|
61 |
+
def plot_predictions(data: pd.DataFrame, predictions: pd.Series, ticker: str,
|
62 |
+
actual_color='blue', predicted_color='red', line_style_actual='-', line_style_predicted='--'):
|
63 |
+
"""
|
64 |
+
Plot actual vs predicted stock prices with customization.
|
65 |
+
"""
|
66 |
+
logger.info(f"Plotting actual vs predicted prices for {ticker}.")
|
67 |
+
|
68 |
+
# Matplotlib Plot
|
69 |
+
plt.figure(figsize=(14, 7))
|
70 |
+
plt.plot(data['Date'], data['Close'], label='Actual Prices', color=actual_color, linestyle=line_style_actual)
|
71 |
+
plt.plot(data['Date'], predictions, label='Predicted Prices', color=predicted_color, linestyle=line_style_predicted)
|
72 |
+
|
73 |
+
plt.title(f'{ticker} Actual vs Predicted Prices')
|
74 |
+
plt.xlabel('Date')
|
75 |
+
plt.ylabel('Price')
|
76 |
+
plt.legend()
|
77 |
+
plt.grid(True)
|
78 |
+
plt.tight_layout()
|
79 |
+
plt.xticks(rotation=45)
|
80 |
+
|
81 |
+
# Render the plot using Streamlit
|
82 |
+
st.pyplot(plt)
|
83 |
+
|
84 |
+
# Plotly Plot (interactive)
|
85 |
+
fig = px.line(data, x='Date', y='Close', title=f'{ticker} Actual vs Predicted Prices')
|
86 |
+
fig.add_scatter(x=data['Date'], y=predictions, mode='lines', name='Predicted Prices', line=dict(color=predicted_color))
|
87 |
+
|
88 |
+
# Render the interactive plot using Streamlit
|
89 |
+
st.plotly_chart(fig)
|
90 |
+
|
91 |
+
def generate_predictions(model, test_data):
|
92 |
+
"""
|
93 |
+
Generate predictions using the model for the given test data.
|
94 |
+
"""
|
95 |
+
try:
|
96 |
+
# Extract relevant features for the model
|
97 |
+
|
98 |
+
features = test_data[['Open', 'SMA_50', 'EMA_50', 'RSI', 'MACD', 'MACD_Signal', 'Bollinger_High', 'Bollinger_Low', 'ATR', 'OBV']] # Adjust features based on your model
|
99 |
+
predictions = model.predict(features)
|
100 |
+
return predictions
|
101 |
+
except KeyError as e:
|
102 |
+
logger.error(f"Feature key error: {e}")
|
103 |
+
st.error(f"Feature key error: {e}")
|
104 |
+
except Exception as e:
|
105 |
+
logger.error(f"An error occurred during prediction: {e}")
|
106 |
+
st.error(f"An error occurred during prediction: {e}")
|
107 |
+
|
108 |
+
|
109 |
+
|
110 |
+
|
111 |
+
|
112 |
+
def plot_technical_indicators(data: pd.DataFrame, indicators: dict, model, days=10):
|
113 |
+
"""
|
114 |
+
Plot technical indicators along with the stock price and predictions.
|
115 |
+
"""
|
116 |
+
logger.info("Plotting stock price with technical indicators and predictions.")
|
117 |
+
|
118 |
+
# Ensure all indicators have the same length as the data
|
119 |
+
for name, values in indicators.items():
|
120 |
+
if len(values) != len(data):
|
121 |
+
logger.error(f"Indicator '{name}' length {len(values)} does not match data length {len(data)}.")
|
122 |
+
st.error(f"Indicator '{name}' length {len(values)} does not match data length {len(data)}.")
|
123 |
+
return
|
124 |
+
|
125 |
+
# Generate the last 30 days' dates
|
126 |
+
end_date = data['Date'].max()
|
127 |
+
start_date = end_date - pd.Timedelta(days=30)
|
128 |
+
date_range = pd.date_range(start=start_date, end=end_date, freq='D')
|
129 |
+
|
130 |
+
# Filter data for the last 30 days
|
131 |
+
last_30_days_data = data[data['Date'].isin(date_range)]
|
132 |
+
|
133 |
+
# Prepare test data for predictions
|
134 |
+
test_data = last_30_days_data.copy()
|
135 |
+
|
136 |
+
# Generate future predictions
|
137 |
+
future_prices, _, _, _, _ = predict_future_prices(data, model, days)
|
138 |
+
|
139 |
+
if future_prices is not None:
|
140 |
+
# Generate future dates
|
141 |
+
future_dates = pd.date_range(start=end_date + pd.Timedelta(days=1), periods=days, freq='D')
|
142 |
+
|
143 |
+
# Create a DataFrame for future predictions
|
144 |
+
future_df = pd.DataFrame({
|
145 |
+
'Date': future_dates,
|
146 |
+
'Predicted_Close': future_prices
|
147 |
+
})
|
148 |
+
|
149 |
+
# Matplotlib Plot
|
150 |
+
plt.figure(figsize=(14, 7))
|
151 |
+
plt.plot(data['Date'], data['Close'], label='Close Price', color='blue')
|
152 |
+
plt.plot(future_df['Date'], future_df['Predicted_Close'], label='Predicted Price', color='orange', linestyle='--')
|
153 |
+
|
154 |
+
for name, values in indicators.items():
|
155 |
+
plt.plot(data['Date'], values, label=name)
|
156 |
+
|
157 |
+
plt.title('Stock Price with Technical Indicators and Predictions')
|
158 |
+
plt.xlabel('Date')
|
159 |
+
plt.ylabel('Price')
|
160 |
+
plt.legend()
|
161 |
+
plt.grid(True)
|
162 |
+
plt.tight_layout()
|
163 |
+
plt.xticks(rotation=45)
|
164 |
+
|
165 |
+
# Render the plot using Streamlit
|
166 |
+
st.pyplot(plt)
|
167 |
+
|
168 |
+
# Plotly Plot (interactive)
|
169 |
+
fig = px.line(data, x='Date', y='Close', title='Stock Price with Technical Indicators and Predictions')
|
170 |
+
fig.add_scatter(x=future_df['Date'], y=future_df['Predicted_Close'], mode='lines', name='Predicted Price', line=dict(color='orange', dash='dash'))
|
171 |
+
|
172 |
+
for name, values in indicators.items():
|
173 |
+
fig.add_scatter(x=data['Date'], y=values, mode='lines', name=name)
|
174 |
+
|
175 |
+
# Render the interactive plot using Streamlit
|
176 |
+
st.plotly_chart(fig)
|
177 |
+
else:
|
178 |
+
st.error("No predictions available.")
|
179 |
+
|
180 |
+
|
181 |
+
|
182 |
+
|
183 |
+
|
184 |
+
|
185 |
+
def plot_risk_levels(data: pd.DataFrame, risk_levels: pd.Series, cmap='coolwarm'):
|
186 |
+
"""
|
187 |
+
Plot risk levels with stock prices and customization.
|
188 |
+
"""
|
189 |
+
logger.info("Plotting stock prices with risk levels.")
|
190 |
+
plt.figure(figsize=(14, 7))
|
191 |
+
plt.plot(data['Date'], data['Close'], label='Close Price', color='blue')
|
192 |
+
plt.scatter(data['Date'], data['Close'], c=risk_levels, cmap=cmap, label='Risk Levels', alpha=0.7)
|
193 |
+
|
194 |
+
plt.title('Stock Prices with Risk Levels')
|
195 |
+
plt.xlabel('Date')
|
196 |
+
plt.ylabel('Price')
|
197 |
+
plt.colorbar(label='Risk Level')
|
198 |
+
plt.legend()
|
199 |
+
plt.grid(True)
|
200 |
+
plt.tight_layout()
|
201 |
+
plt.xticks(rotation=45)
|
202 |
+
|
203 |
+
# Render Matplotlib plot using Streamlit
|
204 |
+
st.pyplot(plt)
|
205 |
+
|
206 |
+
# Plotly Plot (interactive)
|
207 |
+
fig = px.scatter(data, x='Date', y='Close', color=risk_levels, color_continuous_scale=cmap,
|
208 |
+
title='Stock Prices with Risk Levels', labels={'color': 'Risk Level'})
|
209 |
+
|
210 |
+
# Render the interactive Plotly plot using Streamlit
|
211 |
+
st.plotly_chart(fig)
|
212 |
+
|
213 |
+
def plot_feature_importance(importances: pd.Series, feature_names: list):
|
214 |
+
"""
|
215 |
+
Plot feature importance for machine learning models.
|
216 |
+
"""
|
217 |
+
logger.info("Plotting feature importance.")
|
218 |
+
plt.figure(figsize=(10, 6))
|
219 |
+
sns.barplot(x=importances, y=feature_names, palette='viridis')
|
220 |
+
|
221 |
+
plt.title('Feature Importances')
|
222 |
+
plt.xlabel('Importance')
|
223 |
+
plt.ylabel('Feature')
|
224 |
+
plt.grid(True)
|
225 |
+
plt.tight_layout()
|
226 |
+
|
227 |
+
# Render Matplotlib plot using Streamlit
|
228 |
+
st.pyplot(plt)
|
229 |
+
|
230 |
+
# Plotly Plot (interactive)
|
231 |
+
fig = px.bar(x=importances, y=feature_names, orientation='h',
|
232 |
+
title='Feature Importances', labels={'x': 'Importance', 'y': 'Feature'})
|
233 |
+
fig.update_layout(yaxis={'categoryorder':'total ascending'})
|
234 |
+
|
235 |
+
# Render the interactive Plotly plot using Streamlit
|
236 |
+
st.plotly_chart(fig)
|
237 |
+
|
238 |
+
def plot_candlestick(data: pd.DataFrame, ticker: str):
|
239 |
+
"""
|
240 |
+
Plot candlestick chart for stock prices.
|
241 |
+
"""
|
242 |
+
required_columns = ['Date', 'Open', 'High', 'Low', 'Close']
|
243 |
+
|
244 |
+
# Check if all required columns are present
|
245 |
+
missing_columns = [col for col in required_columns if col not in data.columns]
|
246 |
+
if missing_columns:
|
247 |
+
logger.error(f"Missing columns in data for plot_candlestick: {', '.join(missing_columns)}")
|
248 |
+
raise KeyError(f"Missing columns in data: {', '.join(missing_columns)}")
|
249 |
+
|
250 |
+
logger.info(f"Plotting candlestick chart for {ticker}.")
|
251 |
+
data = data[required_columns]
|
252 |
+
data['Date'] = pd.to_datetime(data['Date'])
|
253 |
+
data['Date'] = mdates.date2num(data['Date'])
|
254 |
+
|
255 |
+
fig, ax = plt.subplots(figsize=(14, 7))
|
256 |
+
candlestick_ohlc(ax, data.values, width=0.6, colorup='green', colordown='red')
|
257 |
+
|
258 |
+
ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
|
259 |
+
plt.title(f'{ticker} Candlestick Chart')
|
260 |
+
plt.xlabel('Date')
|
261 |
+
plt.ylabel('Price')
|
262 |
+
plt.grid(True)
|
263 |
+
plt.xticks(rotation=45)
|
264 |
+
plt.tight_layout()
|
265 |
+
|
266 |
+
# Render Matplotlib plot using Streamlit
|
267 |
+
st.pyplot(fig)
|
268 |
+
|
269 |
+
# Plotly Plot (interactive)
|
270 |
+
fig = px.line(data, x='Date', y=['Open', 'High', 'Low', 'Close'],
|
271 |
+
title=f'{ticker} Candlestick Chart')
|
272 |
+
|
273 |
+
# Render the interactive Plotly plot using Streamlit
|
274 |
+
st.plotly_chart(fig)
|
275 |
+
|
276 |
+
def plot_volume(data: pd.DataFrame):
|
277 |
+
"""
|
278 |
+
Plot trading volume alongside stock price.
|
279 |
+
"""
|
280 |
+
logger.info("Plotting stock price and trading volume.")
|
281 |
+
|
282 |
+
plt.figure(figsize=(14, 7))
|
283 |
+
plt.subplot(2, 1, 1)
|
284 |
+
plt.plot(data['Date'], data['Close'], label='Close Price', color='blue')
|
285 |
+
plt.title('Stock Price and Trading Volume')
|
286 |
+
plt.xlabel('Date')
|
287 |
+
plt.ylabel('Price')
|
288 |
+
plt.legend()
|
289 |
+
plt.grid(True)
|
290 |
+
|
291 |
+
plt.subplot(2, 1, 2)
|
292 |
+
plt.bar(data['Date'], data['Volume'], color='grey', alpha=0.5)
|
293 |
+
plt.xlabel('Date')
|
294 |
+
plt.ylabel('Volume')
|
295 |
+
|
296 |
+
plt.tight_layout()
|
297 |
+
plt.xticks(rotation=45)
|
298 |
+
|
299 |
+
# Render Matplotlib plot using Streamlit
|
300 |
+
st.pyplot(plt)
|
301 |
+
|
302 |
+
# Plotly Plot (interactive)
|
303 |
+
fig = px.bar(data, x='Date', y='Volume', title='Trading Volume',
|
304 |
+
labels={'Volume': 'Volume', 'Date': 'Date'})
|
305 |
+
|
306 |
+
# Render the interactive Plotly plot using Streamlit
|
307 |
+
st.plotly_chart(fig)
|
308 |
+
|
309 |
+
def plot_moving_averages(data: pd.DataFrame, short_window: int = 20, long_window: int = 50):
|
310 |
+
"""
|
311 |
+
Plot moving averages along with the stock price.
|
312 |
+
"""
|
313 |
+
logger.info("Calculating and plotting moving averages.")
|
314 |
+
data['Short_MA'] = data['Close'].rolling(window=short_window).mean()
|
315 |
+
data['Long_MA'] = data['Close'].rolling(window=long_window).mean()
|
316 |
+
|
317 |
+
plt.figure(figsize=(14, 7))
|
318 |
+
plt.plot(data['Date'], data['Close'], label='Close Price', color='blue')
|
319 |
+
plt.plot(data['Date'], data['Short_MA'], label=f'Short {short_window}-day MA', color='orange')
|
320 |
+
plt.plot(data['Date'], data['Long_MA'], label=f'Long {long_window}-day MA', color='purple')
|
321 |
+
|
322 |
+
plt.title('Stock Price with Moving Averages')
|
323 |
+
plt.xlabel('Date')
|
324 |
+
plt.ylabel('Price')
|
325 |
+
plt.legend()
|
326 |
+
plt.grid(True)
|
327 |
+
plt.tight_layout()
|
328 |
+
plt.xticks(rotation=45)
|
329 |
+
|
330 |
+
# Render Matplotlib plot using Streamlit
|
331 |
+
st.pyplot(plt)
|
332 |
+
|
333 |
+
# Plotly Plot (interactive)
|
334 |
+
fig = px.line(data, x='Date', y=['Close', 'Short_MA', 'Long_MA'],
|
335 |
+
title='Stock Price with Moving Averages')
|
336 |
+
|
337 |
+
# Render the interactive Plotly plot using Streamlit
|
338 |
+
st.plotly_chart(fig)
|
339 |
+
|
340 |
+
def plot_feature_correlations(data: pd.DataFrame):
|
341 |
+
"""
|
342 |
+
Plot correlation heatmap of features.
|
343 |
+
"""
|
344 |
+
logger.info("Plotting feature correlations heatmap.")
|
345 |
+
plt.figure(figsize=(12, 10))
|
346 |
+
correlation_matrix = data.corr()
|
347 |
+
sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', fmt='.2f')
|
348 |
+
|
349 |
+
plt.title('Feature Correlations')
|
350 |
+
plt.tight_layout()
|
351 |
+
|
352 |
+
# Render Matplotlib plot using Streamlit
|
353 |
+
st.pyplot(plt)
|
354 |
+
|
355 |
+
# Plotly Plot (interactive)
|
356 |
+
fig = px.imshow(correlation_matrix, text_auto=True,
|
357 |
+
title='Feature Correlations', labels={'color': 'Correlation'})
|
358 |
+
|
359 |
+
# Render the interactive Plotly plot using Streamlit
|
360 |
+
st.plotly_chart(fig)
|