danielle2003 commited on
Commit
6537369
·
verified ·
1 Parent(s): d396b0d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -12
app.py CHANGED
@@ -8,6 +8,8 @@ from keras.models import load_model
8
  from sklearn.preprocessing import MinMaxScaler
9
  import time
10
  import os
 
 
11
 
12
  # --- Page Configuration ---
13
  st.set_page_config(layout="wide")
@@ -111,8 +113,6 @@ st.markdown("""
111
 
112
  # --- Python Backend Functions ---
113
  # Outside of any function
114
- import torch.nn as nn
115
- import torch
116
 
117
  class GRUModel(nn.Module):
118
  def __init__(self, input_dim=1, hidden_dim=100, num_layers=2, output_dim=1, dropout_prob=0.2):
@@ -145,6 +145,13 @@ def load_model_from_disk(path, model_type):
145
  model.eval()
146
  return model
147
 
 
 
 
 
 
 
 
148
  @st.cache_data(ttl=900) # Cache data for 15 minutes
149
  def get_stock_data(ticker):
150
  """Fetches historical stock data from Yahoo Finance for the last 4 years."""
@@ -161,8 +168,9 @@ def get_stock_data(ticker):
161
  print(f"Successfully fetched {len(data)} rows for {ticker}")
162
  return data
163
 
164
- def predict_with_model(data: pd.DataFrame, n_days: int, model, model_type: str) -> pd.DataFrame:
165
-
 
166
  close_prices = data['Close'].values.reshape(-1, 1)
167
  scaler = MinMaxScaler(feature_range=(0, 1))
168
  scaled_prices = scaler.fit_transform(close_prices)
@@ -235,13 +243,9 @@ with st.sidebar:
235
  st.session_state.last_prediction_days = prediction_days
236
  # --- Main Application Logic ---
237
  if st.session_state.run_button_clicked:
238
- model_key = "bilstm_model" if model_type == "Bi-Directional LSTM" else "gru_model"
239
-
240
- if model_key not in st.session_state:
241
- model_path = "best_bilstm_model.pth" if model_type == "Bi-Directional LSTM" else "best_gru_model.pth"
242
- st.session_state[model_key] = load_model_from_disk(model_path, model_type)
243
-
244
- model = st.session_state[model_key]
245
  print(f"Inside main logic block. Current loading state: {st.session_state.loading}")
246
  try:
247
  if os.path.exists("AMZN_data.csv"):
@@ -254,7 +258,11 @@ if st.session_state.run_button_clicked:
254
  st.session_state.error = f"Could not fetch data for ticker '{ticker}'. It may be an invalid symbol or network issue."
255
  else:
256
  model_path = "best_bilstm_model.pth" if model_type == "Bi-Directional LSTM" else "best_gru_model.pth"
257
- st.session_state.predictions = predict_with_model(st.session_state.data, prediction_days, model, model_type)
 
 
 
 
258
  st.session_state.error = None
259
 
260
  except FileNotFoundError as e:
 
8
  from sklearn.preprocessing import MinMaxScaler
9
  import time
10
  import os
11
+ import torch.nn as nn
12
+ import torch
13
 
14
  # --- Page Configuration ---
15
  st.set_page_config(layout="wide")
 
113
 
114
  # --- Python Backend Functions ---
115
  # Outside of any function
 
 
116
 
117
  class GRUModel(nn.Module):
118
  def __init__(self, input_dim=1, hidden_dim=100, num_layers=2, output_dim=1, dropout_prob=0.2):
 
145
  model.eval()
146
  return model
147
 
148
+
149
+ if "models" not in st.session_state:
150
+ st.session_state.models = {
151
+ "Bi-Directional LSTM": load_model_from_disk("best_bilstm_model.pth", "Bi-Directional LSTM"),
152
+ "Gated Recurrent Unit (GRU)": load_model_from_disk("best_gru_model.pth", "GRU")
153
+ }
154
+
155
  @st.cache_data(ttl=900) # Cache data for 15 minutes
156
  def get_stock_data(ticker):
157
  """Fetches historical stock data from Yahoo Finance for the last 4 years."""
 
168
  print(f"Successfully fetched {len(data)} rows for {ticker}")
169
  return data
170
 
171
+ def predict_with_model(data, n_days, model_path, model_type, model=None)-> pd.DataFrame:
172
+ if model is None:
173
+ model = load_model_from_disk(model_path, model_type=model_type)
174
  close_prices = data['Close'].values.reshape(-1, 1)
175
  scaler = MinMaxScaler(feature_range=(0, 1))
176
  scaled_prices = scaler.fit_transform(close_prices)
 
243
  st.session_state.last_prediction_days = prediction_days
244
  # --- Main Application Logic ---
245
  if st.session_state.run_button_clicked:
246
+ model = st.session_state.models[model_type]
247
+
248
+
 
 
 
 
249
  print(f"Inside main logic block. Current loading state: {st.session_state.loading}")
250
  try:
251
  if os.path.exists("AMZN_data.csv"):
 
258
  st.session_state.error = f"Could not fetch data for ticker '{ticker}'. It may be an invalid symbol or network issue."
259
  else:
260
  model_path = "best_bilstm_model.pth" if model_type == "Bi-Directional LSTM" else "best_gru_model.pth"
261
+ st.session_state.predictions = predict_with_model(
262
+ st.session_state.data, prediction_days, model_path=None, model_type=model_type, model=model
263
+ )
264
+ print("model",model)
265
+ print("data", st.session_state.data)
266
  st.session_state.error = None
267
 
268
  except FileNotFoundError as e: