gelnesr commited on
Commit
6958775
·
verified ·
1 Parent(s): bbe790e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -10
app.py CHANGED
@@ -141,6 +141,15 @@ def handle_name(name=None, pdb_input=None, model_version="ESM3"):
141
 
142
  @spaces.GPU(duration=300)
143
  def run_model(model, model_version='ESM2', seq_input=None, struct_input=None, sequence_id=None):
 
 
 
 
 
 
 
 
 
144
  if model_version == "ESM3":
145
  logits = model((seq_input, struct_input), sequence_id)
146
  else:
@@ -158,15 +167,6 @@ def predict_dynamics(sequence=None, pdb_input=None, chain_id='A', use_pdb_seq=Fa
158
 
159
  base_name = handle_name(name, pdb_input, model_version)
160
 
161
- if model_version == "ESM3":
162
- model = ESM_model(method='esm3')
163
- model.load_state_dict(torch.load('Dyna-1/model/weights/dyna1.pt'), strict=False)
164
- else:
165
- model = ESM_model(method='esm2', nheads=8, nlayers=12, layer=30).to(DEVICE)
166
- model.load_state_dict(torch.load('Dyna-1/model/weights/dyna1-esm2.pt'), strict=False)
167
-
168
- model.eval()
169
-
170
  seq_input, struct_input = None, None
171
  sequence = validate_sequence(sequence) if sequence else None
172
  protein = None
@@ -194,7 +194,7 @@ def predict_dynamics(sequence=None, pdb_input=None, chain_id='A', use_pdb_seq=Fa
194
  if not (sequence or (pdb_input and model_version == "ESM3")):
195
  raise ValueError('Please provide a sequence' + (' or structure input' if model_version == "ESM3" else ''))
196
 
197
- logits = run_model(model, model_version, seq_input, struct_input, sequence_id)
198
 
199
  probabilities = utils.prob_adjusted(logits).cpu().detach().numpy()
200
 
 
141
 
142
  @spaces.GPU(duration=300)
143
  def run_model(model, model_version='ESM2', seq_input=None, struct_input=None, sequence_id=None):
144
+ if model_version == "ESM3":
145
+ model = ESM_model(method='esm3')
146
+ model.load_state_dict(torch.load('Dyna-1/model/weights/dyna1.pt'), strict=False)
147
+ else:
148
+ model = ESM_model(method='esm2', nheads=8, nlayers=12, layer=30).to(DEVICE)
149
+ model.load_state_dict(torch.load('Dyna-1/model/weights/dyna1-esm2.pt'), strict=False)
150
+
151
+ model.eval()
152
+
153
  if model_version == "ESM3":
154
  logits = model((seq_input, struct_input), sequence_id)
155
  else:
 
167
 
168
  base_name = handle_name(name, pdb_input, model_version)
169
 
 
 
 
 
 
 
 
 
 
170
  seq_input, struct_input = None, None
171
  sequence = validate_sequence(sequence) if sequence else None
172
  protein = None
 
194
  if not (sequence or (pdb_input and model_version == "ESM3")):
195
  raise ValueError('Please provide a sequence' + (' or structure input' if model_version == "ESM3" else ''))
196
 
197
+ logits = run_model(model_version, seq_input, struct_input, sequence_id)
198
 
199
  probabilities = utils.prob_adjusted(logits).cpu().detach().numpy()
200