theterryzhang commited on
Commit
0b1295b
·
verified ·
1 Parent(s): 330e1bf

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +1 -26
tasks/text.py CHANGED
@@ -30,26 +30,6 @@ models_descriptions = {
30
  "sbert_distilroberta": "Fine-tuned sentence transformer DistilRoBERTa"
31
  }
32
 
33
- class SentenceBERTMultiClass(nn.Module, PyTorchModelHubMixin):
34
- def __init__(self, model_name, num_labels=8):
35
- super().__init__()
36
- self.sbert = AutoModel.from_pretrained(model_name)
37
- self.config = self.sbert.config
38
- self.dropout = nn.Dropout(0.05)
39
- self.classifier = nn.Linear(self.sbert.config.hidden_size, num_labels)
40
-
41
- def forward(self, input_ids, attention_mask):
42
- outputs = self.sbert(input_ids=input_ids, attention_mask=attention_mask)
43
- if hasattr(outputs, "pooler_output"):
44
- pooled_output = outputs.pooler_output
45
- else:
46
- pooled_output = outputs.last_hidden_state.mean(dim=1)
47
-
48
- dropout_output = self.dropout(pooled_output)
49
- logits = self.classifier(dropout_output)
50
-
51
- return logits
52
-
53
 
54
  def baseline_model(dataset_length: int):
55
  # Make random predictions (placeholder for actual model inference)
@@ -106,12 +86,7 @@ def bert_classifier(test_dataset: dict, model: str):
106
 
107
  tokenizer = AutoTokenizer.from_pretrained(model_repo)
108
 
109
- if model in ['bert_base_pruned']:
110
- model = AutoModelForSequenceClassification.from_pretrained(model_repo)
111
- elif model in ['sbert_distilroberta']:
112
- model = SentenceBERTMultiClass.from_pretrained(model_repo)
113
- else:
114
- raise(ValueError)
115
 
116
  # Use CUDA if available
117
  device, _, _ = get_backend()
 
30
  "sbert_distilroberta": "Fine-tuned sentence transformer DistilRoBERTa"
31
  }
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  def baseline_model(dataset_length: int):
35
  # Make random predictions (placeholder for actual model inference)
 
86
 
87
  tokenizer = AutoTokenizer.from_pretrained(model_repo)
88
 
89
+ model = AutoModelForSequenceClassification.from_pretrained(model_repo)
 
 
 
 
 
90
 
91
  # Use CUDA if available
92
  device, _, _ = get_backend()