File size: 3,619 Bytes
4fb0bd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import logging

import torch
import torch.nn as nn
import numpy as np

from models.embedding_models.bert_embedding_model import BertEmbedModel
from models.embedding_models.pretrained_embedding_model import PretrainedEmbedModel
from modules.token_embedders.bert_encoder import BertLinear
from modules.token_embedders.bert_encoder import BertLayerNorm
from transformers import BertModel

logger = logging.getLogger(__name__)


class RelDecoder(nn.Module):
    def __init__(self, cfg, vocab, ent_rel_file):
        """__init__ constructs `EntRelJointDecoder` components and
        sets `EntRelJointDecoder` parameters. This class adopts a joint
        decoding algorithm for entity relation joint decoing and facilitates
        the interaction between entity and relation.

        Args:
            cfg (dict): config parameters for constructing multiple models
            vocab (Vocabulary): vocabulary
            ent_rel_file (dict): entity and relation file (joint id, entity id, relation id, symmetric id, asymmetric id)
        """

        super().__init__()

        self.num_labels = 3
        self.vocab = vocab
        self.max_span_length = cfg.max_span_length
        self.device = cfg.device

        # if cfg.rel_embedding_model == 'bert':
        self.embedding_model = BertEmbedModel(cfg, vocab, True)
        # elif cfg.rel_embedding_model == 'pretrained':
        #     self.embedding_model = PretrainedEmbedModel(cfg, vocab)
        self.encoder_output_size = self.embedding_model.get_hidden_size()

        self.layer_norm = BertLayerNorm(self.encoder_output_size * 2)
        self.classifier = nn.Linear(self.encoder_output_size * 2, self.num_labels)
        self.classifier.weight.data.normal_(mean=0.0, std=0.02)
        self.classifier.bias.data.zero_()

        if cfg.logit_dropout > 0:
            self.dropout = nn.Dropout(p=cfg.logit_dropout)
        else:
            self.dropout = lambda x: x

        self.none_idx = self.vocab.get_token_index('None', 'ent_rel_id')

        self.rel_label = torch.LongTensor(ent_rel_file["relation"])
        if self.device > -1:
            self.rel_label = self.rel_label.cuda(device=self.device, non_blocking=True)

        self.element_loss = nn.CrossEntropyLoss()

    def forward(self, batch_inputs):
        """forward

        Arguments:
            batch_inputs {dict} -- batch input data

        Returns:
            dict -- results: ent_loss, ent_pred
        """

        results = {}

        self.embedding_model(batch_inputs)
        batch_seq_tokens_encoder_repr = batch_inputs['seq_encoder_reprs']
        relation_tokens = batch_seq_tokens_encoder_repr[torch.arange(batch_seq_tokens_encoder_repr.shape[0]).unsqueeze(-1),
                                                        batch_inputs["relation_ids"]]
        argument_tokens = batch_seq_tokens_encoder_repr[torch.arange(batch_seq_tokens_encoder_repr.shape[0]).unsqueeze(-1),
                                                        batch_inputs["argument_ids"]]
        batch_input_rep = torch.cat((relation_tokens, argument_tokens), dim=-1)
        batch_input_rep = self.layer_norm(batch_input_rep)
        batch_input_rep = self.dropout(batch_input_rep)
        batch_logits = self.classifier(batch_input_rep)

        if not self.training:
            results['label_preds'] = torch.argmax(batch_logits, dim=-1) * batch_inputs['label_ids_mask']
            return results

        results['loss'] = self.element_loss(
            batch_logits[batch_inputs['label_ids_mask']],
            batch_inputs['label_ids'][batch_inputs['label_ids_mask']])

        return results