danielle2003 commited on
Commit
fade590
·
verified ·
1 Parent(s): 81315fb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -7
app.py CHANGED
@@ -126,24 +126,23 @@ class GRUModel(nn.Module):
126
  return self.fc(out[:, -1, :])
127
 
128
  class BiLSTMModel(nn.Module):
129
- def __init__(self, input_dim=1, hidden_dim=100, output_dim=1):
130
  super(BiLSTMModel, self).__init__()
131
  self.lstm = nn.LSTM(
132
  input_size=input_dim,
133
  hidden_size=hidden_dim,
134
- num_layers=1, # <-- match this to the training model
135
  batch_first=True,
136
- dropout=0.0, # <-- dropout does nothing with 1 layer, so remove
137
  bidirectional=True
138
  )
139
- self.fc = nn.Linear(hidden_dim * 2, output_dim)
140
 
141
  def forward(self, x):
142
- h0 = torch.zeros(2, x.size(0), 100).to(x.device)
143
- c0 = torch.zeros(2, x.size(0), 100).to(x.device)
144
  out, _ = self.lstm(x, (h0, c0))
145
  return self.fc(out[:, -1, :])
146
-
147
  @st.cache_resource(ttl=3600)
148
  def load_model_from_disk(path, model_type):
149
  model = BiLSTMModel() if model_type == "Bi-Directional LSTM" else GRUModel()
 
126
  return self.fc(out[:, -1, :])
127
 
128
  class BiLSTMModel(nn.Module):
129
+ def __init__(self, input_dim=1, hidden_dim=100, num_layers=2, output_dim=1, dropout_prob=0.2):
130
  super(BiLSTMModel, self).__init__()
131
  self.lstm = nn.LSTM(
132
  input_size=input_dim,
133
  hidden_size=hidden_dim,
134
+ num_layers=num_layers,
135
  batch_first=True,
136
+ dropout=dropout_prob,
137
  bidirectional=True
138
  )
139
+ self.fc = nn.Linear(hidden_dim * 2, output_dim) # because bidirectional
140
 
141
  def forward(self, x):
142
+ h0 = torch.zeros(self.lstm.num_layers * 2, x.size(0), self.lstm.hidden_size).to(x.device)
143
+ c0 = torch.zeros(self.lstm.num_layers * 2, x.size(0), self.lstm.hidden_size).to(x.device)
144
  out, _ = self.lstm(x, (h0, c0))
145
  return self.fc(out[:, -1, :])
 
146
  @st.cache_resource(ttl=3600)
147
  def load_model_from_disk(path, model_type):
148
  model = BiLSTMModel() if model_type == "Bi-Directional LSTM" else GRUModel()