File size: 1,749 Bytes
e90dd4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
import torch.nn as nn
from torchcrf import CRF

class BERT_BiLSTM_CRF(nn.Module):
    def __init__(self, base_model, config, dropout_rate=0.2, rnn_dim=256):
        super().__init__()
        self.bert = base_model
        self.label2id = config.label2id  # <-- pulled from config
        self.id2label = config.id2label
        self.num_labels = config.num_labels

        self.bilstm = nn.LSTM(
            self.bert.config.hidden_size,
            rnn_dim,
            num_layers=2,
            batch_first=True,
            bidirectional=True,
            dropout=0.2
        )
        self.dropout = nn.Dropout(dropout_rate)
        self.classifier = nn.Linear(rnn_dim * 2, self.num_labels)
        self.crf = CRF(self.num_labels, batch_first=True)

    def forward(self, input_ids, attention_mask, token_type_ids=None, labels=None):
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        lstm_out, _ = self.bilstm(self.dropout(outputs.last_hidden_state))
        emissions = self.classifier(lstm_out)
        mask = attention_mask.bool()

        if labels is not None:
            safe_labels = labels.clone()
            safe_labels[labels == -100] = self.label2id['O']
            loss = -self.crf(emissions, safe_labels, mask=mask, reduction='mean')
            return {'loss': loss, 'logits': emissions}
        else:
            decoded = self.crf.decode(emissions, mask=mask)
            max_len = input_ids.shape[1]
            padded_decoded = [seq + [0] * (max_len - len(seq)) for seq in decoded]
            logits = torch.tensor(padded_decoded, device=input_ids.device)
            return {'logits': logits}