File size: 2,690 Bytes
5ab5cab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch
import torch.nn as nn
import torch.nn.functional as F
import yaml
import numpy as np

from clip.encoders.image_encoder import ImageEncoder
from clip.encoders.text_encoder import TextEncoder
from helper.tokenizer import Tokenizer

class CLIP(nn.Module):
    def __init__(self, config_path):
        super().__init__()
        with open(config_path, "r") as file:
           config = yaml.safe_load(file)
           
        self.image_encoder = ImageEncoder(**config["image_encoder"])
        self.text_encoder = TextEncoder(**config["text_encoder"])
        self.tokenizer = Tokenizer()
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
        
        # initialize
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_normal_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
            elif isinstance(module, nn.Conv2d):
                nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)):
                nn.init.constant_(module.weight, 1)
                nn.init.constant_(module.bias, 0)
                
    def loss(self, image, text):
        image_features, text_features = self(image, text, tokenize=False)

        # Normalize features
        image_features = F.normalize(image_features, dim=1)
        text_features = F.normalize(text_features, dim=1)

        # Cosine similarity as logits with learned temperature
        logits = torch.matmul(image_features, text_features.t()) * self.logit_scale.exp()
        labels = torch.arange(logits.shape[0], dtype=torch.long, device=logits.device)

        # Cross-entropy loss
        loss_i2t = F.cross_entropy(logits, labels)
        loss_t2i = F.cross_entropy(logits.t(), labels)

        return (loss_i2t + loss_t2i) / 2

    def text_encode(self, text, tokenize=True):
        if tokenize:
            tokens = self.tokenizer.tokenize(text)
        else:
            tokens = text
        text_features = self.text_encoder(tokens)
        if text_features.dim() < 2:
            text_features = text_features.unsqueeze(0)
        return text_features
    
    def forward(self, image, text, tokenize=True):
        image_features = self.image_encoder(image)
        text_features = self.text_encoder(text, tokenize)
        
        if image_features.dim() < 2:
            image_features = image_features.unsqueeze(0)
        if text_features.dim() < 2:
            text_features = text_features.unsqueeze(0)
            
        return image_features, text_features