|
from transformers import RobertaModel, AutoTokenizer |
|
from transformers.modeling_outputs import SequenceClassifierOutput |
|
from huggingface_hub import PyTorchModelHubMixin |
|
from torch.nn import CrossEntropyLoss |
|
import torch.nn as nn |
|
import torch |
|
|
|
class SentenceBERTClassifier(nn.Module, PyTorchModelHubMixin): |
|
def __init__(self, model_name="sentence-transformers/all-distilroberta-v1", num_labels=8): |
|
super().__init__() |
|
self.sbert = RobertaModel.from_pretrained(model_name) |
|
self.config = self.sbert.config |
|
self.config.num_labels = num_labels |
|
self.dropout = nn.Dropout(0.05) |
|
self.config.classifier_dropout = 0.05 |
|
self.classifier = nn.Linear(self.config.hidden_size, self.config.num_labels) |
|
|
|
def forward(self, input_ids, attention_mask): |
|
outputs = self.sbert(input_ids=input_ids, attention_mask=attention_mask) |
|
pooled_output = outputs.pooler_output |
|
dropout_output = self.dropout(pooled_output) |
|
logits = self.classifier(dropout_output) |
|
|
|
return SequenceClassifierOutput( |
|
logits=logits, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |