File size: 1,354 Bytes
1c18be2
e3a1adb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
import torch.nn as nn
from torchcrf import CRF

class BERT_BiLSTM_CRF(nn.Module):
    def __init__(self, base_model, num_labels, rnn_dim=256, dropout_rate=0.2):
        super().__init__()
        self.bert = base_model
        self.bilstm = nn.LSTM(
            input_size=self.bert.config.hidden_size,
            hidden_size=rnn_dim,
            num_layers=2,
            batch_first=True,
            bidirectional=True,
            dropout=dropout_rate
        )
        self.dropout = nn.Dropout(dropout_rate)
        self.classifier = nn.Linear(rnn_dim * 2, num_labels)
        self.crf = CRF(num_labels, batch_first=True)

    def forward(self, input_ids, attention_mask, token_type_ids=None, labels=None):
        bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
        lstm_out, _ = self.bilstm(self.dropout(bert_output))
        emissions = self.classifier(lstm_out)
        mask = attention_mask.bool()

        if labels is not None:
            safe_labels = labels.clone()
            safe_labels[labels == -100] = 0
            loss = -self.crf(emissions, safe_labels, mask=mask, reduction='mean')
            return {'loss': loss, 'logits': emissions}
        else:
            decoded = self.crf.decode(emissions, mask=mask)
            return {'logits': torch.tensor(decoded)}