tiny-random-vqa-score / tiny_vqa_model.py
davidberenstein1957's picture
Upload folder using huggingface_hub
bdec2d8 verified
#!/usr/bin/env python3
"""
Tiny VQAScore Model Wrapper
"""
import torch
import torch.nn as nn
from PIL import Image
import numpy as np
class TinyVQAScore:
"""A tiny random version of the VQAScore model."""
def __init__(self, model="tiny-random", device="cpu"):
self.device = torch.device(device)
self.model = self._create_tiny_model()
self.model.to(self.device)
self.model.eval()
def _create_tiny_model(self):
class TinyCLIPT5(nn.Module):
def __init__(self):
super().__init__()
self.vision_encoder = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=16, stride=16),
nn.AdaptiveAvgPool2d((1, 1)), # Global average pooling
nn.Flatten(),
nn.Linear(64, 256)
)
self.text_encoder = nn.Sequential(
nn.Embedding(32128, 256),
nn.LayerNorm(256),
nn.TransformerEncoderLayer(d_model=256, nhead=8, dim_feedforward=512, dropout=0.1, batch_first=True)
)
self.multimodal_projector = nn.Sequential(
nn.Linear(256, 128), nn.GELU(),
nn.Linear(128, 64), nn.GELU(),
nn.Linear(64, 1)
)
self._init_weights()
def _init_weights(self):
for module in self.modules():
if isinstance(module, (nn.Linear, nn.Conv2d)):
nn.init.xavier_uniform_(module.weight, gain=0.1)
if module.bias is not None:
nn.init.uniform_(module.bias, -0.1, 0.1)
elif isinstance(module, nn.Embedding):
nn.init.uniform_(module.weight, -0.1, 0.1)
def forward(self, pixel_values, input_ids):
vision_features = self.vision_encoder(pixel_values)
text_features = self.text_encoder(input_ids)
text_features = text_features.mean(dim=1)
combined_features = vision_features + text_features
score = self.multimodal_projector(combined_features)
return score.squeeze(-1)
return TinyCLIPT5()
def score(self, image, question):
if isinstance(image, Image.Image):
image = image.resize((224, 224))
image_tensor = torch.from_numpy(np.array(image)).float()
image_tensor = image_tensor.permute(2, 0, 1).unsqueeze(0) / 255.0
else:
image_tensor = image
input_ids = torch.randint(0, 32128, (1, 10)).to(self.device)
with torch.no_grad():
score = self.model(image_tensor.to(self.device), input_ids)
return torch.sigmoid(score).item()
if __name__ == "__main__":
# Test the model
model = TinyVQAScore(device="cpu")
test_image = Image.new('RGB', (224, 224), color='red')
score = model.score(test_image, "What color is this image?")
print(f"Test score: {score}")