File size: 1,974 Bytes
14ce5a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import PIL
from PIL import Image
from dataclasses import dataclass, field
from datasets import load_dataset
import torch
from .data_processing import pil_to_tensor


@dataclass
class DataCollatorForSupervisedDataset(object):
    """Collate examples for supervised fine-tuning."""

    def __init__(self, dataset_name, **kwargs):
        override_params = {}
        if dataset_name == "DIV2K":
            override_params = {
                "target_image_size": -1,
                "lock_ratio": True,
                "center_crop": False,
                "padding": False,
            }
        if dataset_name == "imagenet":
            override_params = {"center_crop": True, "padding": False}
        if dataset_name == "movie_posters":
            override_params = {"center_crop": True, "padding": False}
        if dataset_name == "high_quality_1024":
            override_params = {"target_image_size": (1024, 1024)}

        self.data_params = {**kwargs, **override_params}

    def __call__(self, instances):
        images = torch.stack(
            [
                pil_to_tensor(instance["image"], **self.data_params)
                for instance in instances
            ],
            dim=0,
        )
        idx = [instance["idx"] for instance in instances]
        return dict(image=images, idx=idx)


class ImagenetDataset(torch.utils.data.Dataset):
    def __init__(self, dataset_name, split_name="test", n_take=None):
        print(dataset_name, split_name)
        ds = load_dataset("huaweilin/VTBench", name=dataset_name, split=split_name if n_take is None else f"{split_name}[:{n_take}]")
        self.image_list = ds["image"]

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        return dict(
            image=self.image_list[idx],
            idx=idx,
        )


def get_dataset(dataset_name, split_name, n_take):
    dataset = ImagenetDataset(dataset_name, split_name, n_take)
    return dataset