Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -8,6 +8,8 @@ from keras.models import load_model
|
|
| 8 |
from sklearn.preprocessing import MinMaxScaler
|
| 9 |
import time
|
| 10 |
import os
|
|
|
|
|
|
|
| 11 |
|
| 12 |
# --- Page Configuration ---
|
| 13 |
st.set_page_config(layout="wide")
|
|
@@ -111,8 +113,6 @@ st.markdown("""
|
|
| 111 |
|
| 112 |
# --- Python Backend Functions ---
|
| 113 |
# Outside of any function
|
| 114 |
-
import torch.nn as nn
|
| 115 |
-
import torch
|
| 116 |
|
| 117 |
class GRUModel(nn.Module):
|
| 118 |
def __init__(self, input_dim=1, hidden_dim=100, num_layers=2, output_dim=1, dropout_prob=0.2):
|
|
@@ -145,6 +145,13 @@ def load_model_from_disk(path, model_type):
|
|
| 145 |
model.eval()
|
| 146 |
return model
|
| 147 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
@st.cache_data(ttl=900) # Cache data for 15 minutes
|
| 149 |
def get_stock_data(ticker):
|
| 150 |
"""Fetches historical stock data from Yahoo Finance for the last 4 years."""
|
|
@@ -161,8 +168,9 @@ def get_stock_data(ticker):
|
|
| 161 |
print(f"Successfully fetched {len(data)} rows for {ticker}")
|
| 162 |
return data
|
| 163 |
|
| 164 |
-
def predict_with_model(data
|
| 165 |
-
|
|
|
|
| 166 |
close_prices = data['Close'].values.reshape(-1, 1)
|
| 167 |
scaler = MinMaxScaler(feature_range=(0, 1))
|
| 168 |
scaled_prices = scaler.fit_transform(close_prices)
|
|
@@ -235,13 +243,9 @@ with st.sidebar:
|
|
| 235 |
st.session_state.last_prediction_days = prediction_days
|
| 236 |
# --- Main Application Logic ---
|
| 237 |
if st.session_state.run_button_clicked:
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
model_path = "best_bilstm_model.pth" if model_type == "Bi-Directional LSTM" else "best_gru_model.pth"
|
| 242 |
-
st.session_state[model_key] = load_model_from_disk(model_path, model_type)
|
| 243 |
-
|
| 244 |
-
model = st.session_state[model_key]
|
| 245 |
print(f"Inside main logic block. Current loading state: {st.session_state.loading}")
|
| 246 |
try:
|
| 247 |
if os.path.exists("AMZN_data.csv"):
|
|
@@ -254,7 +258,11 @@ if st.session_state.run_button_clicked:
|
|
| 254 |
st.session_state.error = f"Could not fetch data for ticker '{ticker}'. It may be an invalid symbol or network issue."
|
| 255 |
else:
|
| 256 |
model_path = "best_bilstm_model.pth" if model_type == "Bi-Directional LSTM" else "best_gru_model.pth"
|
| 257 |
-
st.session_state.predictions = predict_with_model(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
st.session_state.error = None
|
| 259 |
|
| 260 |
except FileNotFoundError as e:
|
|
|
|
| 8 |
from sklearn.preprocessing import MinMaxScaler
|
| 9 |
import time
|
| 10 |
import os
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch
|
| 13 |
|
| 14 |
# --- Page Configuration ---
|
| 15 |
st.set_page_config(layout="wide")
|
|
|
|
| 113 |
|
| 114 |
# --- Python Backend Functions ---
|
| 115 |
# Outside of any function
|
|
|
|
|
|
|
| 116 |
|
| 117 |
class GRUModel(nn.Module):
|
| 118 |
def __init__(self, input_dim=1, hidden_dim=100, num_layers=2, output_dim=1, dropout_prob=0.2):
|
|
|
|
| 145 |
model.eval()
|
| 146 |
return model
|
| 147 |
|
| 148 |
+
|
| 149 |
+
if "models" not in st.session_state:
|
| 150 |
+
st.session_state.models = {
|
| 151 |
+
"Bi-Directional LSTM": load_model_from_disk("best_bilstm_model.pth", "Bi-Directional LSTM"),
|
| 152 |
+
"Gated Recurrent Unit (GRU)": load_model_from_disk("best_gru_model.pth", "GRU")
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
@st.cache_data(ttl=900) # Cache data for 15 minutes
|
| 156 |
def get_stock_data(ticker):
|
| 157 |
"""Fetches historical stock data from Yahoo Finance for the last 4 years."""
|
|
|
|
| 168 |
print(f"Successfully fetched {len(data)} rows for {ticker}")
|
| 169 |
return data
|
| 170 |
|
| 171 |
+
def predict_with_model(data, n_days, model_path, model_type, model=None)-> pd.DataFrame:
|
| 172 |
+
if model is None:
|
| 173 |
+
model = load_model_from_disk(model_path, model_type=model_type)
|
| 174 |
close_prices = data['Close'].values.reshape(-1, 1)
|
| 175 |
scaler = MinMaxScaler(feature_range=(0, 1))
|
| 176 |
scaled_prices = scaler.fit_transform(close_prices)
|
|
|
|
| 243 |
st.session_state.last_prediction_days = prediction_days
|
| 244 |
# --- Main Application Logic ---
|
| 245 |
if st.session_state.run_button_clicked:
|
| 246 |
+
model = st.session_state.models[model_type]
|
| 247 |
+
|
| 248 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
print(f"Inside main logic block. Current loading state: {st.session_state.loading}")
|
| 250 |
try:
|
| 251 |
if os.path.exists("AMZN_data.csv"):
|
|
|
|
| 258 |
st.session_state.error = f"Could not fetch data for ticker '{ticker}'. It may be an invalid symbol or network issue."
|
| 259 |
else:
|
| 260 |
model_path = "best_bilstm_model.pth" if model_type == "Bi-Directional LSTM" else "best_gru_model.pth"
|
| 261 |
+
st.session_state.predictions = predict_with_model(
|
| 262 |
+
st.session_state.data, prediction_days, model_path=None, model_type=model_type, model=model
|
| 263 |
+
)
|
| 264 |
+
print("model",model)
|
| 265 |
+
print("data", st.session_state.data)
|
| 266 |
st.session_state.error = None
|
| 267 |
|
| 268 |
except FileNotFoundError as e:
|