Spaces:
Build error
Build error
| from typing import Optional, List | |
| import os | |
| import json | |
| import requests | |
| import functools | |
| from io import BytesIO | |
| from pathlib import Path | |
| from urllib3 import disable_warnings | |
| from urllib3.exceptions import InsecureRequestWarning | |
| import torch | |
| import torchvision | |
| from torch import Tensor | |
| from torch.nn.modules import Module | |
| from torch.utils.data import Dataset, Subset, DataLoader | |
| # from torchtext.datasets import IMDB | |
| from transformers import BertTokenizer, BertForSequenceClassification | |
| from transformers import ViltForQuestionAnswering, ViltProcessor | |
| from tqdm import tqdm | |
| from PIL import Image | |
| # datasets | |
| class ImageNetDataset(Dataset): | |
| def __init__(self, root_dir, transform=None): | |
| self.root_dir = root_dir | |
| self.img_dir = os.path.join(self.root_dir, 'samples/') | |
| self.label_dir = os.path.join(self.root_dir, 'imagenet_class_index.json') | |
| with open(self.label_dir) as json_data: | |
| self.idx_to_labels = json.load(json_data) | |
| self.img_names = os.listdir(self.img_dir) | |
| self.img_names.sort() | |
| self.transform = transform | |
| def __len__(self): | |
| return len(self.img_names) | |
| def __getitem__(self, idx): | |
| img_path = os.path.join(self.img_dir, self.img_names[idx]) | |
| image = Image.open(img_path).convert('RGB') | |
| label = idx | |
| if self.transform: | |
| image = self.transform(image) | |
| return image, label | |
| def idx_to_label(self, idx): | |
| return self.idx_to_labels[str(idx)][1] | |
| def get_imagenet_dataset( | |
| transform, | |
| subset_size: int=100, # ignored if indices is not None | |
| root_dir="./data/ImageNet", | |
| indices: Optional[List[int]]=None, | |
| ): | |
| os.chdir(Path(__file__).parent) # ensure path | |
| dataset = ImageNetDataset(root_dir=root_dir, transform=transform) | |
| if indices is not None: | |
| return Subset(dataset, indices=indices) | |
| indices = list(range(len(dataset))) | |
| subset = Subset(dataset, indices=indices[:subset_size]) | |
| return subset | |
| class IMDBDataset(Dataset): | |
| def __init__(self, split='test'): | |
| super().__init__() | |
| data_iter = IMDB(split=split) | |
| self.annotations = [(line, label-1) for label, line in tqdm(data_iter)] | |
| def __len__(self): | |
| return len(self.annotations) | |
| def __getitem__(self, idx): | |
| return self.annotations[idx] | |
| def get_imdb_dataset(split='test'): | |
| return IMDBDataset(split=split) | |
| disable_warnings(InsecureRequestWarning) | |
| class VQADataset(Dataset): | |
| def __init__(self): | |
| super().__init__() | |
| res = requests.get('https://visualqa.org/balanced_data.json') | |
| self.annotations = eval(res.text) | |
| def __len__(self): | |
| return len(self.annotations) | |
| def __getitem__(self, idx): | |
| data = self.annotations[idx] | |
| if isinstance(data['original_image'], str): | |
| print(f"Requesting {data['original_image']}...") | |
| res = requests.get(data['original_image'], verify=False) | |
| img = Image.open(BytesIO(res.content)).convert('RGB') | |
| data['original_image'] = img | |
| return data['original_image'], data['question'], data['original_answer'] | |
| def get_vqa_dataset(): | |
| return VQADataset() | |
| # models | |
| def get_torchvision_model(model_name): | |
| weights = torchvision.models.get_model_weights(model_name).DEFAULT | |
| model = torchvision.models.get_model(model_name, weights=weights).eval() | |
| transform = weights.transforms() | |
| return model, transform | |
| class Bert(BertForSequenceClassification): | |
| def forward(self, input_ids, token_type_ids, attention_mask): | |
| return super().forward( | |
| input_ids=input_ids, | |
| token_type_ids=token_type_ids, | |
| attention_mask=attention_mask | |
| ).logits | |
| def get_bert_model(model_name, num_labels): | |
| return Bert.from_pretrained(model_name, num_labels=num_labels) | |
| class Vilt(ViltForQuestionAnswering): | |
| def forward( | |
| self, | |
| pixel_values, | |
| input_ids, | |
| token_type_ids, | |
| attention_mask, | |
| pixel_mask, | |
| ): | |
| return super().forward( | |
| input_ids=input_ids, | |
| token_type_ids=token_type_ids, | |
| attention_mask=attention_mask, | |
| pixel_values=pixel_values, | |
| pixel_mask=pixel_mask, | |
| ).logits | |
| def get_vilt_model(model_name): | |
| return Vilt.from_pretrained(model_name) | |
| # utils | |
| img_to_np = lambda img: img.permute(1, 2, 0).detach().numpy() | |
| def denormalize_image(inputs, mean, std): | |
| return img_to_np( | |
| inputs | |
| * Tensor(std)[:, None, None] | |
| + Tensor(mean)[:, None, None] | |
| ) | |
| def bert_collate_fn(batch, tokenizer=None): | |
| inputs = tokenizer( | |
| [d[0] for d in batch], | |
| padding=True, | |
| truncation=True, | |
| return_tensors='pt', | |
| ) | |
| labels = torch.tensor([d[1] for d in batch]) | |
| return tuple(inputs.values()), labels | |
| def get_bert_tokenizer(model_name): | |
| return BertTokenizer.from_pretrained(model_name) | |
| def get_vilt_processor(model_name): | |
| return ViltProcessor.from_pretrained(model_name) | |
| def vilt_collate_fn(batch, processor=None, label2id=None): | |
| imgs = [d[0] for d in batch] | |
| qsts = [d[1] for d in batch] | |
| inputs = processor( | |
| images=imgs, | |
| text=qsts, | |
| padding=True, | |
| truncation=True, | |
| return_tensors='pt', | |
| ) | |
| labels = torch.tensor([label2id[d[2]] for d in batch]) | |
| return ( | |
| inputs['pixel_values'], | |
| inputs['input_ids'], | |
| inputs['token_type_ids'], | |
| inputs['attention_mask'], | |
| inputs['pixel_mask'], | |
| labels, | |
| ) | |
| def load_model_and_dataloader_for_tutorial(modality, device): | |
| if modality == 'image': | |
| model, transform = get_torchvision_model('resnet18') | |
| model = model.to(device) | |
| model.eval() | |
| dataset = get_imagenet_dataset(transform) | |
| loader = DataLoader(dataset, batch_size=8, shuffle=False) | |
| return model, loader, transform | |
| elif modality == 'text': | |
| model = get_bert_model('fabriceyhc/bert-base-uncased-imdb', num_labels=2) | |
| model = model.to(device) | |
| model.eval() | |
| dataset = get_imdb_dataset(split='test') | |
| tokenizer = get_bert_tokenizer('fabriceyhc/bert-base-uncased-imdb') | |
| loader = DataLoader( | |
| dataset, | |
| batch_size=8, | |
| shuffle=False, | |
| collate_fn=functools.partial(bert_collate_fn, tokenizer=tokenizer) | |
| ) | |
| return model, loader, tokenizer | |
| elif modality == ('image', 'text'): | |
| model = get_vilt_model('dandelin/vilt-b32-finetuned-vqa') | |
| model.to(device) | |
| model.eval() | |
| dataset = get_vqa_dataset() | |
| processor = get_vilt_processor('dandelin/vilt-b32-finetuned-vqa') | |
| loader = DataLoader( | |
| dataset, | |
| batch_size=2, | |
| shuffle=False, | |
| collate_fn=functools.partial( | |
| vilt_collate_fn, | |
| processor=processor, | |
| label2id=model.config.label2id, | |
| ), | |
| ) | |
| return model, loader, processor | |