AkashDataScience's picture
First commit
dfae4e0
raw
history blame contribute delete
791 Bytes
import torch.nn as nn
class Projections(nn.Module):
def __init__(self, clip_embed, phi_embed, num_projection_layers=6):
super().__init__()
self.output = nn.Linear(clip_embed, phi_embed)
self.norm = nn.LayerNorm(phi_embed)
self.projection_layers = nn.ModuleList(
[
nn.Sequential(
nn.Linear(phi_embed, phi_embed),
nn.GELU(),
nn.Linear(phi_embed, phi_embed),
)
for _ in range(num_projection_layers)
]
)
def forward(self, x):
x = self.output(x)
x = self.norm(x)
for layer in self.projection_layers:
residual = x
x = layer(x) + residual
return x