wangjin2000 commited on
Commit
d03eed6
·
verified ·
1 Parent(s): 8aefe80

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -0
app.py CHANGED
@@ -73,6 +73,14 @@ def compute_loss(model, inputs):
73
  loss = loss_fct(active_logits, active_labels)
74
  return loss
75
 
 
 
 
 
 
 
 
 
76
  # fine-tuning function
77
  def train_function_no_sweeps(base_model_path, train_dataset, test_dataset):
78
 
 
73
  loss = loss_fct(active_logits, active_labels)
74
  return loss
75
 
76
+ # Define Custom Trainer Class
77
+ # Since we are using class weights, due to the imbalance between non-binding residues and binding residues, we will need a custom weighted trainer.
78
+ class WeightedTrainer(Trainer):
79
+ def compute_loss(self, model, inputs, return_outputs=False):
80
+ outputs = model(**inputs)
81
+ loss = compute_loss(model, inputs)
82
+ return (loss, outputs) if return_outputs else loss
83
+
84
  # fine-tuning function
85
  def train_function_no_sweeps(base_model_path, train_dataset, test_dataset):
86