import torch import torch.nn as nn from torchcrf import CRF class BERT_BiLSTM_CRF(nn.Module): def __init__(self, base_model, config, dropout_rate=0.2, rnn_dim=256): super().__init__() self.bert = base_model self.label2id = config.label2id # <-- pulled from config self.id2label = config.id2label self.num_labels = config.num_labels self.bilstm = nn.LSTM( self.bert.config.hidden_size, rnn_dim, num_layers=2, batch_first=True, bidirectional=True, dropout=0.2 ) self.dropout = nn.Dropout(dropout_rate) self.classifier = nn.Linear(rnn_dim * 2, self.num_labels) self.crf = CRF(self.num_labels, batch_first=True) def forward(self, input_ids, attention_mask, token_type_ids=None, labels=None): outputs = self.bert( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids ) lstm_out, _ = self.bilstm(self.dropout(outputs.last_hidden_state)) emissions = self.classifier(lstm_out) mask = attention_mask.bool() if labels is not None: safe_labels = labels.clone() safe_labels[labels == -100] = self.label2id['O'] loss = -self.crf(emissions, safe_labels, mask=mask, reduction='mean') return {'loss': loss, 'logits': emissions} else: decoded = self.crf.decode(emissions, mask=mask) max_len = input_ids.shape[1] padded_decoded = [seq + [0] * (max_len - len(seq)) for seq in decoded] logits = torch.tensor(padded_decoded, device=input_ids.device) return {'logits': logits}