import streamlit as st import streamlit.components.v1 as components import yfinance as yf import pandas as pd import numpy as np from datetime import datetime, timedelta from keras.models import load_model from sklearn.preprocessing import MinMaxScaler import time import os import torch.nn as nn import torch # --- Page Configuration --- st.set_page_config(layout="wide") # --- Streamlit Session State Initialization --- if 'run_button_clicked' not in st.session_state: st.session_state.run_button_clicked = False if 'loading' not in st.session_state: st.session_state.loading = False if 'data' not in st.session_state: st.session_state.data = None if 'predictions' not in st.session_state: st.session_state.predictions = None if 'error' not in st.session_state: st.session_state.error = None if 'last_ticker' not in st.session_state: st.session_state['last_ticker'] = 'AMZN' # --- Custom CSS --- st.markdown(""" """, unsafe_allow_html=True) # --- Python Backend Functions --- # Outside of any function class GRUModel(nn.Module): def __init__(self, input_dim=1, hidden_dim=100, num_layers=2, output_dim=1, dropout_prob=0.2): super(GRUModel, self).__init__() self.gru = nn.GRU(input_dim, hidden_dim, num_layers, batch_first=True, dropout=dropout_prob) self.fc = nn.Linear(hidden_dim, output_dim) def forward(self, x): h0 = torch.zeros(2, x.size(0), 100).to(x.device) out, _ = self.gru(x, h0) return self.fc(out[:, -1, :]) class BiLSTMModel(nn.Module): def __init__(self): super(BiLSTMModel, self).__init__() self.lstm = nn.LSTM(input_size=1, hidden_size=100, num_layers=1, batch_first=True, dropout=0.2, bidirectional=True) self.fc = nn.Linear(200, 1) def forward(self, x): h0 = torch.zeros(2, x.size(0), 100) c0 = torch.zeros(2, x.size(0), 100) out, _ = self.lstm(x, (h0, c0)) return self.fc(out[:, -1, :]) @st.cache_resource(ttl=3600) def load_model_from_disk(path, model_type): model = BiLSTMModel() if model_type == "Bi-Directional LSTM" else GRUModel() state = torch.load(path, map_location=torch.device("cpu")) model.load_state_dict(state['model_state_dict'] if 'model_state_dict' in state else state) model.eval() return model @st.cache_resource(ttl=3600) def load_scripted_model(path): model = torch.jit.load(path, map_location=torch.device("cpu")) model.eval() return model @st.cache_resource def preload_models(): return { "Bi-Directional LSTM": load_scripted_model("bilstm_scriptes.pt"), "Gated Recurrent Unit": load_model_from_disk("best_gru_model.pth", model_type="GRU") } MODELS = preload_models() @st.cache_data(ttl=900) # Cache data for 15 minutes def get_stock_data(ticker): """Fetches historical stock data from Yahoo Finance for the last 4 years.""" end_date = datetime.now() start_date = end_date - timedelta(days=4 * 365) print(f"Fetching data for ticker: {ticker} from {start_date.date()} to {end_date.date()}") data = yf.download(ticker, period="4y", multi_level_index=False) data.to_csv("AMZN_data.csv") if data.empty: print(f"No data found for ticker: {ticker}") return None data.reset_index(inplace=True) print(f"Successfully fetched {len(data)} rows for {ticker}") return data def predict_with_model(data, n_days, model_path, model_type, model=None)-> pd.DataFrame: if model is None: model = load_model_from_disk(model_path, model_type=model_type) close_prices = data['Close'].values.reshape(-1, 1) scaler = MinMaxScaler(feature_range=(0, 1)) scaled_prices = scaler.fit_transform(close_prices) sequence_length = 90 if len(scaled_prices) < sequence_length: raise ValueError(f"Not enough historical data ({len(scaled_prices)} points) to create a sequence of {sequence_length} for prediction.") last_sequence = scaled_prices[-sequence_length:] current_seq = torch.tensor(last_sequence.reshape(1, sequence_length, 1), dtype=torch.float32) predictions_scaled = [] with torch.no_grad(): for _ in range(n_days): pred = model(current_seq) predictions_scaled.append(pred.item()) next_input = pred.view(1, 1, 1) current_seq = torch.cat((current_seq[:, 1:, :], next_input), dim=1) predictions = scaler.inverse_transform(np.array(predictions_scaled).reshape(-1, 1)).flatten() print("predictions",predictions) last_date = pd.to_datetime(data['Date'].iloc[-1]) future_dates = [last_date + timedelta(days=i) for i in range(1, n_days + 1)] prediction_df = pd.DataFrame({'Date': future_dates, 'Predicted Price': predictions}) historical_returns = data['Close'].pct_change().dropna() volatility = historical_returns.std() if not historical_returns.empty else 0.01 error_std_growth = volatility * np.sqrt(np.arange(1, n_days + 1)) prediction_df['Upper CI'] = predictions * (1 + 1.96 * error_std_growth) prediction_df['Lower CI'] = predictions * (1 - 1.96 * error_std_growth) return prediction_df # --- Streamlit Sidebar Controls --- with st.sidebar: st.image("logo2.png", use_container_width=True) st.markdown("Dashboard Controls") ticker = st.text_input("Stock Ticker", st.session_state.get('last_ticker', "AMZN"), disabled=True).upper() model_type = st.selectbox( "Prediction Model", ("Bi-Directional LSTM", "Gated Recurrent Unit"), key="model_choice", help="Select the neural network architecture for prediction." ) prediction_days = st.slider("Prediction Horizon (Days)", 7, 21, st.session_state.get('last_prediction_days', 7)) if st.button("21 days ahead of the market", use_container_width=True, disabled=True): st.session_state.run_button_clicked = True st.session_state.loading = True st.session_state.last_ticker = ticker st.session_state.last_prediction_days = prediction_days st.session_state.error = None print("Generate Dashboard button clicked. Loading state set to True.") st.rerun() # Check if model or prediction days have changed if ( ticker != st.session_state.get('last_ticker', '') or model_type != st.session_state.get('last_model_type', '') or prediction_days != st.session_state.get('last_prediction_days', 7) ): st.session_state.run_button_clicked = True st.session_state.loading = True st.session_state.last_ticker = ticker st.session_state.last_model_type = model_type st.session_state.last_prediction_days = prediction_days # --- Main Application Logic --- if st.session_state.run_button_clicked: model = MODELS[model_type] print(f"Inside main logic block. Current loading state: {st.session_state.loading}") try: if os.path.exists("AMZN_data.csv"): st.session_state.data = pd.read_csv("AMZN_data.csv") else: st.session_state.data = get_stock_data(ticker) if st.session_state.data is None: st.session_state.error = f"Could not fetch data for ticker '{ticker}'. It may be an invalid symbol or network issue." else: model_path = "best_bilstm_model.pth" if model_type == "Bi-Directional LSTM" else "best_gru_model.pth" st.session_state.predictions = predict_with_model( st.session_state.data, prediction_days, model_path=None, model_type=model_type, model=model ) print("model",model) print("data", st.session_state.data) st.session_state.error = None except FileNotFoundError as e: st.session_state.error = str(e) print(f"Caught FileNotFoundError: {e}") except ValueError as e: st.session_state.error = str(e) print(f"Caught ValueError: {e}") except Exception as e: st.session_state.error = f"An unexpected error occurred: {str(e)}" print(f"Caught general Exception: {e}") st.session_state.loading = False st.session_state.run_button_clicked = False print(f"Processing complete. Loading state set to False. Error: {st.session_state.error}") st.rerun() # --- Data Preparation for Front-End --- historical_data_json = 'null' prediction_data_json = 'null' is_loading_js = str(st.session_state.get('loading', False)).lower() error_message_js = 'null' if st.session_state.get('error'): error_message_js = f"'{st.session_state.error}'" # Pass error to JS if st.session_state.data is not None and st.session_state.get('error') is None: historical_data_json = st.session_state.data.to_json(orient='split', date_format='iso') prediction_data_json = st.session_state.predictions.to_json(orient='split', date_format='iso') # --- HTML Front-End --- html_code = f"""