wangjin2000 commited on
Commit
dfde78e
·
verified ·
1 Parent(s): 02849fc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -2
app.py CHANGED
@@ -140,7 +140,8 @@ def train_function_no_sweeps(base_model_path, train_dataset, test_dataset):
140
  no_cuda=False,
141
  seed=8893,
142
  fp16=True,
143
- report_to='wandb'
 
144
  )
145
 
146
  # Initialize Trainer
@@ -160,6 +161,8 @@ def train_function_no_sweeps(base_model_path, train_dataset, test_dataset):
160
  trainer.save_model(save_path)
161
  tokenizer.save_pretrained(save_path)
162
 
 
 
163
  # Load the data from pickle files (replace with your local paths)
164
  with open("./datasets/train_sequences_chunked_by_family.pkl", "rb") as f:
165
  train_sequences = pickle.load(f)
@@ -217,6 +220,9 @@ inputs = tokenizer(protein_sequence, return_tensors="pt", truncation=True, max_l
217
  with torch.no_grad():
218
  logits = loaded_model(**inputs).logits
219
 
 
 
 
220
  # Get predictions
221
  tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
222
  predictions = torch.argmax(logits, dim=2)
@@ -233,7 +239,7 @@ for token, prediction in zip(tokens, predictions[0].numpy()):
233
  print((token, id2label[prediction]))
234
 
235
  # debug result
236
- dubug_result = predictions #class_weights
237
 
238
  demo = gr.Blocks(title="DEMO FOR ESM2Bind")
239
 
 
140
  no_cuda=False,
141
  seed=8893,
142
  fp16=True,
143
+ #report_to='wandb'
144
+ report_to=None
145
  )
146
 
147
  # Initialize Trainer
 
161
  trainer.save_model(save_path)
162
  tokenizer.save_pretrained(save_path)
163
 
164
+ return save_path
165
+
166
  # Load the data from pickle files (replace with your local paths)
167
  with open("./datasets/train_sequences_chunked_by_family.pkl", "rb") as f:
168
  train_sequences = pickle.load(f)
 
220
  with torch.no_grad():
221
  logits = loaded_model(**inputs).logits
222
 
223
+ # train
224
+ saved_path = train_function_no_sweeps(base_model_path,train_dataset, test_dataset)
225
+
226
  # Get predictions
227
  tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
228
  predictions = torch.argmax(logits, dim=2)
 
239
  print((token, id2label[prediction]))
240
 
241
  # debug result
242
+ dubug_result = saved_path #predictions #class_weights
243
 
244
  demo = gr.Blocks(title="DEMO FOR ESM2Bind")
245