|
import matplotlib.pyplot as plt
|
|
import pandas as pd
|
|
import seaborn as sns
|
|
import numpy as np
|
|
import matplotlib.dates as mdates
|
|
from mplfinance.original_flavor import candlestick_ohlc
|
|
import logging
|
|
import plotly.express as px
|
|
import streamlit as st
|
|
|
|
from model import predict_future_prices
|
|
|
|
from logger import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
|
|
|
|
def plot_stock_price(data: pd.DataFrame, ticker: str, indicators: dict = None,
|
|
color='blue', line_style='-', title=None):
|
|
"""
|
|
Plot the stock price with optional indicators and customization.
|
|
"""
|
|
required_columns = ['Date', 'Close']
|
|
missing_columns = [col for col in required_columns if col not in data.columns]
|
|
if missing_columns:
|
|
logger.error(f"Missing columns in data for plot_stock_price: {', '.join(missing_columns)}")
|
|
raise KeyError(f"Missing columns in data: {', '.join(missing_columns)}")
|
|
|
|
logger.info(f"Plotting stock price for {ticker}.")
|
|
|
|
|
|
plt.figure(figsize=(14, 7))
|
|
plt.plot(data['Date'], data['Close'], label='Close Price', color=color, linestyle=line_style)
|
|
|
|
if indicators:
|
|
for name, values in indicators.items():
|
|
plt.plot(data['Date'], values, label=name)
|
|
|
|
plt.title(title if title else f'{ticker} Stock Price')
|
|
plt.xlabel('Date')
|
|
plt.ylabel('Price')
|
|
plt.legend()
|
|
plt.grid(True)
|
|
plt.tight_layout()
|
|
plt.xticks(rotation=45)
|
|
|
|
|
|
st.pyplot(plt)
|
|
|
|
|
|
fig = px.line(data, x='Date', y='Close', title=title if title else f'{ticker} Stock Price')
|
|
if indicators:
|
|
for name, values in indicators.items():
|
|
fig.add_scatter(x=data['Date'], y=values, mode='lines', name=name)
|
|
|
|
|
|
st.plotly_chart(fig)
|
|
|
|
def plot_predictions(data: pd.DataFrame, predictions: pd.Series, ticker: str,
|
|
actual_color='blue', predicted_color='red', line_style_actual='-', line_style_predicted='--'):
|
|
"""
|
|
Plot actual vs predicted stock prices with customization.
|
|
"""
|
|
logger.info(f"Plotting actual vs predicted prices for {ticker}.")
|
|
|
|
|
|
plt.figure(figsize=(14, 7))
|
|
plt.plot(data['Date'], data['Close'], label='Actual Prices', color=actual_color, linestyle=line_style_actual)
|
|
plt.plot(data['Date'], predictions, label='Predicted Prices', color=predicted_color, linestyle=line_style_predicted)
|
|
|
|
plt.title(f'{ticker} Actual vs Predicted Prices')
|
|
plt.xlabel('Date')
|
|
plt.ylabel('Price')
|
|
plt.legend()
|
|
plt.grid(True)
|
|
plt.tight_layout()
|
|
plt.xticks(rotation=45)
|
|
|
|
|
|
st.pyplot(plt)
|
|
|
|
|
|
fig = px.line(data, x='Date', y='Close', title=f'{ticker} Actual vs Predicted Prices')
|
|
fig.add_scatter(x=data['Date'], y=predictions, mode='lines', name='Predicted Prices', line=dict(color=predicted_color))
|
|
|
|
|
|
st.plotly_chart(fig)
|
|
|
|
def generate_predictions(model, test_data):
|
|
"""
|
|
Generate predictions using the model for the given test data.
|
|
"""
|
|
try:
|
|
|
|
|
|
features = test_data[['Open', 'SMA_50', 'EMA_50', 'RSI', 'MACD', 'MACD_Signal', 'Bollinger_High', 'Bollinger_Low', 'ATR', 'OBV']]
|
|
predictions = model.predict(features)
|
|
return predictions
|
|
except KeyError as e:
|
|
logger.error(f"Feature key error: {e}")
|
|
st.error(f"Feature key error: {e}")
|
|
except Exception as e:
|
|
logger.error(f"An error occurred during prediction: {e}")
|
|
st.error(f"An error occurred during prediction: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
def plot_technical_indicators(data: pd.DataFrame, indicators: dict, model, days=10):
|
|
"""
|
|
Plot technical indicators along with the stock price and predictions.
|
|
"""
|
|
logger.info("Plotting stock price with technical indicators and predictions.")
|
|
|
|
|
|
for name, values in indicators.items():
|
|
if len(values) != len(data):
|
|
logger.error(f"Indicator '{name}' length {len(values)} does not match data length {len(data)}.")
|
|
st.error(f"Indicator '{name}' length {len(values)} does not match data length {len(data)}.")
|
|
return
|
|
|
|
|
|
end_date = data['Date'].max()
|
|
start_date = end_date - pd.Timedelta(days=30)
|
|
date_range = pd.date_range(start=start_date, end=end_date, freq='D')
|
|
|
|
|
|
last_30_days_data = data[data['Date'].isin(date_range)]
|
|
|
|
|
|
test_data = last_30_days_data.copy()
|
|
|
|
|
|
future_prices, _, _, _, _ = predict_future_prices(data, model, days)
|
|
|
|
if future_prices is not None:
|
|
|
|
future_dates = pd.date_range(start=end_date + pd.Timedelta(days=1), periods=days, freq='D')
|
|
|
|
|
|
future_df = pd.DataFrame({
|
|
'Date': future_dates,
|
|
'Predicted_Close': future_prices
|
|
})
|
|
|
|
|
|
plt.figure(figsize=(14, 7))
|
|
plt.plot(data['Date'], data['Close'], label='Close Price', color='blue')
|
|
plt.plot(future_df['Date'], future_df['Predicted_Close'], label='Predicted Price', color='orange', linestyle='--')
|
|
|
|
for name, values in indicators.items():
|
|
plt.plot(data['Date'], values, label=name)
|
|
|
|
plt.title('Stock Price with Technical Indicators and Predictions')
|
|
plt.xlabel('Date')
|
|
plt.ylabel('Price')
|
|
plt.legend()
|
|
plt.grid(True)
|
|
plt.tight_layout()
|
|
plt.xticks(rotation=45)
|
|
|
|
|
|
st.pyplot(plt)
|
|
|
|
|
|
fig = px.line(data, x='Date', y='Close', title='Stock Price with Technical Indicators and Predictions')
|
|
fig.add_scatter(x=future_df['Date'], y=future_df['Predicted_Close'], mode='lines', name='Predicted Price', line=dict(color='orange', dash='dash'))
|
|
|
|
for name, values in indicators.items():
|
|
fig.add_scatter(x=data['Date'], y=values, mode='lines', name=name)
|
|
|
|
|
|
st.plotly_chart(fig)
|
|
else:
|
|
st.error("No predictions available.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_risk_levels(data: pd.DataFrame, risk_levels: pd.Series, cmap='coolwarm'):
|
|
"""
|
|
Plot risk levels with stock prices and customization.
|
|
"""
|
|
logger.info("Plotting stock prices with risk levels.")
|
|
plt.figure(figsize=(14, 7))
|
|
plt.plot(data['Date'], data['Close'], label='Close Price', color='blue')
|
|
plt.scatter(data['Date'], data['Close'], c=risk_levels, cmap=cmap, label='Risk Levels', alpha=0.7)
|
|
|
|
plt.title('Stock Prices with Risk Levels')
|
|
plt.xlabel('Date')
|
|
plt.ylabel('Price')
|
|
plt.colorbar(label='Risk Level')
|
|
plt.legend()
|
|
plt.grid(True)
|
|
plt.tight_layout()
|
|
plt.xticks(rotation=45)
|
|
|
|
|
|
st.pyplot(plt)
|
|
|
|
|
|
fig = px.scatter(data, x='Date', y='Close', color=risk_levels, color_continuous_scale=cmap,
|
|
title='Stock Prices with Risk Levels', labels={'color': 'Risk Level'})
|
|
|
|
|
|
st.plotly_chart(fig)
|
|
|
|
def plot_feature_importance(importances: pd.Series, feature_names: list):
|
|
"""
|
|
Plot feature importance for machine learning models.
|
|
"""
|
|
logger.info("Plotting feature importance.")
|
|
plt.figure(figsize=(10, 6))
|
|
sns.barplot(x=importances, y=feature_names, palette='viridis')
|
|
|
|
plt.title('Feature Importances')
|
|
plt.xlabel('Importance')
|
|
plt.ylabel('Feature')
|
|
plt.grid(True)
|
|
plt.tight_layout()
|
|
|
|
|
|
st.pyplot(plt)
|
|
|
|
|
|
fig = px.bar(x=importances, y=feature_names, orientation='h',
|
|
title='Feature Importances', labels={'x': 'Importance', 'y': 'Feature'})
|
|
fig.update_layout(yaxis={'categoryorder':'total ascending'})
|
|
|
|
|
|
st.plotly_chart(fig)
|
|
|
|
def plot_candlestick(data: pd.DataFrame, ticker: str):
|
|
"""
|
|
Plot candlestick chart for stock prices.
|
|
"""
|
|
required_columns = ['Date', 'Open', 'High', 'Low', 'Close']
|
|
|
|
|
|
missing_columns = [col for col in required_columns if col not in data.columns]
|
|
if missing_columns:
|
|
logger.error(f"Missing columns in data for plot_candlestick: {', '.join(missing_columns)}")
|
|
raise KeyError(f"Missing columns in data: {', '.join(missing_columns)}")
|
|
|
|
logger.info(f"Plotting candlestick chart for {ticker}.")
|
|
data = data[required_columns]
|
|
data['Date'] = pd.to_datetime(data['Date'])
|
|
data['Date'] = mdates.date2num(data['Date'])
|
|
|
|
fig, ax = plt.subplots(figsize=(14, 7))
|
|
candlestick_ohlc(ax, data.values, width=0.6, colorup='green', colordown='red')
|
|
|
|
ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
|
|
plt.title(f'{ticker} Candlestick Chart')
|
|
plt.xlabel('Date')
|
|
plt.ylabel('Price')
|
|
plt.grid(True)
|
|
plt.xticks(rotation=45)
|
|
plt.tight_layout()
|
|
|
|
|
|
st.pyplot(fig)
|
|
|
|
|
|
fig = px.line(data, x='Date', y=['Open', 'High', 'Low', 'Close'],
|
|
title=f'{ticker} Candlestick Chart')
|
|
|
|
|
|
st.plotly_chart(fig)
|
|
|
|
def plot_volume(data: pd.DataFrame):
|
|
"""
|
|
Plot trading volume alongside stock price.
|
|
"""
|
|
logger.info("Plotting stock price and trading volume.")
|
|
|
|
plt.figure(figsize=(14, 7))
|
|
plt.subplot(2, 1, 1)
|
|
plt.plot(data['Date'], data['Close'], label='Close Price', color='blue')
|
|
plt.title('Stock Price and Trading Volume')
|
|
plt.xlabel('Date')
|
|
plt.ylabel('Price')
|
|
plt.legend()
|
|
plt.grid(True)
|
|
|
|
plt.subplot(2, 1, 2)
|
|
plt.bar(data['Date'], data['Volume'], color='grey', alpha=0.5)
|
|
plt.xlabel('Date')
|
|
plt.ylabel('Volume')
|
|
|
|
plt.tight_layout()
|
|
plt.xticks(rotation=45)
|
|
|
|
|
|
st.pyplot(plt)
|
|
|
|
|
|
fig = px.bar(data, x='Date', y='Volume', title='Trading Volume',
|
|
labels={'Volume': 'Volume', 'Date': 'Date'})
|
|
|
|
|
|
st.plotly_chart(fig)
|
|
|
|
def plot_moving_averages(data: pd.DataFrame, short_window: int = 20, long_window: int = 50):
|
|
"""
|
|
Plot moving averages along with the stock price.
|
|
"""
|
|
logger.info("Calculating and plotting moving averages.")
|
|
data['Short_MA'] = data['Close'].rolling(window=short_window).mean()
|
|
data['Long_MA'] = data['Close'].rolling(window=long_window).mean()
|
|
|
|
plt.figure(figsize=(14, 7))
|
|
plt.plot(data['Date'], data['Close'], label='Close Price', color='blue')
|
|
plt.plot(data['Date'], data['Short_MA'], label=f'Short {short_window}-day MA', color='orange')
|
|
plt.plot(data['Date'], data['Long_MA'], label=f'Long {long_window}-day MA', color='purple')
|
|
|
|
plt.title('Stock Price with Moving Averages')
|
|
plt.xlabel('Date')
|
|
plt.ylabel('Price')
|
|
plt.legend()
|
|
plt.grid(True)
|
|
plt.tight_layout()
|
|
plt.xticks(rotation=45)
|
|
|
|
|
|
st.pyplot(plt)
|
|
|
|
|
|
fig = px.line(data, x='Date', y=['Close', 'Short_MA', 'Long_MA'],
|
|
title='Stock Price with Moving Averages')
|
|
|
|
|
|
st.plotly_chart(fig)
|
|
|
|
def plot_feature_correlations(data: pd.DataFrame):
|
|
"""
|
|
Plot correlation heatmap of features.
|
|
"""
|
|
logger.info("Plotting feature correlations heatmap.")
|
|
plt.figure(figsize=(12, 10))
|
|
correlation_matrix = data.corr()
|
|
sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', fmt='.2f')
|
|
|
|
plt.title('Feature Correlations')
|
|
plt.tight_layout()
|
|
|
|
|
|
st.pyplot(plt)
|
|
|
|
|
|
fig = px.imshow(correlation_matrix, text_auto=True,
|
|
title='Feature Correlations', labels={'color': 'Correlation'})
|
|
|
|
|
|
st.plotly_chart(fig)
|
|
|