Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -140,7 +140,8 @@ def train_function_no_sweeps(base_model_path, train_dataset, test_dataset):
|
|
140 |
no_cuda=False,
|
141 |
seed=8893,
|
142 |
fp16=True,
|
143 |
-
report_to='wandb'
|
|
|
144 |
)
|
145 |
|
146 |
# Initialize Trainer
|
@@ -160,6 +161,8 @@ def train_function_no_sweeps(base_model_path, train_dataset, test_dataset):
|
|
160 |
trainer.save_model(save_path)
|
161 |
tokenizer.save_pretrained(save_path)
|
162 |
|
|
|
|
|
163 |
# Load the data from pickle files (replace with your local paths)
|
164 |
with open("./datasets/train_sequences_chunked_by_family.pkl", "rb") as f:
|
165 |
train_sequences = pickle.load(f)
|
@@ -217,6 +220,9 @@ inputs = tokenizer(protein_sequence, return_tensors="pt", truncation=True, max_l
|
|
217 |
with torch.no_grad():
|
218 |
logits = loaded_model(**inputs).logits
|
219 |
|
|
|
|
|
|
|
220 |
# Get predictions
|
221 |
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
|
222 |
predictions = torch.argmax(logits, dim=2)
|
@@ -233,7 +239,7 @@ for token, prediction in zip(tokens, predictions[0].numpy()):
|
|
233 |
print((token, id2label[prediction]))
|
234 |
|
235 |
# debug result
|
236 |
-
dubug_result = predictions #class_weights
|
237 |
|
238 |
demo = gr.Blocks(title="DEMO FOR ESM2Bind")
|
239 |
|
|
|
140 |
no_cuda=False,
|
141 |
seed=8893,
|
142 |
fp16=True,
|
143 |
+
#report_to='wandb'
|
144 |
+
report_to=None
|
145 |
)
|
146 |
|
147 |
# Initialize Trainer
|
|
|
161 |
trainer.save_model(save_path)
|
162 |
tokenizer.save_pretrained(save_path)
|
163 |
|
164 |
+
return save_path
|
165 |
+
|
166 |
# Load the data from pickle files (replace with your local paths)
|
167 |
with open("./datasets/train_sequences_chunked_by_family.pkl", "rb") as f:
|
168 |
train_sequences = pickle.load(f)
|
|
|
220 |
with torch.no_grad():
|
221 |
logits = loaded_model(**inputs).logits
|
222 |
|
223 |
+
# train
|
224 |
+
saved_path = train_function_no_sweeps(base_model_path,train_dataset, test_dataset)
|
225 |
+
|
226 |
# Get predictions
|
227 |
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
|
228 |
predictions = torch.argmax(logits, dim=2)
|
|
|
239 |
print((token, id2label[prediction]))
|
240 |
|
241 |
# debug result
|
242 |
+
dubug_result = saved_path #predictions #class_weights
|
243 |
|
244 |
demo = gr.Blocks(title="DEMO FOR ESM2Bind")
|
245 |
|