danielle2003 commited on
Commit
d396b0d
·
verified ·
1 Parent(s): a0ce4fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -54
app.py CHANGED
@@ -110,55 +110,41 @@ st.markdown("""
110
  """, unsafe_allow_html=True)
111
 
112
  # --- Python Backend Functions ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  @st.cache_resource(ttl=3600)
115
- def load_pytorch_model(path, model_type='Bi-Directional LSTM', input_dim=1, hidden_dim=100, num_layers=2, output_dim=1, dropout_prob=0.2):
116
- import torch.nn as nn
117
- import torch
118
-
119
- class GRUModel(nn.Module):
120
- def __init__(self):
121
- super(GRUModel, self).__init__()
122
- self.gru = nn.GRU(input_dim, hidden_dim, num_layers, batch_first=True, dropout=dropout_prob)
123
- self.fc = nn.Linear(hidden_dim, output_dim)
124
-
125
- def forward(self, x):
126
- h0 = torch.zeros(num_layers, x.size(0), hidden_dim).to(x.device)
127
- out, _ = self.gru(x, h0)
128
- return self.fc(out[:, -1, :])
129
-
130
- class BiLSTMModel(nn.Module):
131
- def __init__(self):
132
- super(BiLSTMModel, self).__init__()
133
- self.lstm = nn.LSTM(
134
- input_size=1,
135
- hidden_size=100,
136
- num_layers=1, # <- match saved model
137
- batch_first=True,
138
- dropout=0.2,
139
- bidirectional=True
140
- )
141
- self.fc = nn.Linear(200, 1) # 2 * hidden_size because of bidirectional
142
-
143
- def forward(self, x):
144
- h0 = torch.zeros(2 * 1, x.size(0), 100)
145
- c0 = torch.zeros(2 * 1, x.size(0), 100)
146
- out, _ = self.lstm(x, (h0, c0))
147
- return self.fc(out[:, -1, :])
148
- model_class = BiLSTMModel if model_type == 'Bi-Directional LSTM' else GRUModel
149
- model = model_class()
150
-
151
- checkpoint = torch.load(path, map_location=torch.device('cpu'))
152
-
153
- # If full checkpoint was saved with keys like 'model_state_dict'
154
- if 'model_state_dict' in checkpoint:
155
- model.load_state_dict(checkpoint['model_state_dict'])
156
- else:
157
- model.load_state_dict(checkpoint) # Just raw state_dict
158
-
159
  model.eval()
160
  return model
161
-
162
  @st.cache_data(ttl=900) # Cache data for 15 minutes
163
  def get_stock_data(ticker):
164
  """Fetches historical stock data from Yahoo Finance for the last 4 years."""
@@ -175,14 +161,8 @@ def get_stock_data(ticker):
175
  print(f"Successfully fetched {len(data)} rows for {ticker}")
176
  return data
177
 
178
- def predict_with_model(data: pd.DataFrame, n_days: int, model_path: str, model_type: str) -> pd.DataFrame:
179
- import torch
180
 
181
- try:
182
- model = load_pytorch_model(model_path, model_type=model_type)
183
- except FileNotFoundError as e:
184
- raise e
185
- print("model:",model)
186
  close_prices = data['Close'].values.reshape(-1, 1)
187
  scaler = MinMaxScaler(feature_range=(0, 1))
188
  scaled_prices = scaler.fit_transform(close_prices)
@@ -253,9 +233,15 @@ with st.sidebar:
253
  st.session_state.last_ticker = ticker
254
  st.session_state.last_model_type = model_type
255
  st.session_state.last_prediction_days = prediction_days
256
- st.rerun()
257
  # --- Main Application Logic ---
258
  if st.session_state.run_button_clicked:
 
 
 
 
 
 
 
259
  print(f"Inside main logic block. Current loading state: {st.session_state.loading}")
260
  try:
261
  if os.path.exists("AMZN_data.csv"):
@@ -268,7 +254,7 @@ if st.session_state.run_button_clicked:
268
  st.session_state.error = f"Could not fetch data for ticker '{ticker}'. It may be an invalid symbol or network issue."
269
  else:
