wangjin2000 commited on
Commit
53cd821
·
verified ·
1 Parent(s): 6296772

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -1
app.py CHANGED
@@ -141,6 +141,7 @@ def train_function_no_sweeps(base_model_path): #, train_dataset, test_dataset)
141
  class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=flat_train_labels)
142
  accelerator = Accelerator()
143
  class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)
 
144
 
145
  # Convert the model into a PeftModel
146
  peft_config = LoraConfig(
@@ -215,7 +216,7 @@ MODEL_OPTIONS = [
215
  "facebook/esm2_t33_650M_UR50D",
216
  ] # models users can choose from
217
 
218
- '''
219
  # Load the data from pickle files (replace with your local paths)
220
  with open("./datasets/train_sequences_chunked_by_family.pkl", "rb") as f:
221
  train_sequences = pickle.load(f)
@@ -251,6 +252,7 @@ class_weights = compute_class_weight(class_weight='balanced', classes=classes, y
251
  accelerator = Accelerator()
252
  class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)
253
 
 
254
  # inference
255
  # Path to the saved LoRA model
256
  model_path = "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_v2_cp3"
 
141
  class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=flat_train_labels)
142
  accelerator = Accelerator()
143
  class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)
144
+ print(" class_weights:", class_weights)
145
 
146
  # Convert the model into a PeftModel
147
  peft_config = LoraConfig(
 
216
  "facebook/esm2_t33_650M_UR50D",
217
  ] # models users can choose from
218
 
219
+
220
  # Load the data from pickle files (replace with your local paths)
221
  with open("./datasets/train_sequences_chunked_by_family.pkl", "rb") as f:
222
  train_sequences = pickle.load(f)
 
252
  accelerator = Accelerator()
253
  class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)
254
 
255
+ '''
256
  # inference
257
  # Path to the saved LoRA model
258
  model_path = "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_v2_cp3"