Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -141,6 +141,7 @@ def train_function_no_sweeps(base_model_path): #, train_dataset, test_dataset)
|
|
141 |
class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=flat_train_labels)
|
142 |
accelerator = Accelerator()
|
143 |
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)
|
|
|
144 |
|
145 |
# Convert the model into a PeftModel
|
146 |
peft_config = LoraConfig(
|
@@ -215,7 +216,7 @@ MODEL_OPTIONS = [
|
|
215 |
"facebook/esm2_t33_650M_UR50D",
|
216 |
] # models users can choose from
|
217 |
|
218 |
-
|
219 |
# Load the data from pickle files (replace with your local paths)
|
220 |
with open("./datasets/train_sequences_chunked_by_family.pkl", "rb") as f:
|
221 |
train_sequences = pickle.load(f)
|
@@ -251,6 +252,7 @@ class_weights = compute_class_weight(class_weight='balanced', classes=classes, y
|
|
251 |
accelerator = Accelerator()
|
252 |
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)
|
253 |
|
|
|
254 |
# inference
|
255 |
# Path to the saved LoRA model
|
256 |
model_path = "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_v2_cp3"
|
|
|
141 |
class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=flat_train_labels)
|
142 |
accelerator = Accelerator()
|
143 |
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)
|
144 |
+
print(" class_weights:", class_weights)
|
145 |
|
146 |
# Convert the model into a PeftModel
|
147 |
peft_config = LoraConfig(
|
|
|
216 |
"facebook/esm2_t33_650M_UR50D",
|
217 |
] # models users can choose from
|
218 |
|
219 |
+
|
220 |
# Load the data from pickle files (replace with your local paths)
|
221 |
with open("./datasets/train_sequences_chunked_by_family.pkl", "rb") as f:
|
222 |
train_sequences = pickle.load(f)
|
|
|
252 |
accelerator = Accelerator()
|
253 |
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)
|
254 |
|
255 |
+
'''
|
256 |
# inference
|
257 |
# Path to the saved LoRA model
|
258 |
model_path = "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_v2_cp3"
|