NandanData commited on
Commit
42c2fbe
·
verified ·
1 Parent(s): 2d90e8b

Upload 9 files

Browse files
Files changed (9) hide show
  1. README.md +19 -11
  2. app.py +282 -0
  3. dashboard.py +221 -0
  4. llm.py +220 -0
  5. logger.py +15 -0
  6. model.py +710 -0
  7. requirements.txt +46 -0
  8. ui.py +132 -0
  9. visualizations.py +360 -0
README.md CHANGED
@@ -1,12 +1,20 @@
1
- ---
2
- title: Indian Stock Analysis
3
- emoji: 💻
4
- colorFrom: pink
5
- colorTo: pink
6
- sdk: streamlit
7
- sdk_version: 1.43.2
8
- app_file: app.py
9
- pinned: false
10
- ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Stock_Analytics
2
+ Stock Prediction and Analysis Script Overview This script is designed to predict stock prices using various machine learning and statistical models. It fetches historical stock data, processes it, and then applies several predictive models. The results, including forecasts and model coefficients, are saved to an Excel file for further analysis.
 
 
 
 
 
 
 
 
3
 
4
+ 1. Data Ingestion and Preprocessing
5
+ Data Source: Historical stock data is fetched using the yfinance library, which provides access to financial data directly from Yahoo Finance.
6
+ Preprocessing: The data is then cleaned and processed using pandas and numpy for further analysis. This includes handling missing values, calculating moving averages, and other necessary data transformations.
7
+ 2. Technical Analysis and Machine Learning
8
+ Technical Indicators: Using libraries like pandas and numpy, the project calculates various technical indicators such as moving averages, RSI (Relative Strength Index), Bollinger Bands, etc.
9
+ Feature Engineering: Features are created and selected for training machine learning models. These features may include technical indicators and other stock-related metrics.
10
+ Machine Learning Models: The scikit-learn library is used to build predictive models. These models might include Linear Regression, Random Forest, or other algorithms to predict future stock prices.
11
+ Risk Assessment: The project assesses risk levels associated with each stock, possibly by analyzing volatility, technical indicators, or other metrics.
12
+ 3. Visualization
13
+ Matplotlib and Seaborn: These libraries are used to create static visualizations such as line plots for stock prices, candlestick charts, and bar plots for feature importance.
14
+ Plotly: An optional tool for creating interactive visualizations, especially within the Streamlit app.
15
+ Candlestick Charts: mplfinance is used to generate candlestick charts that visualize open, high, low, and close prices.
16
+ 4. Web Application with Streamlit
17
+ User Interface: The entire analysis and visualization can be wrapped in a web-based UI using Streamlit. Users can input stock tickers and get visualized results, including price predictions, technical indicators, and risk levels.
18
+ Custom Styling: The Streamlit app is styled according to user preferences, including setting backgrounds, coloring text and numeric values based on risk levels, and displaying buy signals.
19
+ Tabs and Layout: Multiple tabs or sections can be created in Streamlit for different types of visualizations like technical indicators, feature importance, and future predictions.
20
+ 5. LLM Integration
app.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ import logging
5
+
6
+ from model import fetch_data, calculate_indicators, calculate_support_resistance, predict_future_prices
7
+ from visualizations import (
8
+ plot_stock_price, plot_predictions, plot_technical_indicators, plot_risk_levels,
9
+ plot_feature_importance, plot_candlestick, plot_volume, plot_moving_averages,
10
+ plot_feature_correlations
11
+ )
12
+ from sklearn.metrics import ConfusionMatrixDisplay
13
+ from ui import display_analysis
14
+ from logger import get_logger
15
+
16
+ from dashboard import display_dashboard, display_profile,fetch_stock_profile,display_quarterly_results, display_shareholding_pattern, display_financial_ratios
17
+ from llm import display_recommendation #analyze_stock_with_llm
18
+
19
+ # import argparse
20
+
21
+ # parser = argparse.ArgumentParser()
22
+ # parser.add_argument('--token', required=True)
23
+ # args = parser.parse_args()
24
+
25
+ # API_TOKEN = args.token
26
+
27
+ logger = get_logger(__name__)
28
+
29
+ st.title("Stock Analysis and Prediction")
30
+
31
+
32
+
33
+ # Sidebar for navigation
34
+ st.sidebar.title("Navigation")
35
+
36
+ # Initialize `page` to "Analytics" by default
37
+ if 'page' not in st.session_state:
38
+ st.session_state['page'] = "Analytics"
39
+
40
+ if st.sidebar.button("Analytics"):
41
+ st.session_state['page'] = "Analytics"
42
+ if st.sidebar.button("Ask to AI"):
43
+ st.session_state['page'] = "Ask to AI"
44
+ if st.sidebar.button("Dashboard"):
45
+ st.session_state['page'] = "Dashboard"
46
+ if st.sidebar.button("Profile"):
47
+ st.session_state['page'] = "Profile"
48
+ page = st.session_state['page']
49
+
50
+ # Function to fetch and prepare data
51
+ def get_data():
52
+ ticker = st.session_state.get('ticker')
53
+ start_date = st.session_state.get('start_date')
54
+ end_date = st.session_state.get('end_date')
55
+
56
+ try:
57
+ data = fetch_data(ticker, start_date, end_date)
58
+ if data is not None:
59
+ data = calculate_indicators(data)
60
+ return data
61
+ else:
62
+ st.error("Failed to fetch data. Please check the stock ticker symbol and date range.")
63
+ return None
64
+ except Exception as e:
65
+ st.error(f"An error occurred: {e}")
66
+ return None
67
+
68
+ # Display content based on selected page
69
+ if page == "Analytics":
70
+ st.header("Analytics")
71
+
72
+ # Data input section
73
+ ticker = st.text_input("Stock Ticker", "BHEL.NS")
74
+ start_date = st.date_input("Start Date", pd.to_datetime("2020-01-01"))
75
+ end_date = st.date_input("End Date", pd.to_datetime("2024-09-04"))
76
+ algorithm = st.selectbox(
77
+ "Choose an Algorithm",
78
+ ['Linear Regression','LSTM', 'ARIMA','Decision Tree', 'Random Forest', 'XGBoost', 'CatBoost', 'SARIMA']
79
+ )
80
+ st.session_state['ticker'] = ticker
81
+ st.session_state['start_date'] = start_date
82
+ st.session_state['end_date'] = end_date
83
+ st.session_state['algorithm'] = algorithm
84
+
85
+ # Tabs for Analyze and Visualization under Analytics
86
+ tab1, tab2 = st.tabs(["Analyze", "Visualization"])
87
+
88
+ # Analyze Tab
89
+ with tab1:
90
+ if st.button("Analyze"):
91
+ data = get_data()
92
+ if data is not None:
93
+ display_analysis(data, st.session_state.get('algorithm'))
94
+
95
+ # Visualization Tab
96
+ with tab2:
97
+ st.write("### Visualizations")
98
+
99
+ # Fetch and prepare data for visualization
100
+ data = get_data()
101
+ if data is not None:
102
+ indicators = {
103
+ 'SMA_50': data['SMA_50'],
104
+ 'EMA_50': data['EMA_50'],
105
+ 'RSI': data['RSI'],
106
+ 'MACD': data['MACD'],
107
+ 'MACD_Signal': data['MACD_Signal'],
108
+ 'Bollinger_High': data['Bollinger_High'],
109
+ 'Bollinger_Low': data['Bollinger_Low'],
110
+ 'ATR': data['ATR'],
111
+ 'OBV': data['OBV']
112
+ }
113
+
114
+ # Visualization choices
115
+ choice = st.selectbox(
116
+ "Choose a type of visualization",
117
+ [
118
+ "Stock Price","Volume",
119
+ "Moving Averages",
120
+ "Feature Correlations",
121
+ "Predictions vs Actual",
122
+ "Technical Indicators",
123
+ "Risk Levels",
124
+ "Feature Importance",
125
+ "Candlestick"
126
+ ]
127
+ )
128
+
129
+ try:
130
+ if choice == "Stock Price":
131
+ plot_stock_price(data, st.session_state.get('ticker'), indicators)
132
+ elif choice == "Predictions vs Actual":
133
+ future_prices, _, _, _, _ = predict_future_prices(data, st.session_state.get('algorithm'))
134
+ if future_prices is not None:
135
+ st.line_chart(pd.DataFrame({'Actual Prices': data['Close'], 'Predicted Prices': pd.Series(future_prices).values}))
136
+ else:
137
+ st.error("Failed to fetch predictions.")
138
+ logger.error("Failed to fetch predictions.")
139
+ elif choice == "Technical Indicators":
140
+ plot_technical_indicators(data, indicators)
141
+ elif choice == "Risk Levels":
142
+ plot_risk_levels(data)
143
+ elif choice == "Feature Importance":
144
+ plot_feature_importance()
145
+ elif choice == "Candlestick":
146
+ plot_candlestick(data)
147
+ elif choice == "Volume":
148
+ plot_volume(data)
149
+ elif choice == "Moving Averages":
150
+ plot_moving_averages(data)
151
+ elif choice == "Feature Correlations":
152
+ plot_feature_correlations(data)
153
+ except Exception as e:
154
+ logger.error(f"An error occurred during visualization: {e}")
155
+ st.error(f"An error occurred during visualization: {e}")
156
+ else:
157
+ st.error("Failed to fetch data. Please check the stock ticker symbol and date range.")
158
+ logger.error("Failed to fetch data. Please check the stock ticker symbol and date range.")
159
+
160
+ elif page == "Dashboard":
161
+
162
+ st.title("Stock analysis and screening tool for investors in India")
163
+
164
+ ticker = st.text_input("Enter stock ticker (e.g., TATAMOTORS.NS):").upper()
165
+ days = st.sidebar.slider("Select number of days for top movers:", 1, 30, 30)
166
+
167
+ profile = {}
168
+ if ticker:
169
+ profile = fetch_stock_profile(ticker)
170
+ if profile: # Only display profile if it's not empty
171
+ display_profile(profile)
172
+ display_quarterly_results(ticker)
173
+ display_shareholding_pattern(ticker)
174
+ display_financial_ratios(ticker)
175
+ else:
176
+ st.write("No data available for the ticker entered.")
177
+
178
+ st.sidebar.write("### Overview")
179
+ st.sidebar.write(f"Showing top gainers and losers over the past {days} day(s).")
180
+
181
+ display_dashboard()
182
+
183
+ # display_profile()
184
+ # # Display the main dashboard
185
+ # display_dashboard()
186
+
187
+ st.write("<div style='background-color: black; color: white; padding: 10px;'>Coming Soon A lot Updates.......</div>", unsafe_allow_html=True)
188
+
189
+ elif page == "Profile":
190
+ st.image("https://via.placeholder.com/150", caption="User Profile Photo")
191
+ st.write("### User Profile")
192
+ st.write("Name: Nandan Dutta")
193
+ st.write("Role: Data Analyst")
194
+ st.write("Email: [email protected]")
195
+
196
+ elif page == "Ask to AI":
197
+ st.title("Ask Stock Recommendation to AI")
198
+ st.write("Model: Meta LLaMA 3.1")
199
+
200
+ # Input fields for the user
201
+ ticker = st.text_input("Enter Stock Ticker (e.g., BHEL.NS, RELIANCE.NS):")
202
+ start_date = st.date_input("Start Date", value=None)
203
+ end_date = st.date_input("End Date", value=None)
204
+
205
+ if st.button("Get Recommendation"):
206
+ if ticker and start_date and end_date:
207
+ # Ensure dates are in the correct format
208
+ start_date_str = start_date.strftime('%Y-%m-%d')
209
+ end_date_str = end_date.strftime('%Y-%m-%d')
210
+
211
+ st.write(f"Fetching recommendation for {ticker} from {start_date_str} to {end_date_str}...")
212
+
213
+ try:
214
+ # Fetch the recommendation using the LLaMA model
215
+ recommendations = display_recommendation(ticker, start_date_str, end_date_str)
216
+
217
+ except Exception as e:
218
+ st.error(f"An error occurred: {e}")
219
+
220
+ else:
221
+ st.error("Please enter a valid ticker and date range.")
222
+
223
+
224
+
225
+
226
+
227
+
228
+
229
+ st.markdown(
230
+ """
231
+ <style>
232
+ @keyframes blink {
233
+ 0% { opacity: 1; }
234
+ 50% { opacity: 0; }
235
+ 100% { opacity: 1; }
236
+ }
237
+ .blinking-heart {
238
+ animation: blink 1s infinite;
239
+ }
240
+ </style>
241
+ <div style='background-color: #f1f1f1; color: #333; padding: 5px; text-align: center; border-top: 1px solid #ddd;'>
242
+ <p>Made with <span class="blinking-heart">❤️</span> from Nandan</p>
243
+ </div>
244
+ """,
245
+ unsafe_allow_html=True
246
+ )
247
+
248
+
249
+ # Display animated running disclaimer text
250
+ st.write(
251
+ """
252
+ <div style='background-color: black; color: white; padding: 10px; border-radius: 5px;'>
253
+ <marquee behavior="scroll" direction="left" scrollamount="5" style="font-size: 14px;">
254
+ This project is for educational purposes only. The information provided here should not be used for real investment decisions. Please perform your own research and consult with a financial advisor before making any investment decisions. Use this information at your own risk.
255
+ </marquee>
256
+ </div>
257
+ """,
258
+ unsafe_allow_html=True
259
+ )
260
+
261
+
262
+
263
+
264
+
265
+
266
+
267
+
268
+
269
+
270
+
271
+
272
+
273
+
274
+
275
+
276
+
277
+
278
+
279
+
280
+
281
+
282
+
dashboard.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yfinance as yf
2
+ import pandas as pd
3
+ import streamlit as st
4
+ from datetime import datetime, timedelta
5
+
6
+ # Fetch Nifty 50 tickers
7
+ def fetch_nifty50_tickers():
8
+ return [
9
+ "TATAMOTORS.NS", "RELIANCE.NS", "INFY.NS", "HDFCBANK.NS", "ICICIBANK.NS",
10
+ "SBIN.NS", "ITC.NS", "AXISBANK.NS", "MARUTI.NS", "TATASTEEL.NS",
11
+ "WIPRO.NS", "SUNPHARMA.NS", "HINDALCO.NS", "HCLTECH.NS", "NTPC.NS",
12
+ "L&T.NS", "M&M.NS", "ONGC.NS", "HDFCLIFE.NS", "ULTRACEMCO.NS",
13
+ "ADANIGREEN.NS", "BHARTIARTL.NS", "BAJAJFINSV.NS", "JSWSTEEL.NS", "DIVISLAB.NS",
14
+ "POWERGRID.NS", "KOTAKBANK.NS", "HINDUNILVR.NS", "TCS.NS", "CIPLA.NS",
15
+ "ASIANPAINT.NS", "GRASIM.NS", "BRITANNIA.NS", "SHREECEM.NS",
16
+ "TECHM.NS", "INDUSINDBK.NS", "EICHERMOT.NS", "COALINDIA.NS", "GAIL.NS",
17
+ "BOSCHLTD.NS", "M&MFIN.NS", "IDFCFIRSTB.NS", "HAVELLS.NS"
18
+ ]
19
+
20
+ # Fetch large cap tickers
21
+ def fetch_large_cap_tickers():
22
+ return fetch_nifty50_tickers() # Assuming large caps are the same as Nifty 50
23
+
24
+ # Fetch small cap tickers
25
+ def fetch_small_cap_tickers():
26
+ return [
27
+ "ALOKINDS.NS", "ADANIENT.NS", "AARTIIND.NS", "AVANTIFEED.NS", "BLS.IN",
28
+ "BHEL.NS", "BIRLACORP.NS", "CARBORUNIV.NS", "CENTRALBANK.NS", "EMAMILTD.NS",
29
+ "FDC.NS", "GLAXO.NS", "GODFRYPHLP.NS", "GSKCONS.NS", "HAVELLS.NS",
30
+ "HEMIPAPER.NS", "HIL.NS", "JINDALSAW.NS", "JUBLFOOD.NS", "KOTAKMAH.NS",
31
+ "MSTCLAS.NS", "NCC.NS", "PAGEIND.NS", "PIIND.NS", "SBI.CN",
32
+ "SISL.NS", "SOMANYCERA.NS", "STAR.NS", "SUNDARAM.NS", "TATAINVEST.NS",
33
+ "VSTIND.NS", "WABCOINDIA.NS", "WELCORP.NS", "ZEELEARN.NS", "ZOMATO.NS"
34
+ ]
35
+
36
+ # Get top movers
37
+ def get_top_movers(tickers, days=1):
38
+ end_date = datetime.now()
39
+ start_date = end_date - timedelta(days=days)
40
+
41
+ data = {}
42
+ for ticker in tickers:
43
+ try:
44
+ df = yf.download(ticker, start=start_date, end=end_date)
45
+ if not df.empty and 'Close' in df.columns:
46
+ df['Ticker'] = ticker
47
+ data[ticker] = df['Close'].pct_change().iloc[-1] # Percentage change
48
+ except Exception as e:
49
+ st.error(f"Error fetching data for {ticker}: {e}")
50
+
51
+ sorted_data = sorted(data.items(), key=lambda x: x[1], reverse=True)
52
+ top_gainers = sorted_data[:10]
53
+ top_losers = sorted_data[-10:]
54
+
55
+ return top_gainers, top_losers
56
+
57
+ # Format DataFrame with color
58
+ def format_df(df):
59
+ if not df.empty:
60
+ df['Percentage Change'] = pd.to_numeric(df['Percentage Change'], errors='coerce')
61
+ return df.style.applymap(lambda x: 'color: green' if x > 0 else 'color: red', subset=['Percentage Change'])
62
+ return df
63
+
64
+ # Display dashboard
65
+ def display_dashboard():
66
+ st.header("Dashboard")
67
+
68
+ # Fetch tickers
69
+ nifty50_tickers = fetch_nifty50_tickers()
70
+ large_cap_tickers = fetch_large_cap_tickers()
71
+ small_cap_tickers = fetch_small_cap_tickers()
72
+
73
+ # Get top gainers and losers
74
+ top_gainers_nifty50, top_losers_nifty50 = get_top_movers(nifty50_tickers)
75
+ top_gainers_large_cap, top_losers_large_cap = get_top_movers(large_cap_tickers)
76
+ top_gainers_small_cap, top_losers_small_cap = get_top_movers(small_cap_tickers)
77
+
78
+ # Create columns for tables
79
+ col1, col2, col3, col4 = st.columns(4)
80
+
81
+ with col1:
82
+ st.write("### Nifty 50 Top Gainers")
83
+ if top_gainers_nifty50:
84
+ df_gainers_nifty50 = pd.DataFrame(top_gainers_nifty50, columns=['Ticker', 'Percentage Change'])
85
+ st.dataframe(format_df(df_gainers_nifty50))
86
+
87
+ with col2:
88
+ st.write("### Nifty 50 Top Losers")
89
+ if top_losers_nifty50:
90
+ df_losers_nifty50 = pd.DataFrame(top_losers_nifty50, columns=['Ticker', 'Percentage Change'])
91
+ st.dataframe(format_df(df_losers_nifty50))
92
+
93
+ with col3:
94
+ st.write("### Large Cap Top Gainers")
95
+ if top_gainers_large_cap:
96
+ df_gainers_large_cap = pd.DataFrame(top_gainers_large_cap, columns=['Ticker', 'Percentage Change'])
97
+ st.dataframe(format_df(df_gainers_large_cap))
98
+
99
+ with col4:
100
+ st.write("### Large Cap Top Losers")
101
+ if top_losers_large_cap:
102
+ df_losers_large_cap = pd.DataFrame(top_losers_large_cap, columns=['Ticker', 'Percentage Change'])
103
+ st.dataframe(format_df(df_losers_large_cap))
104
+
105
+ # Fetch and display stock profile
106
+ def fetch_stock_profile(ticker):
107
+ try:
108
+ stock = yf.Ticker(ticker)
109
+ info = stock.info
110
+
111
+ profile = {
112
+ "Name": info.get('shortName', 'N/A'),
113
+ "Current Price": f"₹ {info.get('currentPrice', 'N/A')}",
114
+ "Market Cap": f"₹ {info.get('marketCap', 'N/A') / 1e7:.2f} Cr.",
115
+ "P/E Ratio": info.get('forwardEps', 'N/A'),
116
+ "Book Value": info.get('bookValue', 'N/A'),
117
+ "Dividend Yield": info.get('dividendYield', 'N/A'),
118
+ "ROCE": info.get('returnOnCapitalEmployed', 'N/A'),
119
+ "ROE": info.get('returnOnEquity', 'N/A'),
120
+ "Face Value": info.get('faceValue', 'N/A')
121
+ }
122
+ return profile
123
+ except Exception as e:
124
+ st.error(f"Error fetching profile for {ticker}: {e}")
125
+ return {}
126
+
127
+
128
+
129
+
130
+ # Display stock profile as a table
131
+ def display_profile(profile):
132
+ st.subheader("Stock Profile")
133
+ profile_df = pd.DataFrame([profile])
134
+ st.table(profile_df)
135
+
136
+ # Fetch and display quarterly results
137
+ def display_quarterly_results(ticker):
138
+ st.subheader("Quarterly Results Summary")
139
+ try:
140
+ stock = yf.Ticker(ticker)
141
+ financials = stock.quarterly_financials.T
142
+ if not financials.empty:
143
+ results = {
144
+ 'Sales': financials['Total Revenue'].iloc[-1] if 'Total Revenue' in financials.columns else 'N/A',
145
+ 'Operating Profit Margin': financials['Operating Income'].iloc[-1] if 'Operating Income' in financials.columns else 'N/A',
146
+ 'Net Profit': financials['Net Income'].iloc[-1] if 'Net Income' in financials.columns else 'N/A'
147
+ }
148
+ results_df = pd.DataFrame([results])
149
+ st.table(results_df)
150
+ else:
151
+ st.write("No quarterly results available.")
152
+ except Exception as e:
153
+ st.write(f"Error fetching quarterly results: {e}")
154
+
155
+ # Fetch and display shareholding pattern
156
+ def display_shareholding_pattern(ticker):
157
+ st.subheader("Shareholding Pattern")
158
+
159
+ # Placeholder values; replace with actual data source or API call
160
+ data = {
161
+ 'Category': ['Promoters', 'FIIs (Foreign Institutional Investors)', 'DIIs (Domestic Institutional Investors)', 'Public'],
162
+ 'Holding (%)': [45.0, 20.0, 15.0, 20.0]
163
+ }
164
+
165
+ df = pd.DataFrame(data)
166
+ st.table(df)
167
+
168
+
169
+ def display_financial_ratios(ticker):
170
+ st.subheader("Financial Ratios")
171
+ stock = yf.Ticker(ticker)
172
+
173
+ try:
174
+ # Placeholder values, calculate actual values based on your requirements
175
+ ratios = {
176
+ 'Debtor Days': 73,
177
+ 'Working Capital Days': 194,
178
+ 'Cash Conversion Cycle': 51
179
+ }
180
+ ratios_df = pd.DataFrame([ratios])
181
+ st.table(ratios_df)
182
+ except Exception as e:
183
+ st.write("Error fetching financial ratios:", e)
184
+
185
+
186
+
187
+
188
+
189
+
190
+
191
+
192
+
193
+
194
+
195
+
196
+
197
+
198
+
199
+
200
+
201
+ # Main application
202
+ # def main():
203
+ # st.title("Stock Analysis Dashboard")
204
+
205
+ # # Select ticker input
206
+ # ticker = st.text_input("Enter Stock Ticker (e.g., TATAMOTORS.NS)")
207
+
208
+ # if ticker:
209
+ # profile = fetch_stock_profile(ticker)
210
+ # if profile:
211
+ # display_profile(profile)
212
+
213
+ # display_quarterly_results(ticker)
214
+ # display_shareholding_pattern(ticker)
215
+
216
+ # # Show dashboard
217
+ # if st.button("Show Dashboard"):
218
+ # display_dashboard()
219
+
220
+ # if __name__ == "__main__":
221
+ # main()
llm.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import streamlit as st
3
+ import requests
4
+ from model import fetch_data, calculate_indicators, calculate_support_resistance
5
+ import os
6
+
7
+ # import argparse
8
+
9
+ # parser = argparse.ArgumentParser()
10
+ # parser.add_argument('--token', required=True)
11
+ # args = parser.parse_args()
12
+
13
+ # API_TOKEN = args.token
14
+ # Hugging Face API token and model URL
15
+ # API_TOKEN = os.environ['HUGGING_FACE_TOKEN']
16
+ API_TOKEN ='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'
17
+ API_URL = "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3.1-8B-Instruct"
18
+
19
+ def generate_prompt(ticker, start_date, end_date):
20
+ """Fetch data, calculate indicators, and prepare the prompt."""
21
+ data = fetch_data(ticker, start_date, end_date)
22
+
23
+ if data is None:
24
+ return "No data available for the given ticker and date range.", None
25
+
26
+ data = calculate_indicators(data)
27
+ support, resistance = calculate_support_resistance(data)
28
+
29
+ # Additional statistics
30
+ highest_close = data['Close'].max()
31
+ lowest_close = data['Close'].min()
32
+ average_close = data['Close'].mean()
33
+ average_volume = data['Volume'].mean()
34
+ highest_volume = data['Volume'].max()
35
+ lowest_volume = data['Volume'].min()
36
+ daily_returns = data['Close'].pct_change().dropna()
37
+ volatility = daily_returns.std()
38
+
39
+ recent_trend = "uptrend" if data['Close'].iloc[-1] > data['Close'].iloc[0] else "downtrend" if data['Close'].iloc[-1] < data['Close'].iloc[0] else "sideways"
40
+
41
+ # Summarize the key statistics
42
+ summary = {
43
+ 'latest_close': data['Close'].iloc[-1],
44
+ 'SMA_50': data['SMA_50'].iloc[-1],
45
+ 'EMA_50': data['EMA_50'].iloc[-1],
46
+ 'RSI': data['RSI'].iloc[-1],
47
+ 'MACD': data['MACD'].iloc[-1],
48
+ 'MACD_Signal': data['MACD_Signal'].iloc[-1],
49
+ 'Bollinger_High': data['Bollinger_High'].iloc[-1],
50
+ 'Bollinger_Low': data['Bollinger_Low'].iloc[-1],
51
+ 'ATR': data['ATR'].iloc[-1],
52
+ 'OBV': data['OBV'].iloc[-1],
53
+ 'Support': support,
54
+ 'Resistance': resistance,
55
+ 'Highest_Close': highest_close,
56
+ 'Lowest_Close': lowest_close,
57
+ 'Average_Close': average_close,
58
+ 'Average_Volume': average_volume,
59
+ 'Highest_Volume': highest_volume,
60
+ 'Lowest_Volume': lowest_volume,
61
+ 'Volatility': volatility,
62
+ 'Recent_Trend': recent_trend,
63
+ 'Percentage_Change': (data['Close'].iloc[-1] - data['Close'].iloc[-2]) / data['Close'].iloc[-2] * 100 if len(data) > 1 else 0
64
+ }
65
+
66
+ prompt = f"""
67
+ Analyze the following stock data for {ticker} and provide a buy/sell recommendation:
68
+ Latest Close Price: {summary['latest_close']}
69
+ SMA 50: {summary['SMA_50']}
70
+ EMA 50: {summary['EMA_50']}
71
+ RSI: {summary['RSI']}
72
+ MACD: {summary['MACD']}
73
+ MACD Signal: {summary['MACD_Signal']}
74
+ Bollinger Bands High: {summary['Bollinger_High']}
75
+ Bollinger Bands Low: {summary['Bollinger_Low']}
76
+ ATR: {summary['ATR']}
77
+ OBV: {summary['OBV']}
78
+ Support Level: {summary['Support']}
79
+ Resistance Level: {summary['Resistance']}
80
+ Highest Close Price: {summary['Highest_Close']}
81
+ Lowest Close Price: {summary['Lowest_Close']}
82
+ Average Close Price: {summary['Average_Close']}
83
+ Average Volume: {summary['Average_Volume']}
84
+ Highest Volume: {summary['Highest_Volume']}
85
+ Lowest Volume: {summary['Lowest_Volume']}
86
+ Volatility: {summary['Volatility']}
87
+ Recent Trend: {summary['Recent_Trend']}
88
+ Percentage Change: {summary['Percentage_Change']}%
89
+ """
90
+
91
+ return prompt, summary
92
+
93
+ def get_recommendation(prompt):
94
+ """Get stock recommendation from Hugging Face API."""
95
+ headers = {
96
+ "Authorization": f"Bearer {API_TOKEN}",
97
+ "Content-Type": "application/json"
98
+ }
99
+ payload = {"inputs": prompt}
100
+
101
+ response = requests.post(API_URL, headers=headers, json=payload)
102
+ response.raise_for_status() # Raise an error for HTTP issues
103
+ result = response.json()
104
+
105
+ return result[0]['generated_text'].strip()
106
+
107
+ def display_recommendation(ticker, start_date, end_date):
108
+ """Fetch data, generate prompt, get recommendation, and display it in a nice format."""
109
+ prompt, summary = generate_prompt(ticker, start_date, end_date)
110
+
111
+ if summary is None:
112
+ st.error(prompt)
113
+ return
114
+
115
+ try:
116
+ recommendation = get_recommendation(prompt)
117
+ except Exception as e:
118
+ st.error(f"An error occurred while getting recommendation: {e}")
119
+ return
120
+
121
+ # Display in a box/table format using Streamlit
122
+ st.markdown(f"### Stock Analysis & Recommendation for {ticker}")
123
+
124
+ st.markdown(f"""
125
+ <div style='border:2px solid #4CAF50; padding: 15px; border-radius: 10px;'>
126
+ <table style='width:100%; border-collapse: collapse;'>
127
+ <tr>
128
+ <th style='text-align: left;'>Indicator</th>
129
+ <th style='text-align: left;'>Value</th>
130
+ </tr>
131
+ <tr>
132
+ <td>Latest Close Price</td>
133
+ <td>{summary['latest_close']}</td>
134
+ </tr>
135
+ <tr>
136
+ <td>SMA 50</td>
137
+ <td>{summary['SMA_50']}</td>
138
+ </tr>
139
+ <tr>
140
+ <td>EMA 50</td>
141
+ <td>{summary['EMA_50']}</td>
142
+ </tr>
143
+ <tr>
144
+ <td>RSI</td>
145
+ <td>{summary['RSI']}</td>
146
+ </tr>
147
+ <tr>
148
+ <td>MACD</td>
149
+ <td>{summary['MACD']}</td>
150
+ </tr>
151
+ <tr>
152
+ <td>MACD Signal</td>
153
+ <td>{summary['MACD_Signal']}</td>
154
+ </tr>
155
+ <tr>
156
+ <td>Bollinger Bands High</td>
157
+ <td>{summary['Bollinger_High']}</td>
158
+ </tr>
159
+ <tr>
160
+ <td>Bollinger Bands Low</td>
161
+ <td>{summary['Bollinger_Low']}</td>
162
+ </tr>
163
+ <tr>
164
+ <td>ATR</td>
165
+ <td>{summary['ATR']}</td>
166
+ </tr>
167
+ <tr>
168
+ <td>OBV</td>
169
+ <td>{summary['OBV']}</td>
170
+ </tr>
171
+ <tr>
172
+ <td>Support Level</td>
173
+ <td>{summary['Support']}</td>
174
+ </tr>
175
+ <tr>
176
+ <td>Resistance Level</td>
177
+ <td>{summary['Resistance']}</td>
178
+ </tr>
179
+ <tr>
180
+ <td>Highest Close Price</td>
181
+ <td>{summary['Highest_Close']}</td>
182
+ </tr>
183
+ <tr>
184
+ <td>Lowest Close Price</td>
185
+ <td>{summary['Lowest_Close']}</td>
186
+ </tr>
187
+ <tr>
188
+ <td>Average Close Price</td>
189
+ <td>{summary['Average_Close']}</td>
190
+ </tr>
191
+ <tr>
192
+ <td>Average Volume</td>
193
+ <td>{summary['Average_Volume']}</td>
194
+ </tr>
195
+ <tr>
196
+ <td>Highest Volume</td>
197
+ <td>{summary['Highest_Volume']}</td>
198
+ </tr>
199
+ <tr>
200
+ <td>Lowest Volume</td>
201
+ <td>{summary['Lowest_Volume']}</td>
202
+ </tr>
203
+ <tr>
204
+ <td>Volatility</td>
205
+ <td>{summary['Volatility']}</td>
206
+ </tr>
207
+ <tr>
208
+ <td>Recent Trend</td>
209
+ <td>{summary['Recent_Trend']}</td>
210
+ </tr>
211
+ <tr>
212
+ <td>Percentage Change</td>
213
+ <td>{summary['Percentage_Change']}%</td>
214
+ </tr>
215
+ </table>
216
+ </div>
217
+ """, unsafe_allow_html=True)
218
+
219
+ st.write('AI Recommendation:')
220
+ st.write(recommendation)
logger.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # logger.py
2
+ import logging
3
+
4
+ def get_logger(name):
5
+ """
6
+ Create and return a logger instance with a specified name.
7
+ """
8
+ logger = logging.getLogger(name)
9
+ if not logger.hasHandlers():
10
+ handler = logging.StreamHandler()
11
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
12
+ handler.setFormatter(formatter)
13
+ logger.addHandler(handler)
14
+ logger.setLevel(logging.INFO)
15
+ return logger
model.py ADDED
@@ -0,0 +1,710 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import yfinance as yf
3
+ import pandas as pd
4
+ import ta
5
+ from sklearn.model_selection import train_test_split
6
+ from sklearn.linear_model import LinearRegression
7
+ from sklearn.tree import DecisionTreeRegressor
8
+ from sklearn.ensemble import RandomForestRegressor
9
+ from sklearn.metrics import mean_absolute_error, r2_score
10
+ import xgboost as xgb
11
+ from catboost import CatBoostRegressor
12
+ import numpy as np
13
+ from tensorflow.keras.models import Sequential
14
+ from tensorflow.keras.layers import LSTM, Dense
15
+ from sklearn.preprocessing import MinMaxScaler
16
+ from statsmodels.tsa.arima.model import ARIMA
17
+ from statsmodels.tsa.statespace.sarimax import SARIMAX
18
+
19
+
20
+ from logger import get_logger
21
+
22
+ logger = get_logger(__name__)
23
+ # logger.setLevel(logging.DEBUG)
24
+ # handler = logging.StreamHandler()
25
+ # handler.setLevel(logging.DEBUG)
26
+ # formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
27
+ # handler.setFormatter(formatter)
28
+ # logger.addHandler(handler)
29
+
30
+ # # Example usage of logger
31
+ # logger.info("This is an info message")
32
+
33
+ # Fetch historical data
34
+ def fetch_data(ticker, start_date, end_date):
35
+ logger.info(f"Fetching data for {ticker} from {start_date} to {end_date}")
36
+ data = yf.download(ticker, start=start_date, end=end_date)
37
+ if data.empty:
38
+ logger.warning(f"No data returned for {ticker}.")
39
+ return None
40
+
41
+ # Reset index to ensure Date is a column
42
+ data.reset_index(inplace=True)
43
+ logger.info(f"Data fetched successfully for {ticker}.")
44
+ return data
45
+
46
+ def calculate_indicators(data: pd.DataFrame) -> pd.DataFrame:
47
+ logger.info("Calculating indicators with fixed parameters.")
48
+
49
+ # Check if required columns are present
50
+ required_columns = ['Close', 'High', 'Low', 'Volume']
51
+ missing_columns = [col for col in required_columns if col not in data.columns]
52
+ if missing_columns:
53
+ logger.error(f"Missing columns in data: {', '.join(missing_columns)}")
54
+ raise KeyError(f"Missing columns in data: {', '.join(missing_columns)}")
55
+
56
+ # Calculate fixed moving averages
57
+ ma_period = 50 # Fixed period for moving averages
58
+ try:
59
+ data[f'SMA_{ma_period}'] = data['Close'].rolling(window=ma_period).mean()
60
+ data[f'EMA_{ma_period}'] = data['Close'].ewm(span=ma_period, adjust=False).mean()
61
+ except Exception as e:
62
+ logger.error(f"Error calculating moving averages: {e}")
63
+ raise
64
+
65
+ # Calculate other indicators
66
+ try:
67
+ data['RSI'] = ta.momentum.RSIIndicator(data['Close']).rsi()
68
+ macd = ta.trend.MACD(data['Close'])
69
+ data['MACD'] = macd.macd()
70
+ data['MACD_Signal'] = macd.macd_signal()
71
+ bollinger = ta.volatility.BollingerBands(data['Close'])
72
+ data['Bollinger_High'] = bollinger.bollinger_hband()
73
+ data['Bollinger_Low'] = bollinger.bollinger_lband()
74
+ data['ATR'] = ta.volatility.AverageTrueRange(data['High'], data['Low'], data['Close']).average_true_range()
75
+ data['OBV'] = ta.volume.OnBalanceVolumeIndicator(data['Close'], data['Volume']).on_balance_volume()
76
+ except Exception as e:
77
+ logger.error(f"Error calculating other indicators: {e}")
78
+ raise
79
+
80
+ # Debugging line to check the columns
81
+ logger.debug("Columns after calculating indicators: %s", data.columns)
82
+
83
+ data = data.dropna()
84
+ logger.info("Indicators calculated successfully.")
85
+ return data
86
+
87
+
88
+ # def calculate_indicators(data: pd.DataFrame, ma_type='SMA', ma_period=50) -> pd.DataFrame:
89
+ # logger.info(f"Calculating indicators with {ma_type} of period {ma_period}.")
90
+
91
+ # # Check if required columns are present
92
+ # required_columns = ['Close', 'High', 'Low', 'Volume']
93
+ # missing_columns = [col for col in required_columns if col not in data.columns]
94
+ # if missing_columns:
95
+ # logger.error(f"Missing columns in data: {', '.join(missing_columns)}")
96
+ # raise KeyError(f"Missing columns in data: {', '.join(missing_columns)}")
97
+
98
+ # # Calculate moving averages
99
+ # if ma_type == 'SMA':
100
+ # data[f'SMA_{ma_period}'] = data['Close'].rolling(window=ma_period).mean()
101
+ # elif ma_type == 'EMA':
102
+ # data[f'EMA_{ma_period}'] = data['Close'].ewm(span=ma_period, adjust=False).mean()
103
+ # else:
104
+ # logger.error(f"Unknown moving average type: {ma_type}")
105
+ # raise ValueError(f"Unknown moving average type: {ma_type}")
106
+
107
+ # # Calculate other indicators
108
+ # try:
109
+ # data['RSI'] = ta.momentum.RSIIndicator(data['Close']).rsi()
110
+ # macd = ta.trend.MACD(data['Close'])
111
+ # data['MACD'] = macd.macd()
112
+ # data['MACD_Signal'] = macd.macd_signal()
113
+ # bollinger = ta.volatility.BollingerBands(data['Close'])
114
+ # data['Bollinger_High'] = bollinger.bollinger_hband()
115
+ # data['Bollinger_Low'] = bollinger.bollinger_lband()
116
+ # data['ATR'] = ta.volatility.AverageTrueRange(data['High'], data['Low'], data['Close']).average_true_range()
117
+ # data['OBV'] = ta.volume.OnBalanceVolumeIndicator(data['Close'], data['Volume']).on_balance_volume()
118
+ # except Exception as e:
119
+ # logger.error(f"Error calculating indicators: {e}")
120
+ # raise
121
+
122
+ # Debugging line to check the columns
123
+ logger.debug("Columns after calculating indicators: %s", data.columns)
124
+
125
+ data = data.dropna()
126
+ logger.info("Indicators calculated successfully.")
127
+ return data
128
+
129
+
130
+ # # Calculate technical indicators
131
+ # def calculate_indicators(data, ma_type='SMA', ma_period=50):
132
+ # logger.info(f"Calculating indicators with {ma_type} of period {ma_period}.")
133
+
134
+ # if ma_type == 'SMA':
135
+ # data[f'SMA_{ma_period}'] = data['Close'].rolling(window=ma_period).mean()
136
+ # elif ma_type == 'EMA':
137
+ # data[f'EMA_{ma_period}'] = data['Close'].ewm(span=ma_period, adjust=False).mean()
138
+
139
+ # data['RSI'] = ta.momentum.RSIIndicator(data['Close']).rsi()
140
+ # macd = ta.trend.MACD(data['Close'])
141
+ # data['MACD'] = macd.macd()
142
+ # data['MACD_Signal'] = macd.macd_signal()
143
+ # bollinger = ta.volatility.BollingerBands(data['Close'])
144
+ # data['Bollinger_High'] = bollinger.bollinger_hband()
145
+ # data['Bollinger_Low'] = bollinger.bollinger_lband()
146
+ # data['ATR'] = ta.volatility.AverageTrueRange(data['High'], data['Low'], data['Close']).average_true_range()
147
+ # data['OBV'] = ta.volume.OnBalanceVolumeIndicator(data['Close'], data['Volume']).on_balance_volume()
148
+
149
+ # # Debugging line to check the columns
150
+ # logger.debug("Columns after calculating indicators: %s", data.columns)
151
+
152
+ # data = data.dropna()
153
+ # logger.info("Indicators calculated successfully.")
154
+ # return data
155
+
156
+ # Calculate support and resistance levels
157
+ def calculate_support_resistance(data, window=30):
158
+ logger.info(f"Calculating support and resistance with a window of {window}.")
159
+
160
+ recent_data = data.tail(window)
161
+ rolling_max = data['Close'].rolling(window=window).max()
162
+ rolling_min = data['Close'].rolling(window=window).min()
163
+ recent_max = recent_data['Close'].max()
164
+ recent_min = recent_data['Close'].min()
165
+
166
+ support = min(rolling_min.iloc[-1], recent_min)
167
+ resistance = max(rolling_max.iloc[-1], recent_max)
168
+
169
+ logger.debug("Support: %f, Resistance: %f", support, resistance)
170
+ return support, resistance
171
+
172
+ # Prepare data for LSTM model
173
+ def prepare_lstm_data(data):
174
+ logger.info("Preparing data for LSTM model.")
175
+
176
+ features = data[['Open', 'SMA_50', 'EMA_50', 'RSI', 'MACD', 'MACD_Signal', 'Bollinger_High', 'Bollinger_Low', 'ATR', 'OBV']].values
177
+ target = data['Close'].values
178
+ scaler = MinMaxScaler()
179
+ features = scaler.fit_transform(features)
180
+
181
+ X, y = [], []
182
+ for i in range(len(features) - 60):
183
+ X.append(features[i:i+60])
184
+ y.append(target[i+60])
185
+
186
+ logger.info("Data preparation for LSTM completed.")
187
+ return np.array(X), np.array(y)
188
+
189
+
190
+ def predict_future_prices(data, algorithm, days=10):
191
+ logger.info(f"Predicting future prices using {algorithm}.")
192
+
193
+ # Check if required columns are present
194
+ required_columns = ['Open', 'SMA_50', 'EMA_50', 'RSI', 'MACD', 'MACD_Signal', 'Bollinger_High', 'Bollinger_Low', 'ATR', 'OBV']
195
+ missing_columns = [col for col in required_columns if col not in data.columns]
196
+
197
+ if missing_columns:
198
+ logger.error("Missing columns in data: %s", ', '.join(missing_columns))
199
+ raise KeyError(f"Missing columns in data: {', '.join(missing_columns)}")
200
+
201
+ features = data[required_columns]
202
+ target = data['Close']
203
+
204
+ X_train, X_test, y_train, y_test = train_test_split(features, target, test_size=0.2, random_state=42)
205
+
206
+ mae, r2 = None, None # Initialize variables for metrics
207
+
208
+ if algorithm == 'Linear Regression':
209
+ model = LinearRegression()
210
+
211
+ elif algorithm == 'Decision Tree':
212
+ model = DecisionTreeRegressor()
213
+
214
+ elif algorithm == 'Random Forest':
215
+ model = RandomForestRegressor(n_estimators=100)
216
+
217
+ elif algorithm == 'XGBoost':
218
+ model = xgb.XGBRegressor(objective='reg:squarederror', eval_metric='rmse')
219
+
220
+ elif algorithm == 'CatBoost':
221
+ model = CatBoostRegressor(learning_rate=0.1, depth=6, iterations=500, verbose=0)
222
+
223
+ elif algorithm == 'LSTM':
224
+ X, y = prepare_lstm_data(data)
225
+ model = Sequential()
226
+ model.add(LSTM(50, return_sequences=True, input_shape=(X.shape[1], X.shape[2])))
227
+ model.add(LSTM(50))
228
+ model.add(Dense(1))
229
+ model.compile(optimizer='adam', loss='mean_squared_error')
230
+ model.fit(X, y, epochs=10, batch_size=32, verbose=0)
231
+ last_data_point = np.expand_dims(X[-1], axis=0)
232
+ future_prices = [model.predict(last_data_point)[0][0] for _ in range(days)]
233
+ logger.info("Future prices predicted using LSTM model.")
234
+ return future_prices, None, None, None, None
235
+
236
+ elif algorithm == 'ARIMA':
237
+ model = ARIMA(data['Close'], order=(5, 1, 0))
238
+ model_fit = model.fit()
239
+ future_prices = model_fit.forecast(steps=days)
240
+
241
+ elif algorithm == 'SARIMA':
242
+ model = SARIMAX(data['Close'], order=(5, 1, 0), seasonal_order=(1, 1, 0, 12))
243
+ model_fit = model.fit()
244
+ future_prices = model_fit.forecast(steps=days)
245
+
246
+ else:
247
+ logger.error("Algorithm not recognized: %s", algorithm)
248
+ return None, None, None, None, None
249
+
250
+ if algorithm in ['Linear Regression', 'Decision Tree', 'Random Forest', 'XGBoost', 'CatBoost']:
251
+ model.fit(X_train, y_train)
252
+ predictions = model.predict(X_test)
253
+ mae = mean_absolute_error(y_test, predictions)
254
+ r2 = r2_score(y_test, predictions)
255
+
256
+ future_prices = []
257
+ last_data_point = features.iloc[-1].values.reshape(1, -1) # Ensure it's 2D
258
+
259
+ for _ in range(days):
260
+ future_price = model.predict(last_data_point)[0]
261
+ future_prices.append(future_price)
262
+ last_data_point = last_data_point + 1 # Update last data point (simplified, better methods should be used)
263
+
264
+ logger.info("Future prices predicted using %s model.", algorithm)
265
+ return future_prices, mae, r2, None, None
266
+
267
+
268
+
269
+ # def predict_future_prices(data, algorithm, days=10):
270
+ # logger.info(f"Predicting future prices using {algorithm}.")
271
+
272
+ # # Check if required columns are present
273
+ # required_columns = ['Open', 'SMA_50', 'EMA_50', 'RSI', 'MACD', 'MACD_Signal', 'Bollinger_High', 'Bollinger_Low', 'ATR', 'OBV']
274
+ # missing_columns = [col for col in required_columns if col not in data.columns]
275
+
276
+ # if missing_columns:
277
+ # logger.error("Missing columns in data: %s", ', '.join(missing_columns))
278
+ # raise KeyError(f"Missing columns in data: {', '.join(missing_columns)}")
279
+
280
+ # features = data[required_columns]
281
+ # target = data['Close']
282
+
283
+ # X_train, X_test, y_train, y_test = train_test_split(features, target, test_size=0.2, random_state=42)
284
+
285
+ # if algorithm == 'Linear Regression':
286
+ # model = LinearRegression()
287
+
288
+ # elif algorithm == 'Decision Tree':
289
+ # model = DecisionTreeRegressor()
290
+
291
+ # elif algorithm == 'Random Forest':
292
+ # model = RandomForestRegressor(n_estimators=100)
293
+
294
+ # elif algorithm == 'XGBoost':
295
+ # model = xgb.XGBRegressor(objective='reg:squarederror', eval_metric='rmse')
296
+
297
+ # elif algorithm == 'CatBoost':
298
+ # model = CatBoostRegressor(learning_rate=0.1, depth=6, iterations=500, verbose=0)
299
+
300
+ # elif algorithm == 'LSTM':
301
+ # X, y = prepare_lstm_data(data)
302
+ # model = Sequential()
303
+ # model.add(LSTM(50, return_sequences=True, input_shape=(X.shape[1], X.shape[2])))
304
+ # model.add(LSTM(50))
305
+ # model.add(Dense(1))
306
+ # model.compile(optimizer='adam', loss='mean_squared_error')
307
+ # model.fit(X, y, epochs=10, batch_size=32, verbose=0)
308
+ # last_data_point = np.expand_dims(X[-1], axis=0)
309
+ # future_prices = [model.predict(last_data_point)[0][0] for _ in range(days)]
310
+ # logger.info("Future prices predicted using LSTM model.")
311
+ # return future_prices, None, None, None, None
312
+
313
+ # elif algorithm == 'ARIMA':
314
+ # model = ARIMA(data['Close'], order=(5, 1, 0))
315
+ # model_fit = model.fit()
316
+ # future_prices = model_fit.forecast(steps=days)
317
+
318
+ # elif algorithm == 'SARIMA':
319
+ # model = SARIMAX(data['Close'], order=(5, 1, 0), seasonal_order=(1, 1, 0, 12))
320
+ # model_fit = model.fit()
321
+ # future_prices = model_fit.forecast(steps=days)
322
+
323
+ # else:
324
+ # logger.error("Algorithm not recognized: %s", algorithm)
325
+ # return None, None, None, None, None
326
+
327
+ # if algorithm in ['Linear Regression', 'Decision Tree', 'Random Forest', 'XGBoost', 'CatBoost']:
328
+ # model.fit(X_train, y_train)
329
+ # predictions = model.predict(X_test)
330
+ # mae = mean_absolute_error(y_test, predictions)
331
+ # r2 = r2_score(y_test, predictions)
332
+
333
+ # future_prices = []
334
+ # last_data_point = features.iloc[-1].values.reshape(1, -1) # Ensure it's 2D
335
+
336
+ # for _ in range(days):
337
+ # future_price = model.predict(last_data_point)[0]
338
+ # future_prices.append(future_price)
339
+ # last_data_point = last_data_point + 1 # Update last data point (simplified, better methods should be used)
340
+
341
+ # logger.info("Future prices predicted using %s model.", algorithm)
342
+ # return future_prices, mae, r2, None, None
343
+
344
+
345
+
346
+
347
+
348
+
349
+
350
+
351
+
352
+
353
+
354
+
355
+ # # Predict future prices using the selected algorithm
356
+ # def predict_future_prices(data, algorithm, days=10):
357
+ # logger.info(f"Predicting future prices using {algorithm}.")
358
+
359
+ # # Check if required columns are present
360
+ # required_columns = ['Open', 'SMA_50', 'EMA_50', 'RSI', 'MACD', 'MACD_Signal', 'Bollinger_High', 'Bollinger_Low', 'ATR', 'OBV']
361
+ # missing_columns = [col for col in required_columns if col not in data.columns]
362
+
363
+ # if missing_columns:
364
+ # logger.error("Missing columns in data: %s", ', '.join(missing_columns))
365
+ # raise KeyError(f"Missing columns in data: {', '.join(missing_columns)}")
366
+
367
+ # features = data[required_columns]
368
+ # target = data['Close']
369
+
370
+ # X_train, X_test, y_train, y_test = train_test_split(features, target, test_size=0.2, random_state=42)
371
+
372
+ # if algorithm == 'Linear Regression':
373
+ # model = LinearRegression()
374
+ # elif algorithm == 'Decision Tree':
375
+ # model = DecisionTreeRegressor()
376
+ # elif algorithm == 'Random Forest':
377
+ # model = RandomForestRegressor(n_estimators=100)
378
+ # elif algorithm == 'XGBoost':
379
+ # model = xgb.XGBRegressor(objective='reg:squarederror', eval_metric='rmse')
380
+ # elif algorithm == 'CatBoost':
381
+ # model = CatBoostRegressor(learning_rate=0.1, depth=6, iterations=500, verbose=0)
382
+ # elif algorithm == 'LSTM':
383
+
384
+ # X, y = prepare_lstm_data(data)
385
+ # model = Sequential()
386
+ # model.add(LSTM(50, return_sequences=True, input_shape=(X.shape[1], X.shape[2])))
387
+ # model.add(LSTM(50))
388
+ # model.add(Dense(1))
389
+ # model.compile(optimizer='adam', loss='mean_squared_error')
390
+ # model.fit(X, y, epochs=10, batch_size=32, verbose=0)
391
+ # last_data_point = np.expand_dims(X[-1], axis=0)
392
+ # future_prices = [model.predict(last_data_point)[0][0] for _ in range(days)]
393
+
394
+ # elif algorithm == 'ARIMA':
395
+ # model = ARIMA(data['Close'], order=(5, 1, 0))
396
+ # model_fit = model.fit()
397
+ # future_prices = model_fit.forecast(steps=10)
398
+
399
+ # elif algorithm == 'SARIMA':
400
+ # model = SARIMAX(data['Close'], order=(5, 1, 0), seasonal_order=(1, 1, 0, 12))
401
+ # model_fit = model.fit()
402
+ # forecast = model_fit.forecast(steps=10)
403
+
404
+ # logger.info("Future prices predicted using LSTM model.")
405
+ # return future_prices, None, None, None, None
406
+ # else:
407
+ # logger.error("Algorithm not recognized: %s", algorithm)
408
+ # return None, None, None, None, None
409
+
410
+ # model.fit(X_train, y_train)
411
+
412
+ # predictions = model.predict(X_test)
413
+ # mae = mean_absolute_error(y_test, predictions)
414
+ # r2 = r2_score(y_test, predictions)
415
+
416
+ # future_prices = []
417
+ # last_data_point = features.iloc[-1].values.reshape(1, -1) # Ensure it's 2D
418
+
419
+ # for _ in range(days):
420
+ # future_price = model.predict(last_data_point)[0]
421
+ # future_prices.append(future_price)
422
+ # last_data_point = last_data_point + 1 # Update last data point (simplified, better methods should be used)
423
+
424
+ # logger.info("Future prices predicted using %s model.", algorithm)
425
+ # return future_prices, mae, r2, None, None
426
+
427
+ # import pandas as pd
428
+ # import numpy as np
429
+ # import yfinance as yf
430
+ # import ta
431
+ # from sklearn.model_selection import train_test_split
432
+ # from sklearn.linear_model import LinearRegression
433
+ # from sklearn.tree import DecisionTreeRegressor
434
+ # from sklearn.ensemble import RandomForestRegressor
435
+ # from sklearn.metrics import mean_absolute_error, r2_score
436
+ # import xgboost as xgb
437
+ # from catboost import CatBoostRegressor
438
+ # from tensorflow.keras.models import Sequential
439
+ # from tensorflow.keras.layers import LSTM, Dense
440
+ # from sklearn.preprocessing import MinMaxScaler
441
+ # from statsmodels.tsa.arima_model import ARIMA
442
+ # from statsmodels.tsa.statespace.sarimax import SARIMAX
443
+
444
+ # from logger import get_logger
445
+
446
+ # logger = get_logger(__name__)
447
+
448
+ # # Fetch historical data
449
+ # def fetch_data(ticker, start_date, end_date):
450
+ # logger.info(f"Fetching data for {ticker} from {start_date} to {end_date}")
451
+ # data = yf.download(ticker, start=start_date, end=end_date)
452
+ # if data.empty:
453
+ # logger.warning(f"No data returned for {ticker}.")
454
+ # return None
455
+
456
+ # # Reset index to ensure Date is a column
457
+ # data.reset_index(inplace=True)
458
+ # logger.info(f"Data fetched successfully for {ticker}.")
459
+ # return data
460
+
461
+ # def calculate_indicators(data: pd.DataFrame) -> pd.DataFrame:
462
+ # logger.info("Calculating indicators with fixed parameters.")
463
+
464
+ # # Check if required columns are present
465
+ # required_columns = ['Close', 'High', 'Low', 'Volume']
466
+ # missing_columns = [col for col in required_columns if col not in data.columns]
467
+ # if missing_columns:
468
+ # logger.error(f"Missing columns in data: {', '.join(missing_columns)}")
469
+ # raise KeyError(f"Missing columns in data: {', '.join(missing_columns)}")
470
+
471
+ # # Calculate fixed moving averages
472
+ # ma_period = 50 # Fixed period for moving averages
473
+ # try:
474
+ # data[f'SMA_{ma_period}'] = data['Close'].rolling(window=ma_period).mean()
475
+ # data[f'EMA_{ma_period}'] = data['Close'].ewm(span=ma_period, adjust=False).mean()
476
+ # except Exception as e:
477
+ # logger.error(f"Error calculating moving averages: {e}")
478
+ # raise
479
+
480
+ # # Calculate other indicators
481
+ # try:
482
+ # data['RSI'] = ta.momentum.RSIIndicator(data['Close']).rsi()
483
+ # macd = ta.trend.MACD(data['Close'])
484
+ # data['MACD'] = macd.macd()
485
+ # data['MACD_Signal'] = macd.macd_signal()
486
+ # bollinger = ta.volatility.BollingerBands(data['Close'])
487
+ # data['Bollinger_High'] = bollinger.bollinger_hband()
488
+ # data['Bollinger_Low'] = bollinger.bollinger_lband()
489
+ # data['ATR'] = ta.volatility.AverageTrueRange(data['High'], data['Low'], data['Close']).average_true_range()
490
+ # data['OBV'] = ta.volume.OnBalanceVolumeIndicator(data['Close'], data['Volume']).on_balance_volume()
491
+ # except Exception as e:
492
+ # logger.error(f"Error calculating other indicators: {e}")
493
+ # raise
494
+
495
+ # # Debugging line to check the columns
496
+ # logger.debug("Columns after calculating indicators: %s", data.columns)
497
+
498
+ # data = data.dropna()
499
+ # logger.info("Indicators calculated successfully.")
500
+ # return data
501
+
502
+ # # Calculate support and resistance levels
503
+ # def calculate_support_resistance(data, window=30):
504
+ # logger.info(f"Calculating support and resistance with a window of {window}.")
505
+
506
+ # recent_data = data.tail(window)
507
+ # rolling_max = data['Close'].rolling(window=window).max()
508
+ # rolling_min = data['Close'].rolling(window=window).min()
509
+ # recent_max = recent_data['Close'].max()
510
+ # recent_min = recent_data['Close'].min()
511
+
512
+ # support = min(rolling_min.iloc[-1], recent_min)
513
+ # resistance = max(rolling_max.iloc[-1], recent_max)
514
+
515
+ # logger.debug("Support: %f, Resistance: %f", support, resistance)
516
+ # return support, resistance
517
+
518
+ # # Prepare data for LSTM model
519
+ # def prepare_lstm_data(data):
520
+ # logger.info("Preparing data for LSTM model.")
521
+
522
+ # features = data[['Open', 'SMA_50', 'EMA_50', 'RSI', 'MACD', 'MACD_Signal', 'Bollinger_High', 'Bollinger_Low', 'ATR', 'OBV']].values
523
+ # target = data['Close'].values
524
+ # scaler = MinMaxScaler()
525
+ # features = scaler.fit_transform(features)
526
+
527
+ # X, y = [], []
528
+ # for i in range(len(features) - 60):
529
+ # X.append(features[i:i+60])
530
+ # y.append(target[i+60])
531
+
532
+ # logger.info("Data preparation for LSTM completed.")
533
+ # return np.array(X), np.array(y)
534
+
535
+ # # Predict future prices using the selected algorithm
536
+ # def predict_future_prices(data, algorithm, days=10):
537
+ # logger.info(f"Predicting future prices using {algorithm}.")
538
+
539
+ # # Check if required columns are present
540
+ # required_columns = ['Open', 'SMA_50', 'EMA_50', 'RSI', 'MACD', 'MACD_Signal', 'Bollinger_High', 'Bollinger_Low', 'ATR', 'OBV']
541
+ # missing_columns = [col for col in required_columns if col not in data.columns]
542
+
543
+ # if missing_columns:
544
+ # logger.error("Missing columns in data: %s", ', '.join(missing_columns))
545
+ # raise KeyError(f"Missing columns in data: {', '.join(missing_columns)}")
546
+
547
+ # features = data[required_columns]
548
+ # target = data['Close']
549
+
550
+ # X_train, X_test, y_train, y_test = train_test_split(features, target, test_size=0.2, random_state=42)
551
+
552
+ # if algorithm == 'Linear Regression':
553
+ # model = LinearRegression()
554
+ # elif algorithm == 'Decision Tree':
555
+ # model = DecisionTreeRegressor()
556
+ # elif algorithm == 'Random Forest':
557
+ # model = RandomForestRegressor(n_estimators=100)
558
+ # elif algorithm == 'XGBoost':
559
+ # model = xgb.XGBRegressor(objective='reg:squarederror', eval_metric='rmse')
560
+ # elif algorithm == 'CatBoost':
561
+ # model = CatBoostRegressor(learning_rate=0.1, depth=6, iterations=500, verbose=0)
562
+ # elif algorithm == 'LSTM':
563
+ # X, y = prepare_lstm_data(data)
564
+ # model = Sequential()
565
+ # model.add(LSTM(50, return_sequences=True, input_shape=(X.shape[1], X.shape[2])))
566
+ # model.add(LSTM(50))
567
+ # model.add(Dense(1))
568
+ # model.compile(optimizer='adam', loss='mean_squared_error')
569
+ # model.fit(X, y, epochs=10, batch_size=32, verbose=0)
570
+ # last_data_point = np.expand_dims(X[-1], axis=0)
571
+ # future_prices = [model.predict(last_data_point)[0][0] for _ in range(days)]
572
+
573
+ # logger.info("Future prices predicted using LSTM model.")
574
+ # return future_prices, None, None, None, None
575
+ # elif algorithm == 'ARIMA':
576
+ # model = ARIMA(data['Close'], order=(5, 1, 0))
577
+ # model_fit = model.fit(disp=0)
578
+ # forecast = model_fit.forecast(steps=days)[0]
579
+
580
+ # mae = mean_absolute_error(target[-days:], forecast[:days])
581
+ # r2 = r2_score(target[-days:], forecast[:days])
582
+
583
+ # logger.info("Future prices predicted using ARIMA model.")
584
+ # return forecast.tolist(), mae, r2, None, None
585
+ # elif algorithm == 'SARIMA':
586
+ # model = SARIMAX(data['Close'], order=(5, 1, 0), seasonal_order=(1, 1, 0, 12))
587
+ # model_fit = model.fit(disp=0)
588
+ # forecast = model_fit.forecast(steps=days)
589
+
590
+ # mae = mean_absolute_error(target[-days:], forecast[:days])
591
+ # r2 = r2_score(target[-days:], forecast[:days])
592
+
593
+ # logger.info("Future prices predicted using SARIMA model.")
594
+ # return forecast.tolist(), mae, r2, None, None
595
+ # else:
596
+ # logger.error("Algorithm not recognized: %s", algorithm)
597
+ # return None, None, None, None, None
598
+
599
+ # model.fit(X_train, y_train)
600
+
601
+ # predictions = model.predict(X_test)
602
+ # mae = mean_absolute_error(y_test, predictions)
603
+ # r2 = r2_score(y_test, predictions)
604
+
605
+ # future_prices = []
606
+ # last_data_point = features.iloc[-1].values.reshape(1, -1) # Ensure it's 2D
607
+
608
+ # for _ in range(days):
609
+ # future_price = model.predict(last_data_point)[0]
610
+ # future_prices.append(future_price)
611
+ # last_data_point = last_data_point + 1 # Update last data point (simplified, better methods should be used)
612
+
613
+ # logger.info("Future prices predicted using %s model.", algorithm)
614
+ # return future_prices, mae, r2, predictions, y_test
615
+
616
+
617
+
618
+
619
+ # # model.py
620
+
621
+ # import pandas as pd
622
+ # import numpy as np
623
+ # import yfinance as yf
624
+ # import statsmodels.api as sm
625
+ # from statsmodels.tsa.arima.model import ARIMA
626
+ # from statsmodels.tsa.statespace.sarimax import SARIMAX
627
+ # from sklearn.metrics import mean_absolute_error, r2_score
628
+
629
+ # def fetch_data(ticker, start_date, end_date):
630
+ # try:
631
+ # df = yf.download(ticker, start=start_date, end=end_date)
632
+ # return df
633
+ # except Exception as e:
634
+ # print(f"An error occurred while fetching data: {e}")
635
+ # return None
636
+
637
+ # def calculate_indicators(data):
638
+ # # Example indicators - these should be tailored to your requirements
639
+ # data['SMA_50'] = data['Close'].rolling(window=50).mean()
640
+ # data['EMA_50'] = data['Close'].ewm(span=50, adjust=False).mean()
641
+ # data['RSI'] = calculate_rsi(data['Close'])
642
+ # data['MACD'], data['MACD_Signal'] = calculate_macd(data['Close'])
643
+ # data['Bollinger_High'], data['Bollinger_Low'] = calculate_bollinger_bands(data['Close'])
644
+ # data['ATR'] = calculate_atr(data)
645
+ # data['OBV'] = calculate_obv(data)
646
+ # return data
647
+
648
+ # def calculate_rsi(series, period=14):
649
+ # delta = series.diff()
650
+ # gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
651
+ # loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
652
+ # rs = gain / loss
653
+ # return 100 - (100 / (1 + rs))
654
+
655
+ # def calculate_macd(series):
656
+ # macd = series.ewm(span=12, adjust=False).mean() - series.ewm(span=26, adjust=False).mean()
657
+ # macd_signal = macd.ewm(span=9, adjust=False).mean()
658
+ # return macd, macd_signal
659
+
660
+ # def calculate_bollinger_bands(series, window=20):
661
+ # rolling_mean = series.rolling(window=window).mean()
662
+ # rolling_std = series.rolling(window=window).std()
663
+ # high = rolling_mean + (rolling_std * 2)
664
+ # low = rolling_mean - (rolling_std * 2)
665
+ # return high, low
666
+
667
+ # def calculate_atr(data, window=14):
668
+ # high_low = data['High'] - data['Low']
669
+ # high_close = np.abs(data['High'] - data['Close'].shift())
670
+ # low_close = np.abs(data['Low'] - data['Close'].shift())
671
+ # tr = np.max(np.array([high_low, high_close, low_close]), axis=0)
672
+ # atr = tr.rolling(window=window).mean()
673
+ # return atr
674
+
675
+ # def calculate_obv(data):
676
+ # obv = (data['Volume'] * np.sign(data['Close'].diff())).fillna(0).cumsum()
677
+ # return obv
678
+
679
+ # def calculate_support_resistance(data):
680
+ # # Example calculation - you may need to refine this based on your requirements
681
+ # support = data['Close'].min()
682
+ # resistance = data['Close'].max()
683
+ # return support, resistance
684
+
685
+ # def predict_future_prices(data, model_type='ARIMA'):
686
+ # try:
687
+ # # Use ARIMA
688
+ # if model_type == 'ARIMA':
689
+ # model = ARIMA(data['Close'], order=(5, 1, 0))
690
+ # model_fit = model.fit()
691
+ # forecast = model_fit.forecast(steps=10)
692
+ # # Use SARIMA
693
+ # elif model_type == 'SARIMA':
694
+ # model = SARIMAX(data['Close'], order=(5, 1, 0), seasonal_order=(1, 1, 0, 12))
695
+ # model_fit = model.fit()
696
+ # forecast = model_fit.forecast(steps=10)
697
+ # else:
698
+ # raise ValueError("Unsupported model type. Use 'ARIMA' or 'SARIMA'.")
699
+
700
+ # # Calculate MAE and R2 for evaluation
701
+ # y_true = data['Close'][-10:] # last 10 days as true values for comparison
702
+ # mae = mean_absolute_error(y_true, forecast[:len(y_true)])
703
+ # r2 = r2_score(y_true, forecast[:len(y_true)])
704
+
705
+ # # Return results
706
+ # return forecast, mae, r2
707
+ # except Exception as e:
708
+ # print(f"An error occurred while predicting future prices: {e}")
709
+ # return None, None, None
710
+
requirements.txt ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # yfinance==0.2.28
2
+ # pandas==2.0.3
3
+ # ta==0.9.0
4
+ # scikit-learn==1.3.0
5
+ # xgboost==2.0.0
6
+ # catboost==1.2.0
7
+ # numpy==1.24.2
8
+ # tensorflow==2.13.0
9
+ # streamlit==1.23.0
10
+ # mplfinance
11
+
12
+
13
+ matplotlib
14
+ pandas
15
+ seaborn
16
+ numpy
17
+ mplfinance
18
+ scipy
19
+ yfinance
20
+ scikit-learn
21
+ streamlit
22
+ plotly
23
+ ta
24
+ xgboost
25
+ catboost
26
+ tensorflow
27
+ statsmodels
28
+
29
+
30
+ # streamlit
31
+ # mplfinance
32
+
33
+ # matplotlib==3.7.1
34
+ # pandas==2.0.3
35
+ # seaborn==0.12.2
36
+ # numpy==1.25.2
37
+ # mplfinance==0.12.9b7
38
+ # scipy==1.11.1
39
+ # yfinance==0.2.30
40
+ # scikit-learn==1.3.0
41
+ # streamlit==1.25.0
42
+ # plotly==5.17.0
43
+ # pytorch==2.0.1
44
+
45
+ # transformers==4.30.2
46
+ # chroma==1.2.0
ui.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ from model import fetch_data, calculate_indicators, calculate_support_resistance, predict_future_prices
4
+ from visualizations import (
5
+ plot_stock_price, plot_predictions, plot_technical_indicators, plot_risk_levels,
6
+ plot_feature_importance, plot_candlestick, plot_volume, plot_moving_averages,
7
+ plot_feature_correlations
8
+ )
9
+
10
+ def sidebar():
11
+ st.sidebar.title("Stock Analysis Dashboard")
12
+ ticker = st.sidebar.text_input("Enter Stock Ticker Symbol:", value='SBILIFE.NS')
13
+ start_date = st.sidebar.date_input("Start Date", value=pd.to_datetime('2021-01-01'))
14
+ end_date = st.sidebar.date_input("End Date", value=pd.to_datetime('2024-09-01'))
15
+ algorithm = st.sidebar.selectbox("Select Prediction Algorithm", ['Linear Regression', 'ARIMA','Decision Tree', 'Random Forest', 'XGBoost', 'CatBoost', 'LSTM', 'SARIMA'])
16
+ return ticker, start_date, end_date, algorithm
17
+
18
+ def display_analysis(data, algorithm):
19
+ if data is not None:
20
+ try:
21
+ support_price, resistance_price = calculate_support_resistance(data)
22
+ future_prices, mae, r2, accuracy, conf_matrix = predict_future_prices(data, algorithm)
23
+
24
+ if future_prices is not None:
25
+ st.write("### Technical Indicators")
26
+ indicators = {
27
+ 'SMA_50': data['SMA_50'].iloc[-1],
28
+ 'EMA_50': data['EMA_50'].iloc[-1],
29
+ 'RSI': data['RSI'].iloc[-1],
30
+ 'MACD': data['MACD'].iloc[-1],
31
+ 'MACD_Signal': data['MACD_Signal'].iloc[-1],
32
+ 'Bollinger_High': data['Bollinger_High'].iloc[-1],
33
+ 'Bollinger_Low': data['Bollinger_Low'].iloc[-1],
34
+ 'ATR': data['ATR'].iloc[-1],
35
+ 'OBV': data['OBV'].iloc[-1]
36
+ }
37
+ for key, value in indicators.items():
38
+ with st.expander(f"{key} Description"):
39
+ st.write(f"{key}: {value:.2f}")
40
+ st.write(get_indicator_description(key))
41
+
42
+ st.write("### Support and Resistance Levels")
43
+ st.write(f"Support Price: {support_price:.2f}")
44
+ st.write(f"Resistance Price: {resistance_price:.2f}")
45
+
46
+ st.write("### Future Price Predictions")
47
+ st.write(pd.DataFrame({'Day': range(1, len(future_prices) + 1), 'Predicted Price': future_prices}))
48
+
49
+ if accuracy is not None and conf_matrix is not None:
50
+ st.write(f"**Model Accuracy:** {accuracy:.2f}")
51
+ st.write("**Confusion Matrix:**")
52
+ st.pyplot(ConfusionMatrixDisplay(conf_matrix).plot())
53
+
54
+ if mae is not None and r2 is not None:
55
+ st.write(f"**Mean Absolute Error (MAE):** {mae:.2f}")
56
+ st.write(f"**R-squared (R2):** {r2:.2f}")
57
+ else:
58
+ st.error("Model selection or prediction failed. Please check your inputs and try again.")
59
+ except Exception as e:
60
+ st.error(f"An error occurred during analysis: {e}")
61
+ else:
62
+ st.error("Failed to fetch data. Please check the stock ticker symbol and date range.")
63
+
64
+ def get_indicator_description(indicator):
65
+ descriptions = {
66
+ 'SMA_50': "SMA_50 (50-day Simple Moving Average): Yeh 50 din ka average hai jo bataata hai stock ka long-term trend. Agar yeh price line se upar hai, toh stock ka trend upward hai.",
67
+ 'EMA_50': "EMA_50 (50-day Exponential Moving Average): Yeh bhi ek average hai lekin recent prices ko zyada weightage deta hai. Stock ka short-term trend dikhata hai.",
68
+ 'RSI': "RSI (Relative Strength Index): Yeh indicator stock ke overbought ya oversold condition ko dikhata hai. 70 se zyada overbought, aur 30 se kam oversold hai.",
69
+ 'MACD': "MACD: Yeh indicator short-term aur long-term moving averages ke beech ka difference dikhata hai.",
70
+ 'MACD_Signal': "MACD Signal: Yeh line MACD ke signal ko dikhata hai. Jab MACD line isse cross karti hai, toh trend change hota hai.",
71
+ 'Bollinger_High': "Bollinger High: Yeh line stock price ki upper boundary dikhati hai. Agar price isse upar hai, toh stock overbought ho sakta hai.",
72
+ 'Bollinger_Low': "Bollinger Low: Yeh line stock price ki lower boundary dikhati hai. Agar price isse neeche hai, toh stock oversold ho sakta hai.",
73
+ 'ATR': "ATR (Average True Range): Yeh indicator stock ki volatility dikhata hai. Zyada ATR matlab zyada price fluctuations.",
74
+ 'OBV': "OBV (On-Balance Volume): Yeh volume aur price ke relationship ko dikhata hai. Jab OBV badh raha hai, toh stock ka demand badh raha hai."
75
+ }
76
+ return descriptions.get(indicator, "Description not available")
77
+
78
+ def display_visualizations(data, algorithm):
79
+ if data is not None:
80
+ choice = st.sidebar.selectbox(
81
+ "Choose a type of visualization",
82
+ [
83
+ "Stock Price",
84
+ "Predictions vs Actual",
85
+ "Technical Indicators",
86
+ "Risk Levels",
87
+ "Feature Importance",
88
+ "Candlestick",
89
+ "Volume",
90
+ "Moving Averages",
91
+ "Feature Correlations"
92
+ ]
93
+ )
94
+
95
+ try:
96
+ if choice == "Stock Price":
97
+ plot_stock_price(data)
98
+ elif choice == "Predictions vs Actual":
99
+ future_prices, _, _, _, _ = predict_future_prices(data, algorithm)
100
+ if future_prices is not None:
101
+ st.line_chart(pd.DataFrame({'Actual Prices': data['Close'], 'Predicted Prices': pd.Series(future_prices).values}))
102
+ else:
103
+ st.error("Failed to fetch predictions.")
104
+ elif choice == "Technical Indicators":
105
+ indicators = {
106
+ 'SMA_50': data['SMA_50'],
107
+ 'EMA_50': data['EMA_50'],
108
+ 'RSI': data['RSI'],
109
+ 'MACD': data['MACD'],
110
+ 'MACD_Signal': data['MACD_Signal'],
111
+ 'Bollinger_High': data['Bollinger_High'],
112
+ 'Bollinger_Low': data['Bollinger_Low'],
113
+ 'ATR': data['ATR'],
114
+ 'OBV': data['OBV']
115
+ }
116
+ plot_technical_indicators(data, indicators)
117
+ elif choice == "Risk Levels":
118
+ plot_risk_levels(data)
119
+ elif choice == "Feature Importance":
120
+ plot_feature_importance()
121
+ elif choice == "Candlestick":
122
+ plot_candlestick(data)
123
+ elif choice == "Volume":
124
+ plot_volume(data)
125
+ elif choice == "Moving Averages":
126
+ plot_moving_averages(data)
127
+ elif choice == "Feature Correlations":
128
+ plot_feature_correlations(data)
129
+ except Exception as e:
130
+ st.error(f"An error occurred during visualization: {e}")
131
+ else:
132
+ st.error("Failed to fetch data. Please check the stock ticker symbol and date range.")
visualizations.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import pandas as pd
3
+ import seaborn as sns
4
+ import numpy as np
5
+ import matplotlib.dates as mdates
6
+ from mplfinance.original_flavor import candlestick_ohlc
7
+ import logging
8
+ import plotly.express as px
9
+ import streamlit as st
10
+
11
+ from model import predict_future_prices
12
+
13
+ from logger import get_logger
14
+
15
+ logger = get_logger(__name__)
16
+
17
+
18
+
19
+
20
+ def plot_stock_price(data: pd.DataFrame, ticker: str, indicators: dict = None,
21
+ color='blue', line_style='-', title=None):
22
+ """
23
+ Plot the stock price with optional indicators and customization.
24
+ """
25
+ required_columns = ['Date', 'Close']
26
+ missing_columns = [col for col in required_columns if col not in data.columns]
27
+ if missing_columns:
28
+ logger.error(f"Missing columns in data for plot_stock_price: {', '.join(missing_columns)}")
29
+ raise KeyError(f"Missing columns in data: {', '.join(missing_columns)}")
30
+
31
+ logger.info(f"Plotting stock price for {ticker}.")
32
+
33
+ # Matplotlib Plot
34
+ plt.figure(figsize=(14, 7))
35
+ plt.plot(data['Date'], data['Close'], label='Close Price', color=color, linestyle=line_style)
36
+
37
+ if indicators:
38
+ for name, values in indicators.items():
39
+ plt.plot(data['Date'], values, label=name)
40
+
41
+ plt.title(title if title else f'{ticker} Stock Price')
42
+ plt.xlabel('Date')
43
+ plt.ylabel('Price')
44
+ plt.legend()
45
+ plt.grid(True)
46
+ plt.tight_layout()
47
+ plt.xticks(rotation=45)
48
+
49
+ # Render the plot using Streamlit
50
+ st.pyplot(plt)
51
+
52
+ # Plotly Plot (interactive)
53
+ fig = px.line(data, x='Date', y='Close', title=title if title else f'{ticker} Stock Price')
54
+ if indicators:
55
+ for name, values in indicators.items():
56
+ fig.add_scatter(x=data['Date'], y=values, mode='lines', name=name)
57
+
58
+ # Render the interactive plot using Streamlit
59
+ st.plotly_chart(fig)
60
+
61
+ def plot_predictions(data: pd.DataFrame, predictions: pd.Series, ticker: str,
62
+ actual_color='blue', predicted_color='red', line_style_actual='-', line_style_predicted='--'):
63
+ """
64
+ Plot actual vs predicted stock prices with customization.
65
+ """
66
+ logger.info(f"Plotting actual vs predicted prices for {ticker}.")
67
+
68
+ # Matplotlib Plot
69
+ plt.figure(figsize=(14, 7))
70
+ plt.plot(data['Date'], data['Close'], label='Actual Prices', color=actual_color, linestyle=line_style_actual)
71
+ plt.plot(data['Date'], predictions, label='Predicted Prices', color=predicted_color, linestyle=line_style_predicted)
72
+
73
+ plt.title(f'{ticker} Actual vs Predicted Prices')
74
+ plt.xlabel('Date')
75
+ plt.ylabel('Price')
76
+ plt.legend()
77
+ plt.grid(True)
78
+ plt.tight_layout()
79
+ plt.xticks(rotation=45)
80
+
81
+ # Render the plot using Streamlit
82
+ st.pyplot(plt)
83
+
84
+ # Plotly Plot (interactive)
85
+ fig = px.line(data, x='Date', y='Close', title=f'{ticker} Actual vs Predicted Prices')
86
+ fig.add_scatter(x=data['Date'], y=predictions, mode='lines', name='Predicted Prices', line=dict(color=predicted_color))
87
+
88
+ # Render the interactive plot using Streamlit
89
+ st.plotly_chart(fig)
90
+
91
+ def generate_predictions(model, test_data):
92
+ """
93
+ Generate predictions using the model for the given test data.
94
+ """
95
+ try:
96
+ # Extract relevant features for the model
97
+
98
+ features = test_data[['Open', 'SMA_50', 'EMA_50', 'RSI', 'MACD', 'MACD_Signal', 'Bollinger_High', 'Bollinger_Low', 'ATR', 'OBV']] # Adjust features based on your model
99
+ predictions = model.predict(features)
100
+ return predictions
101
+ except KeyError as e:
102
+ logger.error(f"Feature key error: {e}")
103
+ st.error(f"Feature key error: {e}")
104
+ except Exception as e:
105
+ logger.error(f"An error occurred during prediction: {e}")
106
+ st.error(f"An error occurred during prediction: {e}")
107
+
108
+
109
+
110
+
111
+
112
+ def plot_technical_indicators(data: pd.DataFrame, indicators: dict, model, days=10):
113
+ """
114
+ Plot technical indicators along with the stock price and predictions.
115
+ """
116
+ logger.info("Plotting stock price with technical indicators and predictions.")
117
+
118
+ # Ensure all indicators have the same length as the data
119
+ for name, values in indicators.items():
120
+ if len(values) != len(data):
121
+ logger.error(f"Indicator '{name}' length {len(values)} does not match data length {len(data)}.")
122
+ st.error(f"Indicator '{name}' length {len(values)} does not match data length {len(data)}.")
123
+ return
124
+
125
+ # Generate the last 30 days' dates
126
+ end_date = data['Date'].max()
127
+ start_date = end_date - pd.Timedelta(days=30)
128
+ date_range = pd.date_range(start=start_date, end=end_date, freq='D')
129
+
130
+ # Filter data for the last 30 days
131
+ last_30_days_data = data[data['Date'].isin(date_range)]
132
+
133
+ # Prepare test data for predictions
134
+ test_data = last_30_days_data.copy()
135
+
136
+ # Generate future predictions
137
+ future_prices, _, _, _, _ = predict_future_prices(data, model, days)
138
+
139
+ if future_prices is not None:
140
+ # Generate future dates
141
+ future_dates = pd.date_range(start=end_date + pd.Timedelta(days=1), periods=days, freq='D')
142
+
143
+ # Create a DataFrame for future predictions
144
+ future_df = pd.DataFrame({
145
+ 'Date': future_dates,
146
+ 'Predicted_Close': future_prices
147
+ })
148
+
149
+ # Matplotlib Plot
150
+ plt.figure(figsize=(14, 7))
151
+ plt.plot(data['Date'], data['Close'], label='Close Price', color='blue')
152
+ plt.plot(future_df['Date'], future_df['Predicted_Close'], label='Predicted Price', color='orange', linestyle='--')
153
+
154
+ for name, values in indicators.items():
155
+ plt.plot(data['Date'], values, label=name)
156
+
157
+ plt.title('Stock Price with Technical Indicators and Predictions')
158
+ plt.xlabel('Date')
159
+ plt.ylabel('Price')
160
+ plt.legend()
161
+ plt.grid(True)
162
+ plt.tight_layout()
163
+ plt.xticks(rotation=45)
164
+
165
+ # Render the plot using Streamlit
166
+ st.pyplot(plt)
167
+
168
+ # Plotly Plot (interactive)
169
+ fig = px.line(data, x='Date', y='Close', title='Stock Price with Technical Indicators and Predictions')
170
+ fig.add_scatter(x=future_df['Date'], y=future_df['Predicted_Close'], mode='lines', name='Predicted Price', line=dict(color='orange', dash='dash'))
171
+
172
+ for name, values in indicators.items():
173
+ fig.add_scatter(x=data['Date'], y=values, mode='lines', name=name)
174
+
175
+ # Render the interactive plot using Streamlit
176
+ st.plotly_chart(fig)
177
+ else:
178
+ st.error("No predictions available.")
179
+
180
+
181
+
182
+
183
+
184
+
185
+ def plot_risk_levels(data: pd.DataFrame, risk_levels: pd.Series, cmap='coolwarm'):
186
+ """
187
+ Plot risk levels with stock prices and customization.
188
+ """
189
+ logger.info("Plotting stock prices with risk levels.")
190
+ plt.figure(figsize=(14, 7))
191
+ plt.plot(data['Date'], data['Close'], label='Close Price', color='blue')
192
+ plt.scatter(data['Date'], data['Close'], c=risk_levels, cmap=cmap, label='Risk Levels', alpha=0.7)
193
+
194
+ plt.title('Stock Prices with Risk Levels')
195
+ plt.xlabel('Date')
196
+ plt.ylabel('Price')
197
+ plt.colorbar(label='Risk Level')
198
+ plt.legend()
199
+ plt.grid(True)
200
+ plt.tight_layout()
201
+ plt.xticks(rotation=45)
202
+
203
+ # Render Matplotlib plot using Streamlit
204
+ st.pyplot(plt)
205
+
206
+ # Plotly Plot (interactive)
207
+ fig = px.scatter(data, x='Date', y='Close', color=risk_levels, color_continuous_scale=cmap,
208
+ title='Stock Prices with Risk Levels', labels={'color': 'Risk Level'})
209
+
210
+ # Render the interactive Plotly plot using Streamlit
211
+ st.plotly_chart(fig)
212
+
213
+ def plot_feature_importance(importances: pd.Series, feature_names: list):
214
+ """
215
+ Plot feature importance for machine learning models.
216
+ """
217
+ logger.info("Plotting feature importance.")
218
+ plt.figure(figsize=(10, 6))
219
+ sns.barplot(x=importances, y=feature_names, palette='viridis')
220
+
221
+ plt.title('Feature Importances')
222
+ plt.xlabel('Importance')
223
+ plt.ylabel('Feature')
224
+ plt.grid(True)
225
+ plt.tight_layout()
226
+
227
+ # Render Matplotlib plot using Streamlit
228
+ st.pyplot(plt)
229
+
230
+ # Plotly Plot (interactive)
231
+ fig = px.bar(x=importances, y=feature_names, orientation='h',
232
+ title='Feature Importances', labels={'x': 'Importance', 'y': 'Feature'})
233
+ fig.update_layout(yaxis={'categoryorder':'total ascending'})
234
+
235
+ # Render the interactive Plotly plot using Streamlit
236
+ st.plotly_chart(fig)
237
+
238
+ def plot_candlestick(data: pd.DataFrame, ticker: str):
239
+ """
240
+ Plot candlestick chart for stock prices.
241
+ """
242
+ required_columns = ['Date', 'Open', 'High', 'Low', 'Close']
243
+
244
+ # Check if all required columns are present
245
+ missing_columns = [col for col in required_columns if col not in data.columns]
246
+ if missing_columns:
247
+ logger.error(f"Missing columns in data for plot_candlestick: {', '.join(missing_columns)}")
248
+ raise KeyError(f"Missing columns in data: {', '.join(missing_columns)}")
249
+
250
+ logger.info(f"Plotting candlestick chart for {ticker}.")
251
+ data = data[required_columns]
252
+ data['Date'] = pd.to_datetime(data['Date'])
253
+ data['Date'] = mdates.date2num(data['Date'])
254
+
255
+ fig, ax = plt.subplots(figsize=(14, 7))
256
+ candlestick_ohlc(ax, data.values, width=0.6, colorup='green', colordown='red')
257
+
258
+ ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
259
+ plt.title(f'{ticker} Candlestick Chart')
260
+ plt.xlabel('Date')
261
+ plt.ylabel('Price')
262
+ plt.grid(True)
263
+ plt.xticks(rotation=45)
264
+ plt.tight_layout()
265
+
266
+ # Render Matplotlib plot using Streamlit
267
+ st.pyplot(fig)
268
+
269
+ # Plotly Plot (interactive)
270
+ fig = px.line(data, x='Date', y=['Open', 'High', 'Low', 'Close'],
271
+ title=f'{ticker} Candlestick Chart')
272
+
273
+ # Render the interactive Plotly plot using Streamlit
274
+ st.plotly_chart(fig)
275
+
276
+ def plot_volume(data: pd.DataFrame):
277
+ """
278
+ Plot trading volume alongside stock price.
279
+ """
280
+ logger.info("Plotting stock price and trading volume.")
281
+
282
+ plt.figure(figsize=(14, 7))
283
+ plt.subplot(2, 1, 1)
284
+ plt.plot(data['Date'], data['Close'], label='Close Price', color='blue')
285
+ plt.title('Stock Price and Trading Volume')
286
+ plt.xlabel('Date')
287
+ plt.ylabel('Price')
288
+ plt.legend()
289
+ plt.grid(True)
290
+
291
+ plt.subplot(2, 1, 2)
292
+ plt.bar(data['Date'], data['Volume'], color='grey', alpha=0.5)
293
+ plt.xlabel('Date')
294
+ plt.ylabel('Volume')
295
+
296
+ plt.tight_layout()
297
+ plt.xticks(rotation=45)
298
+
299
+ # Render Matplotlib plot using Streamlit
300
+ st.pyplot(plt)
301
+
302
+ # Plotly Plot (interactive)
303
+ fig = px.bar(data, x='Date', y='Volume', title='Trading Volume',
304
+ labels={'Volume': 'Volume', 'Date': 'Date'})
305
+
306
+ # Render the interactive Plotly plot using Streamlit
307
+ st.plotly_chart(fig)
308
+
309
+ def plot_moving_averages(data: pd.DataFrame, short_window: int = 20, long_window: int = 50):
310
+ """
311
+ Plot moving averages along with the stock price.
312
+ """
313
+ logger.info("Calculating and plotting moving averages.")
314
+ data['Short_MA'] = data['Close'].rolling(window=short_window).mean()
315
+ data['Long_MA'] = data['Close'].rolling(window=long_window).mean()
316
+
317
+ plt.figure(figsize=(14, 7))
318
+ plt.plot(data['Date'], data['Close'], label='Close Price', color='blue')
319
+ plt.plot(data['Date'], data['Short_MA'], label=f'Short {short_window}-day MA', color='orange')
320
+ plt.plot(data['Date'], data['Long_MA'], label=f'Long {long_window}-day MA', color='purple')
321
+
322
+ plt.title('Stock Price with Moving Averages')
323
+ plt.xlabel('Date')
324
+ plt.ylabel('Price')
325
+ plt.legend()
326
+ plt.grid(True)
327
+ plt.tight_layout()
328
+ plt.xticks(rotation=45)
329
+
330
+ # Render Matplotlib plot using Streamlit
331
+ st.pyplot(plt)
332
+
333
+ # Plotly Plot (interactive)
334
+ fig = px.line(data, x='Date', y=['Close', 'Short_MA', 'Long_MA'],
335
+ title='Stock Price with Moving Averages')
336
+
337
+ # Render the interactive Plotly plot using Streamlit
338
+ st.plotly_chart(fig)
339
+
340
+ def plot_feature_correlations(data: pd.DataFrame):
341
+ """
342
+ Plot correlation heatmap of features.
343
+ """
344
+ logger.info("Plotting feature correlations heatmap.")
345
+ plt.figure(figsize=(12, 10))
346
+ correlation_matrix = data.corr()
347
+ sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', fmt='.2f')
348
+
349
+ plt.title('Feature Correlations')
350
+ plt.tight_layout()
351
+
352
+ # Render Matplotlib plot using Streamlit
353
+ st.pyplot(plt)
354
+
355
+ # Plotly Plot (interactive)
356
+ fig = px.imshow(correlation_matrix, text_auto=True,
357
+ title='Feature Correlations', labels={'color': 'Correlation'})
358
+
359
+ # Render the interactive Plotly plot using Streamlit
360
+ st.plotly_chart(fig)