danielle2003 commited on
Commit
f7ff14c
·
verified ·
1 Parent(s): 3971804

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -2
app.py CHANGED
@@ -144,11 +144,15 @@ def load_model_from_disk(path, model_type):
144
  model.load_state_dict(state['model_state_dict'] if 'model_state_dict' in state else state)
145
  model.eval()
146
  return model
147
-
 
 
 
 
148
  @st.cache_resource
149
  def preload_models():
150
  return {
151
- "Bi-Directional LSTM": load_model_from_disk("best_bilstm_model.pth", model_type="Bi-Directional LSTM"),
152
  "Gated Recurrent Unit (GRU)": load_model_from_disk("best_gru_model.pth", model_type="GRU")
153
  }
154
 
 
144
  model.load_state_dict(state['model_state_dict'] if 'model_state_dict' in state else state)
145
  model.eval()
146
  return model
147
+ @st.cache_resource(ttl=3600)
148
+ def load_scripted_model(path):
149
+ model = torch.jit.load(path, map_location=torch.device("cpu"))
150
+ model.eval()
151
+ return model
152
  @st.cache_resource
153
  def preload_models():
154
  return {
155
+ "Bi-Directional LSTM": load_scripted_model("bilstm_scripted.pt"),
156
  "Gated Recurrent Unit (GRU)": load_model_from_disk("best_gru_model.pth", model_type="GRU")
157
  }
158