Terry Zhang
commited on
Commit
·
af86903
1
Parent(s):
286033a
fix
Browse files- tasks/text.py +2 -2
tasks/text.py
CHANGED
@@ -109,7 +109,7 @@ def bert_classifier(test_dataset: dict, model: str):
|
|
109 |
test_input_ids = batch["input_ids"].to(device)
|
110 |
test_attention_mask = batch["attention_mask"].to(device)
|
111 |
outputs = model(test_input_ids, test_attention_mask)
|
112 |
-
p = torch.argmax(outputs, dim=1)
|
113 |
predictions = np.append(predictions, p.cpu().numpy())
|
114 |
|
115 |
print("Finished BERT model run")
|
@@ -149,7 +149,7 @@ def moe_classifier(test_dataset: dict, model: str):
|
|
149 |
embeddings = embedding_outputs.last_hidden_state[:, 0, :]
|
150 |
|
151 |
outputs = model(embeddings)
|
152 |
-
p = torch.argmax(outputs
|
153 |
predictions = np.append(predictions, p.cpu().numpy())
|
154 |
|
155 |
print("Finished running MoE Classifier")
|
|
|
109 |
test_input_ids = batch["input_ids"].to(device)
|
110 |
test_attention_mask = batch["attention_mask"].to(device)
|
111 |
outputs = model(test_input_ids, test_attention_mask)
|
112 |
+
p = torch.argmax(outputs.logits, dim=1)
|
113 |
predictions = np.append(predictions, p.cpu().numpy())
|
114 |
|
115 |
print("Finished BERT model run")
|
|
|
149 |
embeddings = embedding_outputs.last_hidden_state[:, 0, :]
|
150 |
|
151 |
outputs = model(embeddings)
|
152 |
+
p = torch.argmax(outputs, dim=1)
|
153 |
predictions = np.append(predictions, p.cpu().numpy())
|
154 |
|
155 |
print("Finished running MoE Classifier")
|