theterryzhang commited on
Commit
330e1bf
·
verified ·
1 Parent(s): df46342

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +2 -2
tasks/text.py CHANGED
@@ -106,9 +106,9 @@ def bert_classifier(test_dataset: dict, model: str):
106
 
107
  tokenizer = AutoTokenizer.from_pretrained(model_repo)
108
 
109
- if model.isin(['bert_base_pruned']):
110
  model = AutoModelForSequenceClassification.from_pretrained(model_repo)
111
- elif model.isin(['sbert_distilroberta']):
112
  model = SentenceBERTMultiClass.from_pretrained(model_repo)
113
  else:
114
  raise(ValueError)
 
106
 
107
  tokenizer = AutoTokenizer.from_pretrained(model_repo)
108
 
109
+ if model in ['bert_base_pruned']:
110
  model = AutoModelForSequenceClassification.from_pretrained(model_repo)
111
+ elif model in ['sbert_distilroberta']:
112
  model = SentenceBERTMultiClass.from_pretrained(model_repo)
113
  else:
114
  raise(ValueError)