wangjin2000 commited on
Commit
5bbc76e
·
verified ·
1 Parent(s): d0d7573

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -118,9 +118,9 @@ def train_function_no_sweeps(base_model_path): #, train_dataset, test_dataset)
118
  base_model = get_peft_model(base_model, peft_config)
119
 
120
  # Use the accelerator
121
- base_model = Accelerator.prepare(base_model)
122
- train_dataset = Accelerator.prepare(train_dataset)
123
- test_dataset = Accelerator.prepare(test_dataset)
124
 
125
  timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
126
 
@@ -205,7 +205,7 @@ test_labels = truncate_labels(test_labels, max_sequence_length)
205
  train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
206
  test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)
207
 
208
- '''
209
  # Compute Class Weights
210
  classes = [0, 1]
211
  flat_train_labels = [label for sublist in train_labels for label in sublist]
@@ -213,6 +213,7 @@ class_weights = compute_class_weight(class_weight='balanced', classes=classes, y
213
  accelerator = Accelerator()
214
  class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)
215
 
 
216
  # inference
217
  # Path to the saved LoRA model
218
  model_path = "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_v2_cp3"
 
118
  base_model = get_peft_model(base_model, peft_config)
119
 
120
  # Use the accelerator
121
+ base_model = accelerator.prepare(base_model)
122
+ train_dataset = accelerator.prepare(train_dataset)
123
+ test_dataset = accelerator.prepare(test_dataset)
124
 
125
  timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
126
 
 
205
  train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
206
  test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)
207
 
208
+
209
  # Compute Class Weights
210
  classes = [0, 1]
211
  flat_train_labels = [label for sublist in train_labels for label in sublist]
 
213
  accelerator = Accelerator()
214
  class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)
215
 
216
+ '''
217
  # inference
218
  # Path to the saved LoRA model
219
  model_path = "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_v2_cp3"