danielle2003 commited on
Commit
38b16ca
·
verified ·
1 Parent(s): 6537369

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -12
app.py CHANGED
@@ -126,17 +126,23 @@ class GRUModel(nn.Module):
126
  return self.fc(out[:, -1, :])
127
 
128
  class BiLSTMModel(nn.Module):
129
- def __init__(self):
130
  super(BiLSTMModel, self).__init__()
131
- self.lstm = nn.LSTM(input_size=1, hidden_size=100, num_layers=1, batch_first=True, dropout=0.2, bidirectional=True)
132
- self.fc = nn.Linear(200, 1)
 
 
 
 
 
 
 
133
 
134
  def forward(self, x):
135
- h0 = torch.zeros(2, x.size(0), 100)
136
- c0 = torch.zeros(2, x.size(0), 100)
137
  out, _ = self.lstm(x, (h0, c0))
138
  return self.fc(out[:, -1, :])
139
-
140
  @st.cache_resource(ttl=3600)
141
  def load_model_from_disk(path, model_type):
142
  model = BiLSTMModel() if model_type == "Bi-Directional LSTM" else GRUModel()
@@ -145,13 +151,14 @@ def load_model_from_disk(path, model_type):
145
  model.eval()
146
  return model
147
 
148
-
149
- if "models" not in st.session_state:
150
- st.session_state.models = {
151
- "Bi-Directional LSTM": load_model_from_disk("best_bilstm_model.pth", "Bi-Directional LSTM"),
152
- "Gated Recurrent Unit (GRU)": load_model_from_disk("best_gru_model.pth", "GRU")
153
  }
154
 
 
155
  @st.cache_data(ttl=900) # Cache data for 15 minutes
156
  def get_stock_data(ticker):
157
  """Fetches historical stock data from Yahoo Finance for the last 4 years."""
@@ -243,7 +250,7 @@ with st.sidebar:
243
  st.session_state.last_prediction_days = prediction_days
244
  # --- Main Application Logic ---
245
  if st.session_state.run_button_clicked:
246
- model = st.session_state.models[model_type]
247
 
248
 
249
  print(f"Inside main logic block. Current loading state: {st.session_state.loading}")
 
126
  return self.fc(out[:, -1, :])
127
 
128
  class BiLSTMModel(nn.Module):
129
+ def __init__(self, input_dim=1, hidden_dim=100, num_layers=2, output_dim=1, dropout_prob=0.2):
130
  super(BiLSTMModel, self).__init__()
131
+ self.lstm = nn.LSTM(
132
+ input_size=input_dim,
133
+ hidden_size=hidden_dim,
134
+ num_layers=num_layers,
135
+ batch_first=True,
136
+ dropout=dropout_prob,
137
+ bidirectional=True
138
+ )
139
+ self.fc = nn.Linear(hidden_dim * 2, output_dim) # because bidirectional
140
 
141
  def forward(self, x):
142
+ h0 = torch.zeros(self.lstm.num_layers * 2, x.size(0), self.lstm.hidden_size).to(x.device)
143
+ c0 = torch.zeros(self.lstm.num_layers * 2, x.size(0), self.lstm.hidden_size).to(x.device)
144
  out, _ = self.lstm(x, (h0, c0))
145
  return self.fc(out[:, -1, :])
 
146
  @st.cache_resource(ttl=3600)
147
  def load_model_from_disk(path, model_type):
148
  model = BiLSTMModel() if model_type == "Bi-Directional LSTM" else GRUModel()
 
151
  model.eval()
152
  return model
153
 
154
+ @st.cache_resource
155
+ def preload_models():
156
+ return {
157
+ "Bi-Directional LSTM": load_model_from_disk("best_bilstm_model.pth", model_type="Bi-Directional LSTM"),
158
+ "Gated Recurrent Unit (GRU)": load_model_from_disk("best_gru_model.pth", model_type="GRU")
159
  }
160
 
161
+ MODELS = preload_models()
162
  @st.cache_data(ttl=900) # Cache data for 15 minutes
163
  def get_stock_data(ticker):
164
  """Fetches historical stock data from Yahoo Finance for the last 4 years."""
 
250
  st.session_state.last_prediction_days = prediction_days
251
  # --- Main Application Logic ---
252
  if st.session_state.run_button_clicked:
253
+ model = MODELS[model_type]
254
 
255
 
256
  print(f"Inside main logic block. Current loading state: {st.session_state.loading}")