Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
class TextEncoder(nn.Module): | |
def __init__(self, number_of_features: int, number_of_heads: int, number_of_transformer_layers: int, | |
context_length, embed_dim): | |
super().__init__() | |
self.vocab_size = 32000 # AutoTokenizer: "koclip/koclip-base-pt" | |
self.token_embedding = nn.Embedding(self.vocab_size, number_of_features) | |
self.positional_embedding = nn.Parameter(torch.zeros(context_length, number_of_features)) | |
self.transformer = nn.TransformerEncoder( | |
nn.TransformerEncoderLayer(d_model=number_of_features, nhead=number_of_heads, batch_first=True), | |
num_layers=number_of_transformer_layers | |
) | |
self.text_projection = nn.Linear(number_of_features, embed_dim) | |
# initialize | |
nn.init.normal_(self.token_embedding.weight, std=0.02) | |
nn.init.xavier_uniform_(self.positional_embedding) | |
nn.init.kaiming_normal_(self.text_projection.weight, nonlinearity='relu') | |
def forward(self, x): | |
eot_token_idx = (x == 2).nonzero(as_tuple=True)[1] # Assume EOT token ID is 2 | |
x = self.token_embedding(x) | |
x = x + self.positional_embedding[:x.size(1), :] | |
x = self.transformer(x) | |
x = x[torch.arange(x.shape[0]), eot_token_idx] | |
x = self.text_projection(x) | |
return x |