Terry Zhang commited on
Commit
f7c276d
·
1 Parent(s): 873e38f
Files changed (1) hide show
  1. tasks/text.py +6 -6
tasks/text.py CHANGED
@@ -118,20 +118,20 @@ def bert_classifier(test_dataset: dict, model: str):
118
 
119
  def moe_classifier(test_dataset: dict, model: str):
120
  print("Starting MoE run")
 
 
 
 
121
  texts = test_dataset["quote"]
 
122
  model_path = f"tasks/text_models/0131_MoE_final.pt"
123
 
124
- embedding_model = AutoModel.from_pretrained("sentence-transformers/all-distilroberta-v1")
125
- embedding_model.to(device)
126
-
127
  tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-distilroberta-v1")
128
 
129
  dataset = TextDataset(texts, tokenizer=tokenizer, max_length=512)
130
  dataloader = DataLoader(dataset, batch_size=64, shuffle=False)
131
 
132
- # Use CUDA if available
133
- device, _, _ = get_backend()
134
-
135
  model = MoEClassifier(3, 0.05)
136
  model.load_state_dict(torch.load(model_path))
137
  model = model.to(device)
 
118
 
119
  def moe_classifier(test_dataset: dict, model: str):
120
  print("Starting MoE run")
121
+
122
+ # Use CUDA if available
123
+ device, _, _ = get_backend()
124
+
125
  texts = test_dataset["quote"]
126
+
127
  model_path = f"tasks/text_models/0131_MoE_final.pt"
128
 
129
+ embedding_model = AutoModel.from_pretrained("sentence-transformers/all-distilroberta-v1")
 
 
130
  tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-distilroberta-v1")
131
 
132
  dataset = TextDataset(texts, tokenizer=tokenizer, max_length=512)
133
  dataloader = DataLoader(dataset, batch_size=64, shuffle=False)
134
 
 
 
 
135
  model = MoEClassifier(3, 0.05)
136
  model.load_state_dict(torch.load(model_path))
137
  model = model.to(device)