File size: 2,878 Bytes
4fb0bd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import logging

import torch
import torch.nn as nn

from transformers import AutoModel

from utils.nn_utils import gelu
from modules.token_embedders.bert_encoder import BertLinear

logger = logging.getLogger(__name__)


class PretrainedEncoder(nn.Module):
    """This class using pre-trained model to encode token,
    then fine-tuning the pre-trained model
    """
    def __init__(self, pretrained_model_name, trainable=False, output_size=0, activation=gelu, dropout=0.0):
        """This function initialize pertrained model

        Arguments:
            pretrained_model_name {str} -- pre-trained model name

        Keyword Arguments:
            output_size {float} -- output size (default: {None})
            activation {nn.Module} -- activation function (default: {gelu})
            dropout {float} -- dropout rate (default: {0.0})
        """

        super().__init__()
        self.pretrained_model = AutoModel.from_pretrained(pretrained_model_name)
        logger.info("Load pre-trained model {} successfully.".format(pretrained_model_name))

        self.output_size = output_size

        if trainable:
            logger.info("Start fine-tuning pre-trained model {}.".format(pretrained_model_name))
        else:
            logger.info("Keep fixed pre-trained model {}.".format(pretrained_model_name))

        for param in self.pretrained_model.parameters():
            param.requires_grad = trainable

        if self.output_size > 0:
            self.mlp = BertLinear(input_size=self.pretrained_model.config.hidden_size,
                                  output_size=self.output_size,
                                  activation=activation)
        else:
            self.output_size = self.pretrained_model.config.hidden_size
            self.mlp = lambda x: x

        if dropout > 0:
            self.dropout = nn.Dropout(p=dropout)
        else:
            self.dropout = lambda x: x

    def get_output_dims(self):
        return self.output_size

    def forward(self, seq_inputs, token_type_inputs=None):
        """forward calculates forward propagation results, get token embedding

        Args:
            seq_inputs {tensor} -- sequence inputs (tokenized)
            token_type_inputs (tensor, optional): token type inputs. Defaults to None.

        Returns:
            tensor: bert output for tokens
        """

        if token_type_inputs is None:
            token_type_inputs = torch.zeros_like(seq_inputs)
        mask_inputs = (seq_inputs != 0).long()

        outputs = self.pretrained_model(input_ids=seq_inputs,
                                        token_type_ids=token_type_inputs,
                                        attention_mask=mask_inputs)
        last_hidden_state = outputs[0]
        pooled_output = outputs[1]

        return self.dropout(self.mlp(last_hidden_state)), self.dropout(self.mlp(pooled_output))