Terry Zhang commited on
Commit
286033a
·
1 Parent(s): 6737c70

remove logits

Browse files
Files changed (1) hide show
  1. tasks/text.py +1 -1
tasks/text.py CHANGED
@@ -109,7 +109,7 @@ def bert_classifier(test_dataset: dict, model: str):
109
  test_input_ids = batch["input_ids"].to(device)
110
  test_attention_mask = batch["attention_mask"].to(device)
111
  outputs = model(test_input_ids, test_attention_mask)
112
- p = torch.argmax(outputs.logits, dim=1)
113
  predictions = np.append(predictions, p.cpu().numpy())
114
 
115
  print("Finished BERT model run")
 
109
  test_input_ids = batch["input_ids"].to(device)
110
  test_attention_mask = batch["attention_mask"].to(device)
111
  outputs = model(test_input_ids, test_attention_mask)
112
+ p = torch.argmax(outputs, dim=1)
113
  predictions = np.append(predictions, p.cpu().numpy())
114
 
115
  print("Finished BERT model run")