KoFace-AI / clip /encoders /text_encoder.py
JuyeopDang's picture
Upload 35 files
5ab5cab verified
import torch
import torch.nn as nn
class TextEncoder(nn.Module):
def __init__(self, number_of_features: int, number_of_heads: int, number_of_transformer_layers: int,
context_length, embed_dim):
super().__init__()
self.vocab_size = 32000 # AutoTokenizer: "koclip/koclip-base-pt"
self.token_embedding = nn.Embedding(self.vocab_size, number_of_features)
self.positional_embedding = nn.Parameter(torch.zeros(context_length, number_of_features))
self.transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=number_of_features, nhead=number_of_heads, batch_first=True),
num_layers=number_of_transformer_layers
)
self.text_projection = nn.Linear(number_of_features, embed_dim)
# initialize
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.xavier_uniform_(self.positional_embedding)
nn.init.kaiming_normal_(self.text_projection.weight, nonlinearity='relu')
def forward(self, x):
eot_token_idx = (x == 2).nonzero(as_tuple=True)[1] # Assume EOT token ID is 2
x = self.token_embedding(x)
x = x + self.positional_embedding[:x.size(1), :]
x = self.transformer(x)
x = x[torch.arange(x.shape[0]), eot_token_idx]
x = self.text_projection(x)
return x