Terry Zhang commited on
Commit
af86903
·
1 Parent(s): 286033a
Files changed (1) hide show
  1. tasks/text.py +2 -2
tasks/text.py CHANGED
@@ -109,7 +109,7 @@ def bert_classifier(test_dataset: dict, model: str):
109
  test_input_ids = batch["input_ids"].to(device)
110
  test_attention_mask = batch["attention_mask"].to(device)
111
  outputs = model(test_input_ids, test_attention_mask)
112
- p = torch.argmax(outputs, dim=1)
113
  predictions = np.append(predictions, p.cpu().numpy())
114
 
115
  print("Finished BERT model run")
@@ -149,7 +149,7 @@ def moe_classifier(test_dataset: dict, model: str):
149
  embeddings = embedding_outputs.last_hidden_state[:, 0, :]
150
 
151
  outputs = model(embeddings)
152
- p = torch.argmax(outputs.logits, dim=1)
153
  predictions = np.append(predictions, p.cpu().numpy())
154
 
155
  print("Finished running MoE Classifier")
 
109
  test_input_ids = batch["input_ids"].to(device)
110
  test_attention_mask = batch["attention_mask"].to(device)
111
  outputs = model(test_input_ids, test_attention_mask)
112
+ p = torch.argmax(outputs.logits, dim=1)
113
  predictions = np.append(predictions, p.cpu().numpy())
114
 
115
  print("Finished BERT model run")
 
149
  embeddings = embedding_outputs.last_hidden_state[:, 0, :]
150
 
151
  outputs = model(embeddings)
152
+ p = torch.argmax(outputs, dim=1)
153
  predictions = np.append(predictions, p.cpu().numpy())
154
 
155
  print("Finished running MoE Classifier")