import matplotlib.pyplot as plt import pandas as pd import seaborn as sns import numpy as np import matplotlib.dates as mdates from mplfinance.original_flavor import candlestick_ohlc import logging import plotly.express as px import streamlit as st from model import predict_future_prices from logger import get_logger logger = get_logger(__name__) def plot_stock_price(data: pd.DataFrame, ticker: str, indicators: dict = None, color='blue', line_style='-', title=None): """ Plot the stock price with optional indicators and customization. """ required_columns = ['Date', 'Close'] missing_columns = [col for col in required_columns if col not in data.columns] if missing_columns: logger.error(f"Missing columns in data for plot_stock_price: {', '.join(missing_columns)}") raise KeyError(f"Missing columns in data: {', '.join(missing_columns)}") logger.info(f"Plotting stock price for {ticker}.") # Matplotlib Plot plt.figure(figsize=(14, 7)) plt.plot(data['Date'], data['Close'], label='Close Price', color=color, linestyle=line_style) if indicators: for name, values in indicators.items(): plt.plot(data['Date'], values, label=name) plt.title(title if title else f'{ticker} Stock Price') plt.xlabel('Date') plt.ylabel('Price') plt.legend() plt.grid(True) plt.tight_layout() plt.xticks(rotation=45) # Render the plot using Streamlit st.pyplot(plt) # Plotly Plot (interactive) fig = px.line(data, x='Date', y='Close', title=title if title else f'{ticker} Stock Price') if indicators: for name, values in indicators.items(): fig.add_scatter(x=data['Date'], y=values, mode='lines', name=name) # Render the interactive plot using Streamlit st.plotly_chart(fig) def plot_predictions(data: pd.DataFrame, predictions: pd.Series, ticker: str, actual_color='blue', predicted_color='red', line_style_actual='-', line_style_predicted='--'): """ Plot actual vs predicted stock prices with customization. """ logger.info(f"Plotting actual vs predicted prices for {ticker}.") # Matplotlib Plot plt.figure(figsize=(14, 7)) plt.plot(data['Date'], data['Close'], label='Actual Prices', color=actual_color, linestyle=line_style_actual) plt.plot(data['Date'], predictions, label='Predicted Prices', color=predicted_color, linestyle=line_style_predicted) plt.title(f'{ticker} Actual vs Predicted Prices') plt.xlabel('Date') plt.ylabel('Price') plt.legend() plt.grid(True) plt.tight_layout() plt.xticks(rotation=45) # Render the plot using Streamlit st.pyplot(plt) # Plotly Plot (interactive) fig = px.line(data, x='Date', y='Close', title=f'{ticker} Actual vs Predicted Prices') fig.add_scatter(x=data['Date'], y=predictions, mode='lines', name='Predicted Prices', line=dict(color=predicted_color)) # Render the interactive plot using Streamlit st.plotly_chart(fig) def generate_predictions(model, test_data): """ Generate predictions using the model for the given test data. """ try: # Extract relevant features for the model features = test_data[['Open', 'SMA_50', 'EMA_50', 'RSI', 'MACD', 'MACD_Signal', 'Bollinger_High', 'Bollinger_Low', 'ATR', 'OBV']] # Adjust features based on your model predictions = model.predict(features) return predictions except KeyError as e: logger.error(f"Feature key error: {e}") st.error(f"Feature key error: {e}") except Exception as e: logger.error(f"An error occurred during prediction: {e}") st.error(f"An error occurred during prediction: {e}") def plot_technical_indicators(data: pd.DataFrame, indicators: dict, model, days=10): """ Plot technical indicators along with the stock price and predictions. """ logger.info("Plotting stock price with technical indicators and predictions.") # Ensure all indicators have the same length as the data for name, values in indicators.items(): if len(values) != len(data): logger.error(f"Indicator '{name}' length {len(values)} does not match data length {len(data)}.") st.error(f"Indicator '{name}' length {len(values)} does not match data length {len(data)}.") return # Generate the last 30 days' dates end_date = data['Date'].max() start_date = end_date - pd.Timedelta(days=30) date_range = pd.date_range(start=start_date, end=end_date, freq='D') # Filter data for the last 30 days last_30_days_data = data[data['Date'].isin(date_range)] # Prepare test data for predictions test_data = last_30_days_data.copy() # Generate future predictions future_prices, _, _, _, _ = predict_future_prices(data, model, days) if future_prices is not None: # Generate future dates future_dates = pd.date_range(start=end_date + pd.Timedelta(days=1), periods=days, freq='D') # Create a DataFrame for future predictions future_df = pd.DataFrame({ 'Date': future_dates, 'Predicted_Close': future_prices }) # Matplotlib Plot plt.figure(figsize=(14, 7)) plt.plot(data['Date'], data['Close'], label='Close Price', color='blue') plt.plot(future_df['Date'], future_df['Predicted_Close'], label='Predicted Price', color='orange', linestyle='--') for name, values in indicators.items(): plt.plot(data['Date'], values, label=name) plt.title('Stock Price with Technical Indicators and Predictions') plt.xlabel('Date') plt.ylabel('Price') plt.legend() plt.grid(True) plt.tight_layout() plt.xticks(rotation=45) # Render the plot using Streamlit st.pyplot(plt) # Plotly Plot (interactive) fig = px.line(data, x='Date', y='Close', title='Stock Price with Technical Indicators and Predictions') fig.add_scatter(x=future_df['Date'], y=future_df['Predicted_Close'], mode='lines', name='Predicted Price', line=dict(color='orange', dash='dash')) for name, values in indicators.items(): fig.add_scatter(x=data['Date'], y=values, mode='lines', name=name) # Render the interactive plot using Streamlit st.plotly_chart(fig) else: st.error("No predictions available.") def plot_risk_levels(data: pd.DataFrame, risk_levels: pd.Series, cmap='coolwarm'): """ Plot risk levels with stock prices and customization. """ logger.info("Plotting stock prices with risk levels.") plt.figure(figsize=(14, 7)) plt.plot(data['Date'], data['Close'], label='Close Price', color='blue') plt.scatter(data['Date'], data['Close'], c=risk_levels, cmap=cmap, label='Risk Levels', alpha=0.7) plt.title('Stock Prices with Risk Levels') plt.xlabel('Date') plt.ylabel('Price') plt.colorbar(label='Risk Level') plt.legend() plt.grid(True) plt.tight_layout() plt.xticks(rotation=45) # Render Matplotlib plot using Streamlit st.pyplot(plt) # Plotly Plot (interactive) fig = px.scatter(data, x='Date', y='Close', color=risk_levels, color_continuous_scale=cmap, title='Stock Prices with Risk Levels', labels={'color': 'Risk Level'}) # Render the interactive Plotly plot using Streamlit st.plotly_chart(fig) def plot_feature_importance(importances: pd.Series, feature_names: list): """ Plot feature importance for machine learning models. """ logger.info("Plotting feature importance.") plt.figure(figsize=(10, 6)) sns.barplot(x=importances, y=feature_names, palette='viridis') plt.title('Feature Importances') plt.xlabel('Importance') plt.ylabel('Feature') plt.grid(True) plt.tight_layout() # Render Matplotlib plot using Streamlit st.pyplot(plt) # Plotly Plot (interactive) fig = px.bar(x=importances, y=feature_names, orientation='h', title='Feature Importances', labels={'x': 'Importance', 'y': 'Feature'}) fig.update_layout(yaxis={'categoryorder':'total ascending'}) # Render the interactive Plotly plot using Streamlit st.plotly_chart(fig) def plot_candlestick(data: pd.DataFrame, ticker: str): """ Plot candlestick chart for stock prices. """ required_columns = ['Date', 'Open', 'High', 'Low', 'Close'] # Check if all required columns are present missing_columns = [col for col in required_columns if col not in data.columns] if missing_columns: logger.error(f"Missing columns in data for plot_candlestick: {', '.join(missing_columns)}") raise KeyError(f"Missing columns in data: {', '.join(missing_columns)}") logger.info(f"Plotting candlestick chart for {ticker}.") data = data[required_columns] data['Date'] = pd.to_datetime(data['Date']) data['Date'] = mdates.date2num(data['Date']) fig, ax = plt.subplots(figsize=(14, 7)) candlestick_ohlc(ax, data.values, width=0.6, colorup='green', colordown='red') ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d')) plt.title(f'{ticker} Candlestick Chart') plt.xlabel('Date') plt.ylabel('Price') plt.grid(True) plt.xticks(rotation=45) plt.tight_layout() # Render Matplotlib plot using Streamlit st.pyplot(fig) # Plotly Plot (interactive) fig = px.line(data, x='Date', y=['Open', 'High', 'Low', 'Close'], title=f'{ticker} Candlestick Chart') # Render the interactive Plotly plot using Streamlit st.plotly_chart(fig) def plot_volume(data: pd.DataFrame): """ Plot trading volume alongside stock price. """ logger.info("Plotting stock price and trading volume.") plt.figure(figsize=(14, 7)) plt.subplot(2, 1, 1) plt.plot(data['Date'], data['Close'], label='Close Price', color='blue') plt.title('Stock Price and Trading Volume') plt.xlabel('Date') plt.ylabel('Price') plt.legend() plt.grid(True) plt.subplot(2, 1, 2) plt.bar(data['Date'], data['Volume'], color='grey', alpha=0.5) plt.xlabel('Date') plt.ylabel('Volume') plt.tight_layout() plt.xticks(rotation=45) # Render Matplotlib plot using Streamlit st.pyplot(plt) # Plotly Plot (interactive) fig = px.bar(data, x='Date', y='Volume', title='Trading Volume', labels={'Volume': 'Volume', 'Date': 'Date'}) # Render the interactive Plotly plot using Streamlit st.plotly_chart(fig) def plot_moving_averages(data: pd.DataFrame, short_window: int = 20, long_window: int = 50): """ Plot moving averages along with the stock price. """ logger.info("Calculating and plotting moving averages.") data['Short_MA'] = data['Close'].rolling(window=short_window).mean() data['Long_MA'] = data['Close'].rolling(window=long_window).mean() plt.figure(figsize=(14, 7)) plt.plot(data['Date'], data['Close'], label='Close Price', color='blue') plt.plot(data['Date'], data['Short_MA'], label=f'Short {short_window}-day MA', color='orange') plt.plot(data['Date'], data['Long_MA'], label=f'Long {long_window}-day MA', color='purple') plt.title('Stock Price with Moving Averages') plt.xlabel('Date') plt.ylabel('Price') plt.legend() plt.grid(True) plt.tight_layout() plt.xticks(rotation=45) # Render Matplotlib plot using Streamlit st.pyplot(plt) # Plotly Plot (interactive) fig = px.line(data, x='Date', y=['Close', 'Short_MA', 'Long_MA'], title='Stock Price with Moving Averages') # Render the interactive Plotly plot using Streamlit st.plotly_chart(fig) def plot_feature_correlations(data: pd.DataFrame): """ Plot correlation heatmap of features. """ logger.info("Plotting feature correlations heatmap.") plt.figure(figsize=(12, 10)) correlation_matrix = data.corr() sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', fmt='.2f') plt.title('Feature Correlations') plt.tight_layout() # Render Matplotlib plot using Streamlit st.pyplot(plt) # Plotly Plot (interactive) fig = px.imshow(correlation_matrix, text_auto=True, title='Feature Correlations', labels={'color': 'Correlation'}) # Render the interactive Plotly plot using Streamlit st.plotly_chart(fig)