Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -81,7 +81,39 @@ class WeightedTrainer(Trainer):
|
|
| 81 |
loss = compute_loss(model, inputs)
|
| 82 |
return (loss, outputs) if return_outputs else loss
|
| 83 |
|
| 84 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
# fine-tuning function
|
| 86 |
def train_function_no_sweeps(base_model_path): #, train_dataset, test_dataset):
|
| 87 |
|
|
@@ -102,8 +134,10 @@ def train_function_no_sweeps(base_model_path): #, train_dataset, test_dataset)
|
|
| 102 |
#base_model_path = "facebook/esm2_t12_35M_UR50D"
|
| 103 |
|
| 104 |
# Define labels and model
|
| 105 |
-
id2label = {0: "No binding site", 1: "Binding site"}
|
| 106 |
-
label2id = {v: k for k, v in id2label.items()}
|
|
|
|
|
|
|
| 107 |
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path, num_labels=len(id2label), id2label=id2label, label2id=label2id)
|
| 108 |
|
| 109 |
'''
|
|
@@ -289,12 +323,14 @@ with torch.no_grad():
|
|
| 289 |
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
|
| 290 |
predictions = torch.argmax(logits, dim=2)
|
| 291 |
|
|
|
|
| 292 |
# Define labels
|
| 293 |
id2label = {
|
| 294 |
0: "No binding site",
|
| 295 |
1: "Binding site"
|
| 296 |
}
|
| 297 |
|
|
|
|
| 298 |
# Print the predicted labels for each token
|
| 299 |
for token, prediction in zip(tokens, predictions[0].numpy()):
|
| 300 |
if token not in ['<pad>', '<cls>', '<eos>']:
|
|
|
|
| 81 |
loss = compute_loss(model, inputs)
|
| 82 |
return (loss, outputs) if return_outputs else loss
|
| 83 |
|
| 84 |
+
# Predict binding site with finetuned PEFT model
|
| 85 |
+
def predict_bind(base_model_path,PEFT_model_path,input_seq):
|
| 86 |
+
# Load the model
|
| 87 |
+
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
|
| 88 |
+
loaded_model = PeftModel.from_pretrained(base_model, PEFT_model_path)
|
| 89 |
+
|
| 90 |
+
# Ensure the model is in evaluation mode
|
| 91 |
+
loaded_model.eval()
|
| 92 |
+
|
| 93 |
+
# Tokenization
|
| 94 |
+
tokenizer = AutoTokenizer.from_pretrained(base_model_path)
|
| 95 |
+
|
| 96 |
+
# Tokenize the sequence
|
| 97 |
+
inputs = tokenizer(input_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length')
|
| 98 |
+
|
| 99 |
+
# Run the model
|
| 100 |
+
with torch.no_grad():
|
| 101 |
+
logits = loaded_model(**inputs).logits
|
| 102 |
+
|
| 103 |
+
# Get predictions
|
| 104 |
+
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
|
| 105 |
+
predictions = torch.argmax(logits, dim=2)
|
| 106 |
+
|
| 107 |
+
binding_site=[]
|
| 108 |
+
# Print the predicted labels for each token
|
| 109 |
+
for n, token, prediction in enumerate(zip(tokens, predictions[0].numpy())):
|
| 110 |
+
if token not in ['<pad>', '<cls>', '<eos>']:
|
| 111 |
+
print((token, id2label[prediction]))
|
| 112 |
+
if prediction == 1:
|
| 113 |
+
print((n+1,token, id2label[prediction]))
|
| 114 |
+
binding_site.append(n+1,token, id2label[prediction])
|
| 115 |
+
return binding_site
|
| 116 |
+
|
| 117 |
# fine-tuning function
|
| 118 |
def train_function_no_sweeps(base_model_path): #, train_dataset, test_dataset):
|
| 119 |
|
|
|
|
| 134 |
#base_model_path = "facebook/esm2_t12_35M_UR50D"
|
| 135 |
|
| 136 |
# Define labels and model
|
| 137 |
+
#id2label = {0: "No binding site", 1: "Binding site"}
|
| 138 |
+
#label2id = {v: k for k, v in id2label.items()}
|
| 139 |
+
|
| 140 |
+
|
| 141 |
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path, num_labels=len(id2label), id2label=id2label, label2id=label2id)
|
| 142 |
|
| 143 |
'''
|
|
|
|
| 323 |
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
|
| 324 |
predictions = torch.argmax(logits, dim=2)
|
| 325 |
|
| 326 |
+
'''
|
| 327 |
# Define labels
|
| 328 |
id2label = {
|
| 329 |
0: "No binding site",
|
| 330 |
1: "Binding site"
|
| 331 |
}
|
| 332 |
|
| 333 |
+
'''
|
| 334 |
# Print the predicted labels for each token
|
| 335 |
for token, prediction in zip(tokens, predictions[0].numpy()):
|
| 336 |
if token not in ['<pad>', '<cls>', '<eos>']:
|