absa-app / model.py
asmashayea's picture
t
1c18be2
raw
history blame
1.35 kB
import torch
import torch.nn as nn
from torchcrf import CRF
class BERT_BiLSTM_CRF(nn.Module):
def __init__(self, base_model, num_labels, rnn_dim=256, dropout_rate=0.2):
super().__init__()
self.bert = base_model
self.bilstm = nn.LSTM(
input_size=self.bert.config.hidden_size,
hidden_size=rnn_dim,
num_layers=2,
batch_first=True,
bidirectional=True,
dropout=dropout_rate
)
self.dropout = nn.Dropout(dropout_rate)
self.classifier = nn.Linear(rnn_dim * 2, num_labels)
self.crf = CRF(num_labels, batch_first=True)
def forward(self, input_ids, attention_mask, token_type_ids=None, labels=None):
bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
lstm_out, _ = self.bilstm(self.dropout(bert_output))
emissions = self.classifier(lstm_out)
mask = attention_mask.bool()
if labels is not None:
safe_labels = labels.clone()
safe_labels[labels == -100] = 0
loss = -self.crf(emissions, safe_labels, mask=mask, reduction='mean')
return {'loss': loss, 'logits': emissions}
else:
decoded = self.crf.decode(emissions, mask=mask)
return {'logits': torch.tensor(decoded)}