Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -110,55 +110,41 @@ st.markdown("""
|
|
110 |
""", unsafe_allow_html=True)
|
111 |
|
112 |
# --- Python Backend Functions ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
|
114 |
@st.cache_resource(ttl=3600)
|
115 |
-
def
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
class GRUModel(nn.Module):
|
120 |
-
def __init__(self):
|
121 |
-
super(GRUModel, self).__init__()
|
122 |
-
self.gru = nn.GRU(input_dim, hidden_dim, num_layers, batch_first=True, dropout=dropout_prob)
|
123 |
-
self.fc = nn.Linear(hidden_dim, output_dim)
|
124 |
-
|
125 |
-
def forward(self, x):
|
126 |
-
h0 = torch.zeros(num_layers, x.size(0), hidden_dim).to(x.device)
|
127 |
-
out, _ = self.gru(x, h0)
|
128 |
-
return self.fc(out[:, -1, :])
|
129 |
-
|
130 |
-
class BiLSTMModel(nn.Module):
|
131 |
-
def __init__(self):
|
132 |
-
super(BiLSTMModel, self).__init__()
|
133 |
-
self.lstm = nn.LSTM(
|
134 |
-
input_size=1,
|
135 |
-
hidden_size=100,
|
136 |
-
num_layers=1, # <- match saved model
|
137 |
-
batch_first=True,
|
138 |
-
dropout=0.2,
|
139 |
-
bidirectional=True
|
140 |
-
)
|
141 |
-
self.fc = nn.Linear(200, 1) # 2 * hidden_size because of bidirectional
|
142 |
-
|
143 |
-
def forward(self, x):
|
144 |
-
h0 = torch.zeros(2 * 1, x.size(0), 100)
|
145 |
-
c0 = torch.zeros(2 * 1, x.size(0), 100)
|
146 |
-
out, _ = self.lstm(x, (h0, c0))
|
147 |
-
return self.fc(out[:, -1, :])
|
148 |
-
model_class = BiLSTMModel if model_type == 'Bi-Directional LSTM' else GRUModel
|
149 |
-
model = model_class()
|
150 |
-
|
151 |
-
checkpoint = torch.load(path, map_location=torch.device('cpu'))
|
152 |
-
|
153 |
-
# If full checkpoint was saved with keys like 'model_state_dict'
|
154 |
-
if 'model_state_dict' in checkpoint:
|
155 |
-
model.load_state_dict(checkpoint['model_state_dict'])
|
156 |
-
else:
|
157 |
-
model.load_state_dict(checkpoint) # Just raw state_dict
|
158 |
-
|
159 |
model.eval()
|
160 |
return model
|
161 |
-
|
162 |
@st.cache_data(ttl=900) # Cache data for 15 minutes
|
163 |
def get_stock_data(ticker):
|
164 |
"""Fetches historical stock data from Yahoo Finance for the last 4 years."""
|
@@ -175,14 +161,8 @@ def get_stock_data(ticker):
|
|
175 |
print(f"Successfully fetched {len(data)} rows for {ticker}")
|
176 |
return data
|
177 |
|
178 |
-
def predict_with_model(data: pd.DataFrame, n_days: int,
|
179 |
-
import torch
|
180 |
|
181 |
-
try:
|
182 |
-
model = load_pytorch_model(model_path, model_type=model_type)
|
183 |
-
except FileNotFoundError as e:
|
184 |
-
raise e
|
185 |
-
print("model:",model)
|
186 |
close_prices = data['Close'].values.reshape(-1, 1)
|
187 |
scaler = MinMaxScaler(feature_range=(0, 1))
|
188 |
scaled_prices = scaler.fit_transform(close_prices)
|
@@ -253,9 +233,15 @@ with st.sidebar:
|
|
253 |
st.session_state.last_ticker = ticker
|
254 |
st.session_state.last_model_type = model_type
|
255 |
st.session_state.last_prediction_days = prediction_days
|
256 |
-
st.rerun()
|
257 |
# --- Main Application Logic ---
|
258 |
if st.session_state.run_button_clicked:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
259 |
print(f"Inside main logic block. Current loading state: {st.session_state.loading}")
|
260 |
try:
|
261 |
if os.path.exists("AMZN_data.csv"):
|
@@ -268,7 +254,7 @@ if st.session_state.run_button_clicked:
|
|
268 |
st.session_state.error = f"Could not fetch data for ticker '{ticker}'. It may be an invalid symbol or network issue."
|
269 |
else:
|
270 |
model_path = "best_bilstm_model.pth" if model_type == "Bi-Directional LSTM" else "best_gru_model.pth"
|
271 |
-
st.session_state.predictions = predict_with_model(st.session_state.data, prediction_days,
|
272 |
st.session_state.error = None
|
273 |
|
274 |
except FileNotFoundError as e:
|
|
|
110 |
""", unsafe_allow_html=True)
|
111 |
|
112 |
# --- Python Backend Functions ---
|
113 |
+
# Outside of any function
|
114 |
+
import torch.nn as nn
|
115 |
+
import torch
|
116 |
+
|
117 |
+
class GRUModel(nn.Module):
|
118 |
+
def __init__(self, input_dim=1, hidden_dim=100, num_layers=2, output_dim=1, dropout_prob=0.2):
|
119 |
+
super(GRUModel, self).__init__()
|
120 |
+
self.gru = nn.GRU(input_dim, hidden_dim, num_layers, batch_first=True, dropout=dropout_prob)
|
121 |
+
self.fc = nn.Linear(hidden_dim, output_dim)
|
122 |
+
|
123 |
+
def forward(self, x):
|
124 |
+
h0 = torch.zeros(2, x.size(0), 100).to(x.device)
|
125 |
+
out, _ = self.gru(x, h0)
|
126 |
+
return self.fc(out[:, -1, :])
|
127 |
+
|
128 |
+
class BiLSTMModel(nn.Module):
|
129 |
+
def __init__(self):
|
130 |
+
super(BiLSTMModel, self).__init__()
|
131 |
+
self.lstm = nn.LSTM(input_size=1, hidden_size=100, num_layers=1, batch_first=True, dropout=0.2, bidirectional=True)
|
132 |
+
self.fc = nn.Linear(200, 1)
|
133 |
+
|
134 |
+
def forward(self, x):
|
135 |
+
h0 = torch.zeros(2, x.size(0), 100)
|
136 |
+
c0 = torch.zeros(2, x.size(0), 100)
|
137 |
+
out, _ = self.lstm(x, (h0, c0))
|
138 |
+
return self.fc(out[:, -1, :])
|
139 |
|
140 |
@st.cache_resource(ttl=3600)
|
141 |
+
def load_model_from_disk(path, model_type):
|
142 |
+
model = BiLSTMModel() if model_type == "Bi-Directional LSTM" else GRUModel()
|
143 |
+
state = torch.load(path, map_location=torch.device("cpu"))
|
144 |
+
model.load_state_dict(state['model_state_dict'] if 'model_state_dict' in state else state)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
model.eval()
|
146 |
return model
|
147 |
+
|
148 |
@st.cache_data(ttl=900) # Cache data for 15 minutes
|
149 |
def get_stock_data(ticker):
|
150 |
"""Fetches historical stock data from Yahoo Finance for the last 4 years."""
|
|
|
161 |
print(f"Successfully fetched {len(data)} rows for {ticker}")
|
162 |
return data
|
163 |
|
164 |
+
def predict_with_model(data: pd.DataFrame, n_days: int, model, model_type: str) -> pd.DataFrame:
|
|
|
165 |
|
|
|
|
|
|
|
|
|
|
|
166 |
close_prices = data['Close'].values.reshape(-1, 1)
|
167 |
scaler = MinMaxScaler(feature_range=(0, 1))
|
168 |
scaled_prices = scaler.fit_transform(close_prices)
|
|
|
233 |
st.session_state.last_ticker = ticker
|
234 |
st.session_state.last_model_type = model_type
|
235 |
st.session_state.last_prediction_days = prediction_days
|
|
|
236 |
# --- Main Application Logic ---
|
237 |
if st.session_state.run_button_clicked:
|
238 |
+
model_key = "bilstm_model" if model_type == "Bi-Directional LSTM" else "gru_model"
|
239 |
+
|
240 |
+
if model_key not in st.session_state:
|
241 |
+
model_path = "best_bilstm_model.pth" if model_type == "Bi-Directional LSTM" else "best_gru_model.pth"
|
242 |
+
st.session_state[model_key] = load_model_from_disk(model_path, model_type)
|
243 |
+
|
244 |
+
model = st.session_state[model_key]
|
245 |
print(f"Inside main logic block. Current loading state: {st.session_state.loading}")
|
246 |
try:
|
247 |
if os.path.exists("AMZN_data.csv"):
|
|
|
254 |
st.session_state.error = f"Could not fetch data for ticker '{ticker}'. It may be an invalid symbol or network issue."
|
255 |
else:
|
256 |
model_path = "best_bilstm_model.pth" if model_type == "Bi-Directional LSTM" else "best_gru_model.pth"
|
257 |
+
st.session_state.predictions = predict_with_model(st.session_state.data, prediction_days, model, model_type)
|
258 |
st.session_state.error = None
|
259 |
|
260 |
except FileNotFoundError as e:
|