import PIL from PIL import Image from dataclasses import dataclass, field from datasets import load_dataset import torch from .data_processing import pil_to_tensor @dataclass class DataCollatorForSupervisedDataset(object): """Collate examples for supervised fine-tuning.""" def __init__(self, dataset_name, **kwargs): override_params = {} if dataset_name == "DIV2K": override_params = { "target_image_size": -1, "lock_ratio": True, "center_crop": False, "padding": False, } if dataset_name == "imagenet": override_params = {"center_crop": True, "padding": False} if dataset_name == "movie_posters": override_params = {"center_crop": True, "padding": False} if dataset_name == "high_quality_1024": override_params = {"target_image_size": (1024, 1024)} self.data_params = {**kwargs, **override_params} def __call__(self, instances): images = torch.stack( [ pil_to_tensor(instance["image"], **self.data_params) for instance in instances ], dim=0, ) idx = [instance["idx"] for instance in instances] return dict(image=images, idx=idx) class ImagenetDataset(torch.utils.data.Dataset): def __init__(self, dataset_name, split_name="test", n_take=None): print(dataset_name, split_name) ds = load_dataset("huaweilin/VTBench", name=dataset_name, split=split_name if n_take is None else f"{split_name}[:{n_take}]") self.image_list = ds["image"] def __len__(self): return len(self.image_list) def __getitem__(self, idx): return dict( image=self.image_list[idx], idx=idx, ) def get_dataset(dataset_name, split_name, n_take): dataset = ImagenetDataset(dataset_name, split_name, n_take) return dataset