Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -126,17 +126,23 @@ class GRUModel(nn.Module):
|
|
126 |
return self.fc(out[:, -1, :])
|
127 |
|
128 |
class BiLSTMModel(nn.Module):
|
129 |
-
def __init__(self):
|
130 |
super(BiLSTMModel, self).__init__()
|
131 |
-
self.lstm = nn.LSTM(
|
132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
|
134 |
def forward(self, x):
|
135 |
-
h0 = torch.zeros(2, x.size(0),
|
136 |
-
c0 = torch.zeros(2, x.size(0),
|
137 |
out, _ = self.lstm(x, (h0, c0))
|
138 |
return self.fc(out[:, -1, :])
|
139 |
-
|
140 |
@st.cache_resource(ttl=3600)
|
141 |
def load_model_from_disk(path, model_type):
|
142 |
model = BiLSTMModel() if model_type == "Bi-Directional LSTM" else GRUModel()
|
@@ -145,13 +151,14 @@ def load_model_from_disk(path, model_type):
|
|
145 |
model.eval()
|
146 |
return model
|
147 |
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
"Bi-Directional LSTM": load_model_from_disk("best_bilstm_model.pth", "Bi-Directional LSTM"),
|
152 |
-
"Gated Recurrent Unit (GRU)": load_model_from_disk("best_gru_model.pth", "GRU")
|
153 |
}
|
154 |
|
|
|
155 |
@st.cache_data(ttl=900) # Cache data for 15 minutes
|
156 |
def get_stock_data(ticker):
|
157 |
"""Fetches historical stock data from Yahoo Finance for the last 4 years."""
|
@@ -243,7 +250,7 @@ with st.sidebar:
|
|
243 |
st.session_state.last_prediction_days = prediction_days
|
244 |
# --- Main Application Logic ---
|
245 |
if st.session_state.run_button_clicked:
|
246 |
-
model =
|
247 |
|
248 |
|
249 |
print(f"Inside main logic block. Current loading state: {st.session_state.loading}")
|
|
|
126 |
return self.fc(out[:, -1, :])
|
127 |
|
128 |
class BiLSTMModel(nn.Module):
|
129 |
+
def __init__(self, input_dim=1, hidden_dim=100, num_layers=2, output_dim=1, dropout_prob=0.2):
|
130 |
super(BiLSTMModel, self).__init__()
|
131 |
+
self.lstm = nn.LSTM(
|
132 |
+
input_size=input_dim,
|
133 |
+
hidden_size=hidden_dim,
|
134 |
+
num_layers=num_layers,
|
135 |
+
batch_first=True,
|
136 |
+
dropout=dropout_prob,
|
137 |
+
bidirectional=True
|
138 |
+
)
|
139 |
+
self.fc = nn.Linear(hidden_dim * 2, output_dim) # because bidirectional
|
140 |
|
141 |
def forward(self, x):
|
142 |
+
h0 = torch.zeros(self.lstm.num_layers * 2, x.size(0), self.lstm.hidden_size).to(x.device)
|
143 |
+
c0 = torch.zeros(self.lstm.num_layers * 2, x.size(0), self.lstm.hidden_size).to(x.device)
|
144 |
out, _ = self.lstm(x, (h0, c0))
|
145 |
return self.fc(out[:, -1, :])
|
|
|
146 |
@st.cache_resource(ttl=3600)
|
147 |
def load_model_from_disk(path, model_type):
|
148 |
model = BiLSTMModel() if model_type == "Bi-Directional LSTM" else GRUModel()
|
|
|
151 |
model.eval()
|
152 |
return model
|
153 |
|
154 |
+
@st.cache_resource
|
155 |
+
def preload_models():
|
156 |
+
return {
|
157 |
+
"Bi-Directional LSTM": load_model_from_disk("best_bilstm_model.pth", model_type="Bi-Directional LSTM"),
|
158 |
+
"Gated Recurrent Unit (GRU)": load_model_from_disk("best_gru_model.pth", model_type="GRU")
|
159 |
}
|
160 |
|
161 |
+
MODELS = preload_models()
|
162 |
@st.cache_data(ttl=900) # Cache data for 15 minutes
|
163 |
def get_stock_data(ticker):
|
164 |
"""Fetches historical stock data from Yahoo Finance for the last 4 years."""
|
|
|
250 |
st.session_state.last_prediction_days = prediction_days
|
251 |
# --- Main Application Logic ---
|
252 |
if st.session_state.run_button_clicked:
|
253 |
+
model = MODELS[model_type]
|
254 |
|
255 |
|
256 |
print(f"Inside main logic block. Current loading state: {st.session_state.loading}")
|