Terry Zhang
commited on
Commit
·
286033a
1
Parent(s):
6737c70
remove logits
Browse files- tasks/text.py +1 -1
tasks/text.py
CHANGED
@@ -109,7 +109,7 @@ def bert_classifier(test_dataset: dict, model: str):
|
|
109 |
test_input_ids = batch["input_ids"].to(device)
|
110 |
test_attention_mask = batch["attention_mask"].to(device)
|
111 |
outputs = model(test_input_ids, test_attention_mask)
|
112 |
-
p = torch.argmax(outputs
|
113 |
predictions = np.append(predictions, p.cpu().numpy())
|
114 |
|
115 |
print("Finished BERT model run")
|
|
|
109 |
test_input_ids = batch["input_ids"].to(device)
|
110 |
test_attention_mask = batch["attention_mask"].to(device)
|
111 |
outputs = model(test_input_ids, test_attention_mask)
|
112 |
+
p = torch.argmax(outputs, dim=1)
|
113 |
predictions = np.append(predictions, p.cpu().numpy())
|
114 |
|
115 |
print("Finished BERT model run")
|