270
  model_path = "best_bilstm_model.pth" if model_type == "Bi-Directional LSTM" else "best_gru_model.pth"
271
- st.session_state.predictions = predict_with_model(st.session_state.data, prediction_days, model_path,model_type)
272
  st.session_state.error = None
273
 
274
  except FileNotFoundError as e:
 
110
  """, unsafe_allow_html=True)
111
 
112
  # --- Python Backend Functions ---
113
+ # Outside of any function
114
+ import torch.nn as nn
115
+ import torch
116
+
117
+ class GRUModel(nn.Module):
118
+ def __init__(self, input_dim=1, hidden_dim=100, num_layers=2, output_dim=1, dropout_prob=0.2):
119
+ super(GRUModel, self).__init__()
120
+ self.gru = nn.GRU(input_dim, hidden_dim, num_layers, batch_first=True, dropout=dropout_prob)
121
+ self.fc = nn.Linear(hidden_dim, output_dim)
122
+
123
+ def forward(self, x):
124
+ h0 = torch.zeros(2, x.size(0), 100).to(x.device)
125
+ out, _ = self.gru(x, h0)
126
+ return self.fc(out[:, -1, :])
127
+
128
+ class BiLSTMModel(nn.Module):
129
+ def __init__(self):
130
+ super(BiLSTMModel, self).__init__()
131
+ self.lstm = nn.LSTM(input_size=1, hidden_size=100, num_layers=1, batch_first=True, dropout=0.2, bidirectional=True)
132
+ self.fc = nn.Linear(200, 1)
133
+
134
+ def forward(self, x):
135
+ h0 = torch.zeros(2, x.size(0), 100)
136
+ c0 = torch.zeros(2, x.size(0), 100)
137
+ out, _ = self.lstm(x, (h0, c0))
138
+ return self.fc(out[:, -1, :])
139
 
140
  @st.cache_resource(ttl=3600)
141
+ def load_model_from_disk(path, model_type):
142
+ model = BiLSTMModel() if model_type == "Bi-Directional LSTM" else GRUModel()
143
+ state = torch.load(path, map_location=torch.device("cpu"))
144
+ model.load_state_dict(state['model_state_dict'] if 'model_state_dict' in state else state)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  model.eval()
146
  return model
147
+
148
  @st.cache_data(ttl=900) # Cache data for 15 minutes
149
  def get_stock_data(ticker):
150
  """Fetches historical stock data from Yahoo Finance for the last 4 years."""
 
161
  print(f"Successfully fetched {len(data)} rows for {ticker}")
162
  return data
163
 
164
+ def predict_with_model(data: pd.DataFrame, n_days: int, model, model_type: str) -> pd.DataFrame:
 
165
 
 
 
 
 
 
166
  close_prices = data['Close'].values.reshape(-1, 1)
167
  scaler = MinMaxScaler(feature_range=(0, 1))
168
  scaled_prices = scaler.fit_transform(close_prices)
 
233
  st.session_state.last_ticker = ticker
234
  st.session_state.last_model_type = model_type
235
  st.session_state.last_prediction_days = prediction_days
 
236
  # --- Main Application Logic ---
237
  if st.session_state.run_button_clicked:
238
+ model_key = "bilstm_model" if model_type == "Bi-Directional LSTM" else "gru_model"
239
+
240
+ if model_key not in st.session_state:
241
+ model_path = "best_bilstm_model.pth" if model_type == "Bi-Directional LSTM" else "best_gru_model.pth"
242
+ st.session_state[model_key] = load_model_from_disk(model_path, model_type)
243
+
244
+ model = st.session_state[model_key]
245
  print(f"Inside main logic block. Current loading state: {st.session_state.loading}")
246
  try:
247
  if os.path.exists("AMZN_data.csv"):
 
254
  st.session_state.error = f"Could not fetch data for ticker '{ticker}'. It may be an invalid symbol or network issue."
255
  else:
256
  model_path = "best_bilstm_model.pth" if model_type == "Bi-Directional LSTM" else "best_gru_model.pth"
257
+ st.session_state.predictions = predict_with_model(st.session_state.data, prediction_days, model, model_type)
258
  st.session_state.error = None
259
 
260
  except FileNotFoundError as e: