danielle2003 commited on
Commit
3971804
·
verified ·
1 Parent(s): fff6ecb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -128,14 +128,15 @@ class GRUModel(nn.Module):
128
  class BiLSTMModel(nn.Module):
129
  def __init__(self):
130
  super(BiLSTMModel, self).__init__()
131
- self.lstm = nn.LSTM(input_size=1, hidden_size=100, num_layers=2, batch_first=True, dropout=0.2, bidirectional=True)
132
  self.fc = nn.Linear(200, 1)
133
 
134
  def forward(self, x):
135
- h0 = torch.zeros(4, x.size(0), 100) # 2 directions × 2 layers = 4
136
- c0 = torch.zeros(4, x.size(0), 100)
137
  out, _ = self.lstm(x, (h0, c0))
138
  return self.fc(out[:, -1, :])
 
139
  @st.cache_resource(ttl=3600)
140
  def load_model_from_disk(path, model_type):
141
  model = BiLSTMModel() if model_type == "Bi-Directional LSTM" else GRUModel()
 
128
  class BiLSTMModel(nn.Module):
129
  def __init__(self):
130
  super(BiLSTMModel, self).__init__()
131
+ self.lstm = nn.LSTM(input_size=1, hidden_size=100, num_layers=1, batch_first=True, dropout=0.2, bidirectional=True)
132
  self.fc = nn.Linear(200, 1)
133
 
134
  def forward(self, x):
135
+ h0 = torch.zeros(2, x.size(0), 100)
136
+ c0 = torch.zeros(2, x.size(0), 100)
137
  out, _ = self.lstm(x, (h0, c0))
138
  return self.fc(out[:, -1, :])
139
+
140
  @st.cache_resource(ttl=3600)
141
  def load_model_from_disk(path, model_type):
142
  model = BiLSTMModel() if model_type == "Bi-Directional LSTM" else GRUModel()