|
import torch.nn as nn |
|
|
|
class Projections(nn.Module): |
|
def __init__(self, clip_embed, phi_embed, num_projection_layers=6): |
|
super().__init__() |
|
|
|
self.output = nn.Linear(clip_embed, phi_embed) |
|
self.norm = nn.LayerNorm(phi_embed) |
|
self.projection_layers = nn.ModuleList( |
|
[ |
|
nn.Sequential( |
|
nn.Linear(phi_embed, phi_embed), |
|
nn.GELU(), |
|
nn.Linear(phi_embed, phi_embed), |
|
) |
|
for _ in range(num_projection_layers) |
|
] |
|
) |
|
|
|
def forward(self, x): |
|
x = self.output(x) |
|
x = self.norm(x) |
|
for layer in self.projection_layers: |
|
residual = x |
|
x = layer(x) + residual |
|
|
|
return x |