wangjin2000 commited on
Commit
02849fc
·
verified ·
1 Parent(s): 4713da5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -0
app.py CHANGED
@@ -73,6 +73,93 @@ def compute_loss(model, inputs):
73
  loss = loss_fct(active_logits, active_labels)
74
  return loss
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  # Load the data from pickle files (replace with your local paths)
77
  with open("./datasets/train_sequences_chunked_by_family.pkl", "rb") as f:
78
  train_sequences = pickle.load(f)
 
73
  loss = loss_fct(active_logits, active_labels)
74
  return loss
75
 
76
+ # fine-tuning function
77
+ def train_function_no_sweeps(base_model_path, train_dataset, test_dataset):
78
+
79
+ # Set the LoRA config
80
+ config = {
81
+ "lora_alpha": 1, #try 0.5, 1, 2, ..., 16
82
+ "lora_dropout": 0.2,
83
+ "lr": 5.701568055793089e-04,
84
+ "lr_scheduler_type": "cosine",
85
+ "max_grad_norm": 0.5,
86
+ "num_train_epochs": 3,
87
+ "per_device_train_batch_size": 12,
88
+ "r": 2,
89
+ "weight_decay": 0.2,
90
+ # Add other hyperparameters as needed
91
+ }
92
+ # The base model you will train a LoRA on top of
93
+ base_model_path = "facebook/esm2_t12_35M_UR50D"
94
+
95
+ # Define labels and model
96
+ id2label = {0: "No binding site", 1: "Binding site"}
97
+ label2id = {v: k for k, v in id2label.items()}
98
+ base_model = AutoModelForTokenClassification.from_pretrained(base_model_path, num_labels=len(id2label), id2label=id2label, label2id=label2id)
99
+
100
+ # Convert the model into a PeftModel
101
+ peft_config = LoraConfig(
102
+ task_type=TaskType.TOKEN_CLS,
103
+ inference_mode=False,
104
+ r=config["r"],
105
+ lora_alpha=config["lora_alpha"],
106
+ target_modules=["query", "key", "value"], # also try "dense_h_to_4h" and "dense_4h_to_h"
107
+ lora_dropout=config["lora_dropout"],
108
+ bias="none" # or "all" or "lora_only"
109
+ )
110
+ base_model = get_peft_model(base_model, peft_config)
111
+
112
+ # Use the accelerator
113
+ base_model = accelerator.prepare(base_model)
114
+ train_dataset = accelerator.prepare(train_dataset)
115
+ test_dataset = accelerator.prepare(test_dataset)
116
+
117
+ timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
118
+
119
+ # Training setup
120
+ training_args = TrainingArguments(
121
+ output_dir=f"esm2_t12_35M-lora-binding-sites_{timestamp}",
122
+ learning_rate=config["lr"],
123
+ lr_scheduler_type=config["lr_scheduler_type"],
124
+ gradient_accumulation_steps=1,
125
+ max_grad_norm=config["max_grad_norm"],
126
+ per_device_train_batch_size=config["per_device_train_batch_size"],
127
+ per_device_eval_batch_size=config["per_device_train_batch_size"],
128
+ num_train_epochs=config["num_train_epochs"],
129
+ weight_decay=config["weight_decay"],
130
+ evaluation_strategy="epoch",
131
+ save_strategy="epoch",
132
+ load_best_model_at_end=True,
133
+ metric_for_best_model="f1",
134
+ greater_is_better=True,
135
+ push_to_hub=False,
136
+ logging_dir=None,
137
+ logging_first_step=False,
138
+ logging_steps=200,
139
+ save_total_limit=7,
140
+ no_cuda=False,
141
+ seed=8893,
142
+ fp16=True,
143
+ report_to='wandb'
144
+ )
145
+
146
+ # Initialize Trainer
147
+ trainer = WeightedTrainer(
148
+ model=base_model,
149
+ args=training_args,
150
+ train_dataset=train_dataset,
151
+ eval_dataset=test_dataset,
152
+ tokenizer=tokenizer,
153
+ data_collator=DataCollatorForTokenClassification(tokenizer=tokenizer),
154
+ compute_metrics=compute_metrics
155
+ )
156
+
157
+ # Train and Save Model
158
+ trainer.train()
159
+ save_path = os.path.join("lora_binding_sites", f"best_model_esm2_t12_35M_lora_{timestamp}")
160
+ trainer.save_model(save_path)
161
+ tokenizer.save_pretrained(save_path)
162
+
163
  # Load the data from pickle files (replace with your local paths)
164
  with open("./datasets/train_sequences_chunked_by_family.pkl", "rb") as f:
165
  train_sequences = pickle.load(f)