wangjin2000 commited on
Commit
0ce7882
·
verified ·
1 Parent(s): 227527b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -1
app.py CHANGED
@@ -1,4 +1,5 @@
1
  #ref: https://huggingface.co/blog/AmelieSchreiber/esmbind
 
2
 
3
  import os
4
  # os.environ["CUDA_VISIBLE_DEVICES"] = "0"
@@ -28,4 +29,78 @@ from transformers import (
28
  from datasets import Dataset
29
  from accelerate import Accelerator
30
  # Imports specific to the custom peft lora model
31
- from peft import get_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  #ref: https://huggingface.co/blog/AmelieSchreiber/esmbind
2
+ import gradio as gr
3
 
4
  import os
5
  # os.environ["CUDA_VISIBLE_DEVICES"] = "0"
 
29
  from datasets import Dataset
30
  from accelerate import Accelerator
31
  # Imports specific to the custom peft lora model
32
+ from peft import get_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType
33
+
34
+
35
+ # Helper Functions and Data Preparation
36
+ def truncate_labels(labels, max_length):
37
+ """Truncate labels to the specified max_length."""
38
+ return [label[:max_length] for label in labels]
39
+
40
+ def compute_metrics(p):
41
+ """Compute metrics for evaluation."""
42
+ predictions, labels = p
43
+ predictions = np.argmax(predictions, axis=2)
44
+
45
+ # Remove padding (-100 labels)
46
+ predictions = predictions[labels != -100].flatten()
47
+ labels = labels[labels != -100].flatten()
48
+
49
+ # Compute accuracy
50
+ accuracy = accuracy_score(labels, predictions)
51
+
52
+ # Compute precision, recall, F1 score, and AUC
53
+ precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary')
54
+ auc = roc_auc_score(labels, predictions)
55
+
56
+ # Compute MCC
57
+ mcc = matthews_corrcoef(labels, predictions)
58
+
59
+ return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc, 'mcc': mcc}
60
+
61
+ def compute_loss(model, inputs):
62
+ """Custom compute_loss function."""
63
+ logits = model(**inputs).logits
64
+ labels = inputs["labels"]
65
+ loss_fct = nn.CrossEntropyLoss(weight=class_weights)
66
+ active_loss = inputs["attention_mask"].view(-1) == 1
67
+ active_logits = logits.view(-1, model.config.num_labels)
68
+ active_labels = torch.where(
69
+ active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
70
+ )
71
+ loss = loss_fct(active_logits, active_labels)
72
+ return loss
73
+
74
+ # Load the data from pickle files (replace with your local paths)
75
+ with open("train_sequences_chunked_by_family.pkl", "rb") as f:
76
+ train_sequences = pickle.load(f)
77
+
78
+ with open("test_sequences_chunked_by_family.pkl", "rb") as f:
79
+ test_sequences = pickle.load(f)
80
+
81
+ with open("train_labels_chunked_by_family.pkl", "rb") as f:
82
+ train_labels = pickle.load(f)
83
+
84
+ with open("test_labels_chunked_by_family.pkl", "rb") as f:
85
+ test_labels = pickle.load(f)
86
+
87
+ # Tokenization
88
+ tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
89
+ max_sequence_length = 1000
90
+
91
+ train_tokenized = tokenizer(train_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
92
+ test_tokenized = tokenizer(test_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
93
+
94
+ # Directly truncate the entire list of labels
95
+ train_labels = truncate_labels(train_labels, max_sequence_length)
96
+ test_labels = truncate_labels(test_labels, max_sequence_length)
97
+
98
+ train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
99
+ test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)
100
+
101
+ # Compute Class Weights
102
+ classes = [0, 1]
103
+ flat_train_labels = [label for sublist in train_labels for label in sublist]
104
+ class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=flat_train_labels)
105
+ accelerator = Accelerator()
106
+ class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)