Indian_Stock_analysis / visualizations.py
NandanData's picture
Upload 9 files
42c2fbe verified
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.dates as mdates
from mplfinance.original_flavor import candlestick_ohlc
import logging
import plotly.express as px
import streamlit as st
from model import predict_future_prices
from logger import get_logger
logger = get_logger(__name__)
def plot_stock_price(data: pd.DataFrame, ticker: str, indicators: dict = None,
color='blue', line_style='-', title=None):
"""
Plot the stock price with optional indicators and customization.
"""
required_columns = ['Date', 'Close']
missing_columns = [col for col in required_columns if col not in data.columns]
if missing_columns:
logger.error(f"Missing columns in data for plot_stock_price: {', '.join(missing_columns)}")
raise KeyError(f"Missing columns in data: {', '.join(missing_columns)}")
logger.info(f"Plotting stock price for {ticker}.")
# Matplotlib Plot
plt.figure(figsize=(14, 7))
plt.plot(data['Date'], data['Close'], label='Close Price', color=color, linestyle=line_style)
if indicators:
for name, values in indicators.items():
plt.plot(data['Date'], values, label=name)
plt.title(title if title else f'{ticker} Stock Price')
plt.xlabel('Date')
plt.ylabel('Price')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.xticks(rotation=45)
# Render the plot using Streamlit
st.pyplot(plt)
# Plotly Plot (interactive)
fig = px.line(data, x='Date', y='Close', title=title if title else f'{ticker} Stock Price')
if indicators:
for name, values in indicators.items():
fig.add_scatter(x=data['Date'], y=values, mode='lines', name=name)
# Render the interactive plot using Streamlit
st.plotly_chart(fig)
def plot_predictions(data: pd.DataFrame, predictions: pd.Series, ticker: str,
actual_color='blue', predicted_color='red', line_style_actual='-', line_style_predicted='--'):
"""
Plot actual vs predicted stock prices with customization.
"""
logger.info(f"Plotting actual vs predicted prices for {ticker}.")
# Matplotlib Plot
plt.figure(figsize=(14, 7))
plt.plot(data['Date'], data['Close'], label='Actual Prices', color=actual_color, linestyle=line_style_actual)
plt.plot(data['Date'], predictions, label='Predicted Prices', color=predicted_color, linestyle=line_style_predicted)
plt.title(f'{ticker} Actual vs Predicted Prices')
plt.xlabel('Date')
plt.ylabel('Price')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.xticks(rotation=45)
# Render the plot using Streamlit
st.pyplot(plt)
# Plotly Plot (interactive)
fig = px.line(data, x='Date', y='Close', title=f'{ticker} Actual vs Predicted Prices')
fig.add_scatter(x=data['Date'], y=predictions, mode='lines', name='Predicted Prices', line=dict(color=predicted_color))
# Render the interactive plot using Streamlit
st.plotly_chart(fig)
def generate_predictions(model, test_data):
"""
Generate predictions using the model for the given test data.
"""
try:
# Extract relevant features for the model
features = test_data[['Open', 'SMA_50', 'EMA_50', 'RSI', 'MACD', 'MACD_Signal', 'Bollinger_High', 'Bollinger_Low', 'ATR', 'OBV']] # Adjust features based on your model
predictions = model.predict(features)
return predictions
except KeyError as e:
logger.error(f"Feature key error: {e}")
st.error(f"Feature key error: {e}")
except Exception as e:
logger.error(f"An error occurred during prediction: {e}")
st.error(f"An error occurred during prediction: {e}")
def plot_technical_indicators(data: pd.DataFrame, indicators: dict, model, days=10):
"""
Plot technical indicators along with the stock price and predictions.
"""
logger.info("Plotting stock price with technical indicators and predictions.")
# Ensure all indicators have the same length as the data
for name, values in indicators.items():
if len(values) != len(data):
logger.error(f"Indicator '{name}' length {len(values)} does not match data length {len(data)}.")
st.error(f"Indicator '{name}' length {len(values)} does not match data length {len(data)}.")
return
# Generate the last 30 days' dates
end_date = data['Date'].max()
start_date = end_date - pd.Timedelta(days=30)
date_range = pd.date_range(start=start_date, end=end_date, freq='D')
# Filter data for the last 30 days
last_30_days_data = data[data['Date'].isin(date_range)]
# Prepare test data for predictions
test_data = last_30_days_data.copy()
# Generate future predictions
future_prices, _, _, _, _ = predict_future_prices(data, model, days)
if future_prices is not None:
# Generate future dates
future_dates = pd.date_range(start=end_date + pd.Timedelta(days=1), periods=days, freq='D')
# Create a DataFrame for future predictions
future_df = pd.DataFrame({
'Date': future_dates,
'Predicted_Close': future_prices
})
# Matplotlib Plot
plt.figure(figsize=(14, 7))
plt.plot(data['Date'], data['Close'], label='Close Price', color='blue')
plt.plot(future_df['Date'], future_df['Predicted_Close'], label='Predicted Price', color='orange', linestyle='--')
for name, values in indicators.items():
plt.plot(data['Date'], values, label=name)
plt.title('Stock Price with Technical Indicators and Predictions')
plt.xlabel('Date')
plt.ylabel('Price')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.xticks(rotation=45)
# Render the plot using Streamlit
st.pyplot(plt)
# Plotly Plot (interactive)
fig = px.line(data, x='Date', y='Close', title='Stock Price with Technical Indicators and Predictions')
fig.add_scatter(x=future_df['Date'], y=future_df['Predicted_Close'], mode='lines', name='Predicted Price', line=dict(color='orange', dash='dash'))
for name, values in indicators.items():
fig.add_scatter(x=data['Date'], y=values, mode='lines', name=name)
# Render the interactive plot using Streamlit
st.plotly_chart(fig)
else:
st.error("No predictions available.")
def plot_risk_levels(data: pd.DataFrame, risk_levels: pd.Series, cmap='coolwarm'):
"""
Plot risk levels with stock prices and customization.
"""
logger.info("Plotting stock prices with risk levels.")
plt.figure(figsize=(14, 7))
plt.plot(data['Date'], data['Close'], label='Close Price', color='blue')
plt.scatter(data['Date'], data['Close'], c=risk_levels, cmap=cmap, label='Risk Levels', alpha=0.7)
plt.title('Stock Prices with Risk Levels')
plt.xlabel('Date')
plt.ylabel('Price')
plt.colorbar(label='Risk Level')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.xticks(rotation=45)
# Render Matplotlib plot using Streamlit
st.pyplot(plt)
# Plotly Plot (interactive)
fig = px.scatter(data, x='Date', y='Close', color=risk_levels, color_continuous_scale=cmap,
title='Stock Prices with Risk Levels', labels={'color': 'Risk Level'})
# Render the interactive Plotly plot using Streamlit
st.plotly_chart(fig)
def plot_feature_importance(importances: pd.Series, feature_names: list):
"""
Plot feature importance for machine learning models.
"""
logger.info("Plotting feature importance.")
plt.figure(figsize=(10, 6))
sns.barplot(x=importances, y=feature_names, palette='viridis')
plt.title('Feature Importances')
plt.xlabel('Importance')
plt.ylabel('Feature')
plt.grid(True)
plt.tight_layout()
# Render Matplotlib plot using Streamlit
st.pyplot(plt)
# Plotly Plot (interactive)
fig = px.bar(x=importances, y=feature_names, orientation='h',
title='Feature Importances', labels={'x': 'Importance', 'y': 'Feature'})
fig.update_layout(yaxis={'categoryorder':'total ascending'})
# Render the interactive Plotly plot using Streamlit
st.plotly_chart(fig)
def plot_candlestick(data: pd.DataFrame, ticker: str):
"""
Plot candlestick chart for stock prices.
"""
required_columns = ['Date', 'Open', 'High', 'Low', 'Close']
# Check if all required columns are present
missing_columns = [col for col in required_columns if col not in data.columns]
if missing_columns:
logger.error(f"Missing columns in data for plot_candlestick: {', '.join(missing_columns)}")
raise KeyError(f"Missing columns in data: {', '.join(missing_columns)}")
logger.info(f"Plotting candlestick chart for {ticker}.")
data = data[required_columns]
data['Date'] = pd.to_datetime(data['Date'])
data['Date'] = mdates.date2num(data['Date'])
fig, ax = plt.subplots(figsize=(14, 7))
candlestick_ohlc(ax, data.values, width=0.6, colorup='green', colordown='red')
ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
plt.title(f'{ticker} Candlestick Chart')
plt.xlabel('Date')
plt.ylabel('Price')
plt.grid(True)
plt.xticks(rotation=45)
plt.tight_layout()
# Render Matplotlib plot using Streamlit
st.pyplot(fig)
# Plotly Plot (interactive)
fig = px.line(data, x='Date', y=['Open', 'High', 'Low', 'Close'],
title=f'{ticker} Candlestick Chart')
# Render the interactive Plotly plot using Streamlit
st.plotly_chart(fig)
def plot_volume(data: pd.DataFrame):
"""
Plot trading volume alongside stock price.
"""
logger.info("Plotting stock price and trading volume.")
plt.figure(figsize=(14, 7))
plt.subplot(2, 1, 1)
plt.plot(data['Date'], data['Close'], label='Close Price', color='blue')
plt.title('Stock Price and Trading Volume')
plt.xlabel('Date')
plt.ylabel('Price')
plt.legend()
plt.grid(True)
plt.subplot(2, 1, 2)
plt.bar(data['Date'], data['Volume'], color='grey', alpha=0.5)
plt.xlabel('Date')
plt.ylabel('Volume')
plt.tight_layout()
plt.xticks(rotation=45)
# Render Matplotlib plot using Streamlit
st.pyplot(plt)
# Plotly Plot (interactive)
fig = px.bar(data, x='Date', y='Volume', title='Trading Volume',
labels={'Volume': 'Volume', 'Date': 'Date'})
# Render the interactive Plotly plot using Streamlit
st.plotly_chart(fig)
def plot_moving_averages(data: pd.DataFrame, short_window: int = 20, long_window: int = 50):
"""
Plot moving averages along with the stock price.
"""
logger.info("Calculating and plotting moving averages.")
data['Short_MA'] = data['Close'].rolling(window=short_window).mean()
data['Long_MA'] = data['Close'].rolling(window=long_window).mean()
plt.figure(figsize=(14, 7))
plt.plot(data['Date'], data['Close'], label='Close Price', color='blue')
plt.plot(data['Date'], data['Short_MA'], label=f'Short {short_window}-day MA', color='orange')
plt.plot(data['Date'], data['Long_MA'], label=f'Long {long_window}-day MA', color='purple')
plt.title('Stock Price with Moving Averages')
plt.xlabel('Date')
plt.ylabel('Price')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.xticks(rotation=45)
# Render Matplotlib plot using Streamlit
st.pyplot(plt)
# Plotly Plot (interactive)
fig = px.line(data, x='Date', y=['Close', 'Short_MA', 'Long_MA'],
title='Stock Price with Moving Averages')
# Render the interactive Plotly plot using Streamlit
st.plotly_chart(fig)
def plot_feature_correlations(data: pd.DataFrame):
"""
Plot correlation heatmap of features.
"""
logger.info("Plotting feature correlations heatmap.")
plt.figure(figsize=(12, 10))
correlation_matrix = data.corr()
sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', fmt='.2f')
plt.title('Feature Correlations')
plt.tight_layout()
# Render Matplotlib plot using Streamlit
st.pyplot(plt)
# Plotly Plot (interactive)
fig = px.imshow(correlation_matrix, text_auto=True,
title='Feature Correlations', labels={'color': 'Correlation'})
# Render the interactive Plotly plot using Streamlit
st.plotly_chart(fig)