File size: 1,361 Bytes
5ab5cab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
import torch.nn as nn

class TextEncoder(nn.Module):
    def __init__(self, number_of_features: int, number_of_heads: int, number_of_transformer_layers: int,
                 context_length, embed_dim):
        super().__init__()
        self.vocab_size = 32000  # AutoTokenizer: "koclip/koclip-base-pt"
        self.token_embedding = nn.Embedding(self.vocab_size, number_of_features)
        self.positional_embedding = nn.Parameter(torch.zeros(context_length, number_of_features))
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=number_of_features, nhead=number_of_heads, batch_first=True),
            num_layers=number_of_transformer_layers
        )
        self.text_projection = nn.Linear(number_of_features, embed_dim)

        # initialize
        nn.init.normal_(self.token_embedding.weight, std=0.02)
        nn.init.xavier_uniform_(self.positional_embedding)
        nn.init.kaiming_normal_(self.text_projection.weight, nonlinearity='relu')

    def forward(self, x):
        eot_token_idx = (x == 2).nonzero(as_tuple=True)[1]  # Assume EOT token ID is 2
        x = self.token_embedding(x)
        x = x + self.positional_embedding[:x.size(1), :]
        x = self.transformer(x)
        x = x[torch.arange(x.shape[0]), eot_token_idx]
        x = self.text_projection(x)
        return x