Terry Zhang commited on
Commit
5225d97
·
1 Parent(s): 0b1295b

add custom classifiers

Browse files
Files changed (2) hide show
  1. tasks/custom_classifiers.py +28 -0
  2. tasks/text.py +7 -2
tasks/custom_classifiers.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import RobertaModel, AutoTokenizer
2
+ from transformers.modeling_outputs import SequenceClassifierOutput
3
+ from huggingface_hub import PyTorchModelHubMixin
4
+ from torch.nn import CrossEntropyLoss
5
+ import torch.nn as nn
6
+ import torch
7
+
8
+ class SentenceBERTClassifier(nn.Module, PyTorchModelHubMixin):
9
+ def __init__(self, model_name="sentence-transformers/all-distilroberta-v1", num_labels=8):
10
+ super().__init__()
11
+ self.sbert = RobertaModel.from_pretrained(model_name)
12
+ self.config = self.sbert.config
13
+ self.config.num_labels = num_labels
14
+ self.dropout = nn.Dropout(0.05)
15
+ self.config.classifier_dropout = 0.05
16
+ self.classifier = nn.Linear(self.config.hidden_size, self.config.num_labels)
17
+
18
+ def forward(self, input_ids, attention_mask):
19
+ outputs = self.sbert(input_ids=input_ids, attention_mask=attention_mask)
20
+ pooled_output = outputs.pooler_output
21
+ dropout_output = self.dropout(pooled_output)
22
+ logits = self.classifier(dropout_output)
23
+
24
+ return SequenceClassifierOutput(
25
+ logits=logits,
26
+ hidden_states=outputs.hidden_states,
27
+ attentions=outputs.attentions,
28
+ )
tasks/text.py CHANGED
@@ -10,12 +10,12 @@ from fastapi import APIRouter
10
  from datasets import load_dataset
11
  from sklearn.metrics import accuracy_score
12
  from skops.io import load
13
- from huggingface_hub import PyTorchModelHubMixin
14
 
15
  from .utils.evaluation import TextEvaluationRequest
16
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
17
  from .utils.text_preprocessor import preprocess
18
  from accelerate.test_utils.testing import get_backend
 
19
 
20
  router = APIRouter()
21
 
@@ -86,7 +86,12 @@ def bert_classifier(test_dataset: dict, model: str):
86
 
87
  tokenizer = AutoTokenizer.from_pretrained(model_repo)
88
 
89
- model = AutoModelForSequenceClassification.from_pretrained(model_repo)
 
 
 
 
 
90
 
91
  # Use CUDA if available
92
  device, _, _ = get_backend()
 
10
  from datasets import load_dataset
11
  from sklearn.metrics import accuracy_score
12
  from skops.io import load
 
13
 
14
  from .utils.evaluation import TextEvaluationRequest
15
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
16
  from .utils.text_preprocessor import preprocess
17
  from accelerate.test_utils.testing import get_backend
18
+ from custom_classifiers import SentenceBERTClassifier
19
 
20
  router = APIRouter()
21
 
 
86
 
87
  tokenizer = AutoTokenizer.from_pretrained(model_repo)
88
 
89
+ if model in ["bert_base_pruned"]:
90
+ model = AutoModelForSequenceClassification.from_pretrained(model_repo)
91
+ elif model in ["sbert_distilroberta"]:
92
+ model = SentenceBERTClassifier.from_pretrained(model_repo)
93
+ else:
94
+ raise(ValueError)
95
 
96
  # Use CUDA if available
97
  device, _, _ = get_backend()