Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -118,9 +118,9 @@ def train_function_no_sweeps(base_model_path): #, train_dataset, test_dataset)
|
|
118 |
base_model = get_peft_model(base_model, peft_config)
|
119 |
|
120 |
# Use the accelerator
|
121 |
-
base_model =
|
122 |
-
train_dataset =
|
123 |
-
test_dataset =
|
124 |
|
125 |
timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
|
126 |
|
@@ -205,7 +205,7 @@ test_labels = truncate_labels(test_labels, max_sequence_length)
|
|
205 |
train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
|
206 |
test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)
|
207 |
|
208 |
-
|
209 |
# Compute Class Weights
|
210 |
classes = [0, 1]
|
211 |
flat_train_labels = [label for sublist in train_labels for label in sublist]
|
@@ -213,6 +213,7 @@ class_weights = compute_class_weight(class_weight='balanced', classes=classes, y
|
|
213 |
accelerator = Accelerator()
|
214 |
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)
|
215 |
|
|
|
216 |
# inference
|
217 |
# Path to the saved LoRA model
|
218 |
model_path = "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_v2_cp3"
|
|
|
118 |
base_model = get_peft_model(base_model, peft_config)
|
119 |
|
120 |
# Use the accelerator
|
121 |
+
base_model = accelerator.prepare(base_model)
|
122 |
+
train_dataset = accelerator.prepare(train_dataset)
|
123 |
+
test_dataset = accelerator.prepare(test_dataset)
|
124 |
|
125 |
timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
|
126 |
|
|
|
205 |
train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
|
206 |
test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)
|
207 |
|
208 |
+
|
209 |
# Compute Class Weights
|
210 |
classes = [0, 1]
|
211 |
flat_train_labels = [label for sublist in train_labels for label in sublist]
|
|
|
213 |
accelerator = Accelerator()
|
214 |
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)
|
215 |
|
216 |
+
'''
|
217 |
# inference
|
218 |
# Path to the saved LoRA model
|
219 |
model_path = "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_v2_cp3"
|