Terry Zhang
commited on
Commit
·
5225d97
1
Parent(s):
0b1295b
add custom classifiers
Browse files- tasks/custom_classifiers.py +28 -0
- tasks/text.py +7 -2
tasks/custom_classifiers.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import RobertaModel, AutoTokenizer
|
2 |
+
from transformers.modeling_outputs import SequenceClassifierOutput
|
3 |
+
from huggingface_hub import PyTorchModelHubMixin
|
4 |
+
from torch.nn import CrossEntropyLoss
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch
|
7 |
+
|
8 |
+
class SentenceBERTClassifier(nn.Module, PyTorchModelHubMixin):
|
9 |
+
def __init__(self, model_name="sentence-transformers/all-distilroberta-v1", num_labels=8):
|
10 |
+
super().__init__()
|
11 |
+
self.sbert = RobertaModel.from_pretrained(model_name)
|
12 |
+
self.config = self.sbert.config
|
13 |
+
self.config.num_labels = num_labels
|
14 |
+
self.dropout = nn.Dropout(0.05)
|
15 |
+
self.config.classifier_dropout = 0.05
|
16 |
+
self.classifier = nn.Linear(self.config.hidden_size, self.config.num_labels)
|
17 |
+
|
18 |
+
def forward(self, input_ids, attention_mask):
|
19 |
+
outputs = self.sbert(input_ids=input_ids, attention_mask=attention_mask)
|
20 |
+
pooled_output = outputs.pooler_output
|
21 |
+
dropout_output = self.dropout(pooled_output)
|
22 |
+
logits = self.classifier(dropout_output)
|
23 |
+
|
24 |
+
return SequenceClassifierOutput(
|
25 |
+
logits=logits,
|
26 |
+
hidden_states=outputs.hidden_states,
|
27 |
+
attentions=outputs.attentions,
|
28 |
+
)
|
tasks/text.py
CHANGED
@@ -10,12 +10,12 @@ from fastapi import APIRouter
|
|
10 |
from datasets import load_dataset
|
11 |
from sklearn.metrics import accuracy_score
|
12 |
from skops.io import load
|
13 |
-
from huggingface_hub import PyTorchModelHubMixin
|
14 |
|
15 |
from .utils.evaluation import TextEvaluationRequest
|
16 |
from .utils.emissions import tracker, clean_emissions_data, get_space_info
|
17 |
from .utils.text_preprocessor import preprocess
|
18 |
from accelerate.test_utils.testing import get_backend
|
|
|
19 |
|
20 |
router = APIRouter()
|
21 |
|
@@ -86,7 +86,12 @@ def bert_classifier(test_dataset: dict, model: str):
|
|
86 |
|
87 |
tokenizer = AutoTokenizer.from_pretrained(model_repo)
|
88 |
|
89 |
-
model
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
# Use CUDA if available
|
92 |
device, _, _ = get_backend()
|
|
|
10 |
from datasets import load_dataset
|
11 |
from sklearn.metrics import accuracy_score
|
12 |
from skops.io import load
|
|
|
13 |
|
14 |
from .utils.evaluation import TextEvaluationRequest
|
15 |
from .utils.emissions import tracker, clean_emissions_data, get_space_info
|
16 |
from .utils.text_preprocessor import preprocess
|
17 |
from accelerate.test_utils.testing import get_backend
|
18 |
+
from custom_classifiers import SentenceBERTClassifier
|
19 |
|
20 |
router = APIRouter()
|
21 |
|
|
|
86 |
|
87 |
tokenizer = AutoTokenizer.from_pretrained(model_repo)
|
88 |
|
89 |
+
if model in ["bert_base_pruned"]:
|
90 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_repo)
|
91 |
+
elif model in ["sbert_distilroberta"]:
|
92 |
+
model = SentenceBERTClassifier.from_pretrained(model_repo)
|
93 |
+
else:
|
94 |
+
raise(ValueError)
|
95 |
|
96 |
# Use CUDA if available
|
97 |
device, _, _ = get_backend()
|