Terry Zhang commited on
Commit
6737c70
·
1 Parent(s): f7c276d

fix device

Browse files
Files changed (1) hide show
  1. tasks/text.py +3 -2
tasks/text.py CHANGED
@@ -121,12 +121,13 @@ def moe_classifier(test_dataset: dict, model: str):
121
 
122
  # Use CUDA if available
123
  device, _, _ = get_backend()
124
-
125
  texts = test_dataset["quote"]
126
 
127
  model_path = f"tasks/text_models/0131_MoE_final.pt"
128
 
129
- embedding_model = AutoModel.from_pretrained("sentence-transformers/all-distilroberta-v1")
 
130
  tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-distilroberta-v1")
131
 
132
  dataset = TextDataset(texts, tokenizer=tokenizer, max_length=512)
 
121
 
122
  # Use CUDA if available
123
  device, _, _ = get_backend()
124
+
125
  texts = test_dataset["quote"]
126
 
127
  model_path = f"tasks/text_models/0131_MoE_final.pt"
128
 
129
+ embedding_model = AutoModel.from_pretrained("sentence-transformers/all-distilroberta-v1")
130
+ embedding_model = embedding_model.to(device)
131
  tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-distilroberta-v1")
132
 
133
  dataset = TextDataset(texts, tokenizer=tokenizer, max_length=512)