HemanM commited on
Commit
bd0af7a
·
verified ·
1 Parent(s): 314d724

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +1 -1
inference.py CHANGED
@@ -166,7 +166,7 @@ def retrain_from_feedback_csv():
166
  encoded = tokenizer(input_text, return_tensors="pt", padding="max_length", truncation=True, max_length=128)
167
 
168
  logits = model(encoded["input_ids"])
169
- loss = F.binary_cross_entropy_with_logits(logits.squeeze(), label)
170
  loss.backward()
171
  optimizer.step()
172
  optimizer.zero_grad()
 
166
  encoded = tokenizer(input_text, return_tensors="pt", padding="max_length", truncation=True, max_length=128)
167
 
168
  logits = model(encoded["input_ids"])
169
+ loss = F.binary_cross_entropy_with_logits(logits.squeeze(dim=-1), label)
170
  loss.backward()
171
  optimizer.step()
172
  optimizer.zero_grad()