import torch import torch.nn as nn import torch.nn.functional as F import yaml import numpy as np from clip.encoders.image_encoder import ImageEncoder from clip.encoders.text_encoder import TextEncoder from helper.tokenizer import Tokenizer class CLIP(nn.Module): def __init__(self, config_path): super().__init__() with open(config_path, "r") as file: config = yaml.safe_load(file) self.image_encoder = ImageEncoder(**config["image_encoder"]) self.text_encoder = TextEncoder(**config["text_encoder"]) self.tokenizer = Tokenizer() self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) # initialize for module in self.modules(): if isinstance(module, nn.Linear): nn.init.xavier_normal_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) elif isinstance(module, nn.Conv2d): nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)): nn.init.constant_(module.weight, 1) nn.init.constant_(module.bias, 0) def loss(self, image, text): image_features, text_features = self(image, text, tokenize=False) # Normalize features image_features = F.normalize(image_features, dim=1) text_features = F.normalize(text_features, dim=1) # Cosine similarity as logits with learned temperature logits = torch.matmul(image_features, text_features.t()) * self.logit_scale.exp() labels = torch.arange(logits.shape[0], dtype=torch.long, device=logits.device) # Cross-entropy loss loss_i2t = F.cross_entropy(logits, labels) loss_t2i = F.cross_entropy(logits.t(), labels) return (loss_i2t + loss_t2i) / 2 def text_encode(self, text, tokenize=True): if tokenize: tokens = self.tokenizer.tokenize(text) else: tokens = text text_features = self.text_encoder(tokens) if text_features.dim() < 2: text_features = text_features.unsqueeze(0) return text_features def forward(self, image, text, tokenize=True): image_features = self.image_encoder(image) text_features = self.text_encoder(text, tokenize) if image_features.dim() < 2: image_features = image_features.unsqueeze(0) if text_features.dim() < 2: text_features = text_features.unsqueeze(0) return image_features, text_features