submission-template / tasks /custom_classifiers.py
Terry Zhang
add custom classifiers
5225d97
raw
history blame
1.2 kB
from transformers import RobertaModel, AutoTokenizer
from transformers.modeling_outputs import SequenceClassifierOutput
from huggingface_hub import PyTorchModelHubMixin
from torch.nn import CrossEntropyLoss
import torch.nn as nn
import torch
class SentenceBERTClassifier(nn.Module, PyTorchModelHubMixin):
def __init__(self, model_name="sentence-transformers/all-distilroberta-v1", num_labels=8):
super().__init__()
self.sbert = RobertaModel.from_pretrained(model_name)
self.config = self.sbert.config
self.config.num_labels = num_labels
self.dropout = nn.Dropout(0.05)
self.config.classifier_dropout = 0.05
self.classifier = nn.Linear(self.config.hidden_size, self.config.num_labels)
def forward(self, input_ids, attention_mask):
outputs = self.sbert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs.pooler_output
dropout_output = self.dropout(pooled_output)
logits = self.classifier(dropout_output)
return SequenceClassifierOutput(
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)