Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
from torchcrf import CRF | |
class BERT_BiLSTM_CRF(nn.Module): | |
def __init__(self, base_model, num_labels, rnn_dim=256, dropout_rate=0.2): | |
super().__init__() | |
self.bert = base_model | |
self.bilstm = nn.LSTM( | |
input_size=self.bert.config.hidden_size, | |
hidden_size=rnn_dim, | |
num_layers=2, | |
batch_first=True, | |
bidirectional=True, | |
dropout=dropout_rate | |
) | |
self.dropout = nn.Dropout(dropout_rate) | |
self.classifier = nn.Linear(rnn_dim * 2, num_labels) | |
self.crf = CRF(num_labels, batch_first=True) | |
def forward(self, input_ids, attention_mask, token_type_ids=None, labels=None): | |
bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state | |
lstm_out, _ = self.bilstm(self.dropout(bert_output)) | |
emissions = self.classifier(lstm_out) | |
mask = attention_mask.bool() | |
if labels is not None: | |
safe_labels = labels.clone() | |
safe_labels[labels == -100] = 0 | |
loss = -self.crf(emissions, safe_labels, mask=mask, reduction='mean') | |
return {'loss': loss, 'logits': emissions} | |
else: | |
decoded = self.crf.decode(emissions, mask=mask) | |
return {'logits': torch.tensor(decoded)} | |