import torch.nn as nn class Projections(nn.Module): def __init__(self, clip_embed, phi_embed, num_projection_layers=6): super().__init__() self.output = nn.Linear(clip_embed, phi_embed) self.norm = nn.LayerNorm(phi_embed) self.projection_layers = nn.ModuleList( [ nn.Sequential( nn.Linear(phi_embed, phi_embed), nn.GELU(), nn.Linear(phi_embed, phi_embed), ) for _ in range(num_projection_layers) ] ) def forward(self, x): x = self.output(x) x = self.norm(x) for layer in self.projection_layers: residual = x x = layer(x) + residual return x