Terry Zhang
commited on
Commit
·
873e38f
1
Parent(s):
56d7bf2
add device
Browse files- tasks/text.py +2 -0
tasks/text.py
CHANGED
@@ -122,6 +122,8 @@ def moe_classifier(test_dataset: dict, model: str):
|
|
122 |
model_path = f"tasks/text_models/0131_MoE_final.pt"
|
123 |
|
124 |
embedding_model = AutoModel.from_pretrained("sentence-transformers/all-distilroberta-v1")
|
|
|
|
|
125 |
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-distilroberta-v1")
|
126 |
|
127 |
dataset = TextDataset(texts, tokenizer=tokenizer, max_length=512)
|
|
|
122 |
model_path = f"tasks/text_models/0131_MoE_final.pt"
|
123 |
|
124 |
embedding_model = AutoModel.from_pretrained("sentence-transformers/all-distilroberta-v1")
|
125 |
+
embedding_model.to(device)
|
126 |
+
|
127 |
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-distilroberta-v1")
|
128 |
|
129 |
dataset = TextDataset(texts, tokenizer=tokenizer, max_length=512)
|