Update tasks/text.py
Browse files- tasks/text.py +2 -2
tasks/text.py
CHANGED
@@ -106,9 +106,9 @@ def bert_classifier(test_dataset: dict, model: str):
|
|
106 |
|
107 |
tokenizer = AutoTokenizer.from_pretrained(model_repo)
|
108 |
|
109 |
-
if model
|
110 |
model = AutoModelForSequenceClassification.from_pretrained(model_repo)
|
111 |
-
elif model
|
112 |
model = SentenceBERTMultiClass.from_pretrained(model_repo)
|
113 |
else:
|
114 |
raise(ValueError)
|
|
|
106 |
|
107 |
tokenizer = AutoTokenizer.from_pretrained(model_repo)
|
108 |
|
109 |
+
if model in ['bert_base_pruned']:
|
110 |
model = AutoModelForSequenceClassification.from_pretrained(model_repo)
|
111 |
+
elif model in ['sbert_distilroberta']:
|
112 |
model = SentenceBERTMultiClass.from_pretrained(model_repo)
|
113 |
else:
|
114 |
raise(ValueError)
|