Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
#ref: https://huggingface.co/blog/AmelieSchreiber/esmbind
|
|
|
|
| 2 |
|
| 3 |
import os
|
| 4 |
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
|
@@ -28,4 +29,78 @@ from transformers import (
|
|
| 28 |
from datasets import Dataset
|
| 29 |
from accelerate import Accelerator
|
| 30 |
# Imports specific to the custom peft lora model
|
| 31 |
-
from peft import get_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
#ref: https://huggingface.co/blog/AmelieSchreiber/esmbind
|
| 2 |
+
import gradio as gr
|
| 3 |
|
| 4 |
import os
|
| 5 |
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
|
|
|
| 29 |
from datasets import Dataset
|
| 30 |
from accelerate import Accelerator
|
| 31 |
# Imports specific to the custom peft lora model
|
| 32 |
+
from peft import get_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# Helper Functions and Data Preparation
|
| 36 |
+
def truncate_labels(labels, max_length):
|
| 37 |
+
"""Truncate labels to the specified max_length."""
|
| 38 |
+
return [label[:max_length] for label in labels]
|
| 39 |
+
|
| 40 |
+
def compute_metrics(p):
|
| 41 |
+
"""Compute metrics for evaluation."""
|
| 42 |
+
predictions, labels = p
|
| 43 |
+
predictions = np.argmax(predictions, axis=2)
|
| 44 |
+
|
| 45 |
+
# Remove padding (-100 labels)
|
| 46 |
+
predictions = predictions[labels != -100].flatten()
|
| 47 |
+
labels = labels[labels != -100].flatten()
|
| 48 |
+
|
| 49 |
+
# Compute accuracy
|
| 50 |
+
accuracy = accuracy_score(labels, predictions)
|
| 51 |
+
|
| 52 |
+
# Compute precision, recall, F1 score, and AUC
|
| 53 |
+
precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary')
|
| 54 |
+
auc = roc_auc_score(labels, predictions)
|
| 55 |
+
|
| 56 |
+
# Compute MCC
|
| 57 |
+
mcc = matthews_corrcoef(labels, predictions)
|
| 58 |
+
|
| 59 |
+
return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc, 'mcc': mcc}
|
| 60 |
+
|
| 61 |
+
def compute_loss(model, inputs):
|
| 62 |
+
"""Custom compute_loss function."""
|
| 63 |
+
logits = model(**inputs).logits
|
| 64 |
+
labels = inputs["labels"]
|
| 65 |
+
loss_fct = nn.CrossEntropyLoss(weight=class_weights)
|
| 66 |
+
active_loss = inputs["attention_mask"].view(-1) == 1
|
| 67 |
+
active_logits = logits.view(-1, model.config.num_labels)
|
| 68 |
+
active_labels = torch.where(
|
| 69 |
+
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
|
| 70 |
+
)
|
| 71 |
+
loss = loss_fct(active_logits, active_labels)
|
| 72 |
+
return loss
|
| 73 |
+
|
| 74 |
+
# Load the data from pickle files (replace with your local paths)
|
| 75 |
+
with open("train_sequences_chunked_by_family.pkl", "rb") as f:
|
| 76 |
+
train_sequences = pickle.load(f)
|
| 77 |
+
|
| 78 |
+
with open("test_sequences_chunked_by_family.pkl", "rb") as f:
|
| 79 |
+
test_sequences = pickle.load(f)
|
| 80 |
+
|
| 81 |
+
with open("train_labels_chunked_by_family.pkl", "rb") as f:
|
| 82 |
+
train_labels = pickle.load(f)
|
| 83 |
+
|
| 84 |
+
with open("test_labels_chunked_by_family.pkl", "rb") as f:
|
| 85 |
+
test_labels = pickle.load(f)
|
| 86 |
+
|
| 87 |
+
# Tokenization
|
| 88 |
+
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
|
| 89 |
+
max_sequence_length = 1000
|
| 90 |
+
|
| 91 |
+
train_tokenized = tokenizer(train_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
|
| 92 |
+
test_tokenized = tokenizer(test_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
|
| 93 |
+
|
| 94 |
+
# Directly truncate the entire list of labels
|
| 95 |
+
train_labels = truncate_labels(train_labels, max_sequence_length)
|
| 96 |
+
test_labels = truncate_labels(test_labels, max_sequence_length)
|
| 97 |
+
|
| 98 |
+
train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
|
| 99 |
+
test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)
|
| 100 |
+
|
| 101 |
+
# Compute Class Weights
|
| 102 |
+
classes = [0, 1]
|
| 103 |
+
flat_train_labels = [label for sublist in train_labels for label in sublist]
|
| 104 |
+
class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=flat_train_labels)
|
| 105 |
+
accelerator = Accelerator()
|
| 106 |
+
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)
|