#!/usr/bin/env python3 """ Tiny VQAScore Model Wrapper """ import torch import torch.nn as nn from PIL import Image import numpy as np class TinyVQAScore: """A tiny random version of the VQAScore model.""" def __init__(self, model="tiny-random", device="cpu"): self.device = torch.device(device) self.model = self._create_tiny_model() self.model.to(self.device) self.model.eval() def _create_tiny_model(self): class TinyCLIPT5(nn.Module): def __init__(self): super().__init__() self.vision_encoder = nn.Sequential( nn.Conv2d(3, 64, kernel_size=16, stride=16), nn.AdaptiveAvgPool2d((1, 1)), # Global average pooling nn.Flatten(), nn.Linear(64, 256) ) self.text_encoder = nn.Sequential( nn.Embedding(32128, 256), nn.LayerNorm(256), nn.TransformerEncoderLayer(d_model=256, nhead=8, dim_feedforward=512, dropout=0.1, batch_first=True) ) self.multimodal_projector = nn.Sequential( nn.Linear(256, 128), nn.GELU(), nn.Linear(128, 64), nn.GELU(), nn.Linear(64, 1) ) self._init_weights() def _init_weights(self): for module in self.modules(): if isinstance(module, (nn.Linear, nn.Conv2d)): nn.init.xavier_uniform_(module.weight, gain=0.1) if module.bias is not None: nn.init.uniform_(module.bias, -0.1, 0.1) elif isinstance(module, nn.Embedding): nn.init.uniform_(module.weight, -0.1, 0.1) def forward(self, pixel_values, input_ids): vision_features = self.vision_encoder(pixel_values) text_features = self.text_encoder(input_ids) text_features = text_features.mean(dim=1) combined_features = vision_features + text_features score = self.multimodal_projector(combined_features) return score.squeeze(-1) return TinyCLIPT5() def score(self, image, question): if isinstance(image, Image.Image): image = image.resize((224, 224)) image_tensor = torch.from_numpy(np.array(image)).float() image_tensor = image_tensor.permute(2, 0, 1).unsqueeze(0) / 255.0 else: image_tensor = image input_ids = torch.randint(0, 32128, (1, 10)).to(self.device) with torch.no_grad(): score = self.model(image_tensor.to(self.device), input_ids) return torch.sigmoid(score).item() if __name__ == "__main__": # Test the model model = TinyVQAScore(device="cpu") test_image = Image.new('RGB', (224, 224), color='red') score = model.score(test_image, "What color is this image?") print(f"Test score: {score}")