|
|
|
""" |
|
Tiny VQAScore Model Wrapper |
|
""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
from PIL import Image |
|
import numpy as np |
|
|
|
class TinyVQAScore: |
|
"""A tiny random version of the VQAScore model.""" |
|
|
|
def __init__(self, model="tiny-random", device="cpu"): |
|
self.device = torch.device(device) |
|
self.model = self._create_tiny_model() |
|
self.model.to(self.device) |
|
self.model.eval() |
|
|
|
def _create_tiny_model(self): |
|
class TinyCLIPT5(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.vision_encoder = nn.Sequential( |
|
nn.Conv2d(3, 64, kernel_size=16, stride=16), |
|
nn.AdaptiveAvgPool2d((1, 1)), |
|
nn.Flatten(), |
|
nn.Linear(64, 256) |
|
) |
|
self.text_encoder = nn.Sequential( |
|
nn.Embedding(32128, 256), |
|
nn.LayerNorm(256), |
|
nn.TransformerEncoderLayer(d_model=256, nhead=8, dim_feedforward=512, dropout=0.1, batch_first=True) |
|
) |
|
self.multimodal_projector = nn.Sequential( |
|
nn.Linear(256, 128), nn.GELU(), |
|
nn.Linear(128, 64), nn.GELU(), |
|
nn.Linear(64, 1) |
|
) |
|
self._init_weights() |
|
|
|
def _init_weights(self): |
|
for module in self.modules(): |
|
if isinstance(module, (nn.Linear, nn.Conv2d)): |
|
nn.init.xavier_uniform_(module.weight, gain=0.1) |
|
if module.bias is not None: |
|
nn.init.uniform_(module.bias, -0.1, 0.1) |
|
elif isinstance(module, nn.Embedding): |
|
nn.init.uniform_(module.weight, -0.1, 0.1) |
|
|
|
def forward(self, pixel_values, input_ids): |
|
vision_features = self.vision_encoder(pixel_values) |
|
text_features = self.text_encoder(input_ids) |
|
text_features = text_features.mean(dim=1) |
|
combined_features = vision_features + text_features |
|
score = self.multimodal_projector(combined_features) |
|
return score.squeeze(-1) |
|
|
|
return TinyCLIPT5() |
|
|
|
def score(self, image, question): |
|
if isinstance(image, Image.Image): |
|
image = image.resize((224, 224)) |
|
image_tensor = torch.from_numpy(np.array(image)).float() |
|
image_tensor = image_tensor.permute(2, 0, 1).unsqueeze(0) / 255.0 |
|
else: |
|
image_tensor = image |
|
|
|
input_ids = torch.randint(0, 32128, (1, 10)).to(self.device) |
|
|
|
with torch.no_grad(): |
|
score = self.model(image_tensor.to(self.device), input_ids) |
|
|
|
return torch.sigmoid(score).item() |
|
|
|
if __name__ == "__main__": |
|
|
|
model = TinyVQAScore(device="cpu") |
|
test_image = Image.new('RGB', (224, 224), color='red') |
|
score = model.score(test_image, "What color is this image?") |
|
print(f"Test score: {score}") |
|
|