Spaces:
Running
Running
File size: 1,749 Bytes
e90dd4b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
import torch
import torch.nn as nn
from torchcrf import CRF
class BERT_BiLSTM_CRF(nn.Module):
def __init__(self, base_model, config, dropout_rate=0.2, rnn_dim=256):
super().__init__()
self.bert = base_model
self.label2id = config.label2id # <-- pulled from config
self.id2label = config.id2label
self.num_labels = config.num_labels
self.bilstm = nn.LSTM(
self.bert.config.hidden_size,
rnn_dim,
num_layers=2,
batch_first=True,
bidirectional=True,
dropout=0.2
)
self.dropout = nn.Dropout(dropout_rate)
self.classifier = nn.Linear(rnn_dim * 2, self.num_labels)
self.crf = CRF(self.num_labels, batch_first=True)
def forward(self, input_ids, attention_mask, token_type_ids=None, labels=None):
outputs = self.bert(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids
)
lstm_out, _ = self.bilstm(self.dropout(outputs.last_hidden_state))
emissions = self.classifier(lstm_out)
mask = attention_mask.bool()
if labels is not None:
safe_labels = labels.clone()
safe_labels[labels == -100] = self.label2id['O']
loss = -self.crf(emissions, safe_labels, mask=mask, reduction='mean')
return {'loss': loss, 'logits': emissions}
else:
decoded = self.crf.decode(emissions, mask=mask)
max_len = input_ids.shape[1]
padded_decoded = [seq + [0] * (max_len - len(seq)) for seq in decoded]
logits = torch.tensor(padded_decoded, device=input_ids.device)
return {'logits': logits}
|