File size: 1,974 Bytes
14ce5a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import PIL
from PIL import Image
from dataclasses import dataclass, field
from datasets import load_dataset
import torch
from .data_processing import pil_to_tensor
@dataclass
class DataCollatorForSupervisedDataset(object):
"""Collate examples for supervised fine-tuning."""
def __init__(self, dataset_name, **kwargs):
override_params = {}
if dataset_name == "DIV2K":
override_params = {
"target_image_size": -1,
"lock_ratio": True,
"center_crop": False,
"padding": False,
}
if dataset_name == "imagenet":
override_params = {"center_crop": True, "padding": False}
if dataset_name == "movie_posters":
override_params = {"center_crop": True, "padding": False}
if dataset_name == "high_quality_1024":
override_params = {"target_image_size": (1024, 1024)}
self.data_params = {**kwargs, **override_params}
def __call__(self, instances):
images = torch.stack(
[
pil_to_tensor(instance["image"], **self.data_params)
for instance in instances
],
dim=0,
)
idx = [instance["idx"] for instance in instances]
return dict(image=images, idx=idx)
class ImagenetDataset(torch.utils.data.Dataset):
def __init__(self, dataset_name, split_name="test", n_take=None):
print(dataset_name, split_name)
ds = load_dataset("huaweilin/VTBench", name=dataset_name, split=split_name if n_take is None else f"{split_name}[:{n_take}]")
self.image_list = ds["image"]
def __len__(self):
return len(self.image_list)
def __getitem__(self, idx):
return dict(
image=self.image_list[idx],
idx=idx,
)
def get_dataset(dataset_name, split_name, n_take):
dataset = ImagenetDataset(dataset_name, split_name, n_take)
return dataset
|