Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -73,6 +73,14 @@ def compute_loss(model, inputs):
|
|
73 |
loss = loss_fct(active_logits, active_labels)
|
74 |
return loss
|
75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
# fine-tuning function
|
77 |
def train_function_no_sweeps(base_model_path, train_dataset, test_dataset):
|
78 |
|
|
|
73 |
loss = loss_fct(active_logits, active_labels)
|
74 |
return loss
|
75 |
|
76 |
+
# Define Custom Trainer Class
|
77 |
+
# Since we are using class weights, due to the imbalance between non-binding residues and binding residues, we will need a custom weighted trainer.
|
78 |
+
class WeightedTrainer(Trainer):
|
79 |
+
def compute_loss(self, model, inputs, return_outputs=False):
|
80 |
+
outputs = model(**inputs)
|
81 |
+
loss = compute_loss(model, inputs)
|
82 |
+
return (loss, outputs) if return_outputs else loss
|
83 |
+
|
84 |
# fine-tuning function
|
85 |
def train_function_no_sweeps(base_model_path, train_dataset, test_dataset):
|
86 |
|