Terry Zhang commited on
Commit
873e38f
·
1 Parent(s): 56d7bf2

add device

Browse files
Files changed (1) hide show
  1. tasks/text.py +2 -0
tasks/text.py CHANGED
@@ -122,6 +122,8 @@ def moe_classifier(test_dataset: dict, model: str):
122
  model_path = f"tasks/text_models/0131_MoE_final.pt"
123
 
124
  embedding_model = AutoModel.from_pretrained("sentence-transformers/all-distilroberta-v1")
 
 
125
  tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-distilroberta-v1")
126
 
127
  dataset = TextDataset(texts, tokenizer=tokenizer, max_length=512)
 
122
  model_path = f"tasks/text_models/0131_MoE_final.pt"
123
 
124
  embedding_model = AutoModel.from_pretrained("sentence-transformers/all-distilroberta-v1")
125
+ embedding_model.to(device)
126
+
127
  tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-distilroberta-v1")
128
 
129
  dataset = TextDataset(texts, tokenizer=tokenizer, max_length=512)