Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -60,7 +60,7 @@ def compute_metrics(p):
|
|
| 60 |
|
| 61 |
return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc, 'mcc': mcc}
|
| 62 |
|
| 63 |
-
def compute_loss(model, inputs
|
| 64 |
"""Custom compute_loss function."""
|
| 65 |
logits = model(**inputs).logits
|
| 66 |
labels = inputs["labels"]
|
|
@@ -76,11 +76,12 @@ def compute_loss(model, inputs, class_weights): #compute_loss(model, inputs): a
|
|
| 76 |
# Define Custom Trainer Class
|
| 77 |
# Since we are using class weights, due to the imbalance between non-binding residues and binding residues, we will need a custom weighted trainer.
|
| 78 |
class WeightedTrainer(Trainer):
|
| 79 |
-
def compute_loss(self, model, inputs,
|
| 80 |
outputs = model(**inputs)
|
| 81 |
-
loss = compute_loss(model, inputs
|
| 82 |
return (loss, outputs) if return_outputs else loss
|
| 83 |
-
|
|
|
|
| 84 |
# fine-tuning function
|
| 85 |
def train_function_no_sweeps(base_model_path): #, train_dataset, test_dataset):
|
| 86 |
|
|
@@ -196,8 +197,7 @@ def train_function_no_sweeps(base_model_path): #, train_dataset, test_dataset)
|
|
| 196 |
eval_dataset=test_dataset,
|
| 197 |
tokenizer=tokenizer,
|
| 198 |
data_collator=DataCollatorForTokenClassification(tokenizer=tokenizer),
|
| 199 |
-
compute_metrics=compute_metrics
|
| 200 |
-
class_weights=class_weights, #add class_weights as input, jw 20240628
|
| 201 |
)
|
| 202 |
|
| 203 |
# Train and Save Model
|
|
|
|
| 60 |
|
| 61 |
return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc, 'mcc': mcc}
|
| 62 |
|
| 63 |
+
def compute_loss(model, inputs):
|
| 64 |
"""Custom compute_loss function."""
|
| 65 |
logits = model(**inputs).logits
|
| 66 |
labels = inputs["labels"]
|
|
|
|
| 76 |
# Define Custom Trainer Class
|
| 77 |
# Since we are using class weights, due to the imbalance between non-binding residues and binding residues, we will need a custom weighted trainer.
|
| 78 |
class WeightedTrainer(Trainer):
|
| 79 |
+
def compute_loss(self, model, inputs, return_outputs=False):
|
| 80 |
outputs = model(**inputs)
|
| 81 |
+
loss = compute_loss(model, inputs)
|
| 82 |
return (loss, outputs) if return_outputs else loss
|
| 83 |
+
|
| 84 |
+
#
|
| 85 |
# fine-tuning function
|
| 86 |
def train_function_no_sweeps(base_model_path): #, train_dataset, test_dataset):
|
| 87 |
|
|
|
|
| 197 |
eval_dataset=test_dataset,
|
| 198 |
tokenizer=tokenizer,
|
| 199 |
data_collator=DataCollatorForTokenClassification(tokenizer=tokenizer),
|
| 200 |
+
compute_metrics=compute_metrics
|
|
|
|
| 201 |
)
|
| 202 |
|
| 203 |
# Train and Save Model
|