File size: 3,104 Bytes
bdec2d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#!/usr/bin/env python3
"""
Tiny VQAScore Model Wrapper
"""

import torch
import torch.nn as nn
from PIL import Image
import numpy as np

class TinyVQAScore:
    """A tiny random version of the VQAScore model."""

    def __init__(self, model="tiny-random", device="cpu"):
        self.device = torch.device(device)
        self.model = self._create_tiny_model()
        self.model.to(self.device)
        self.model.eval()

    def _create_tiny_model(self):
        class TinyCLIPT5(nn.Module):
            def __init__(self):
                super().__init__()
                                 self.vision_encoder = nn.Sequential(
                     nn.Conv2d(3, 64, kernel_size=16, stride=16),
                     nn.AdaptiveAvgPool2d((1, 1)),  # Global average pooling
                     nn.Flatten(),
                     nn.Linear(64, 256)
                 )
                self.text_encoder = nn.Sequential(
                    nn.Embedding(32128, 256),
                    nn.LayerNorm(256),
                    nn.TransformerEncoderLayer(d_model=256, nhead=8, dim_feedforward=512, dropout=0.1, batch_first=True)
                )
                self.multimodal_projector = nn.Sequential(
                    nn.Linear(256, 128), nn.GELU(),
                    nn.Linear(128, 64), nn.GELU(),
                    nn.Linear(64, 1)
                )
                self._init_weights()

            def _init_weights(self):
                for module in self.modules():
                    if isinstance(module, (nn.Linear, nn.Conv2d)):
                        nn.init.xavier_uniform_(module.weight, gain=0.1)
                        if module.bias is not None:
                            nn.init.uniform_(module.bias, -0.1, 0.1)
                    elif isinstance(module, nn.Embedding):
                        nn.init.uniform_(module.weight, -0.1, 0.1)

            def forward(self, pixel_values, input_ids):
                vision_features = self.vision_encoder(pixel_values)
                text_features = self.text_encoder(input_ids)
                text_features = text_features.mean(dim=1)
                combined_features = vision_features + text_features
                score = self.multimodal_projector(combined_features)
                return score.squeeze(-1)

        return TinyCLIPT5()

    def score(self, image, question):
        if isinstance(image, Image.Image):
            image = image.resize((224, 224))
            image_tensor = torch.from_numpy(np.array(image)).float()
            image_tensor = image_tensor.permute(2, 0, 1).unsqueeze(0) / 255.0
        else:
            image_tensor = image

        input_ids = torch.randint(0, 32128, (1, 10)).to(self.device)

        with torch.no_grad():
            score = self.model(image_tensor.to(self.device), input_ids)

        return torch.sigmoid(score).item()

if __name__ == "__main__":
    # Test the model
    model = TinyVQAScore(device="cpu")
    test_image = Image.new('RGB', (224, 224), color='red')
    score = model.score(test_image, "What color is this image?")
    print(f"Test score: {score}")