File size: 4,703 Bytes
5ab5cab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from torchvision.datasets import CIFAR10, CelebA
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose, ToTensor, Lambda, CenterCrop, Resize, RandomHorizontalFlip
import os
import torch
import json
from PIL import Image as im
from helper.tokenizer import Tokenizer
from transformers import AutoProcessor

def center_crop_and_resize(img, crop_size, resize_size):
    width, height = img.size

    # 1. Center Crop
    left = (width - crop_size) / 2
    top = (height - crop_size) / 2
    right = (width + crop_size) / 2
    bottom = (height + crop_size) / 2

    img_cropped = img.crop((left, top, right, bottom))

    # 2. Resize
    img_resized = img_cropped.resize((resize_size, resize_size), im.Resampling.BICUBIC)

    return img_resized

class UnlabelDataset(Dataset):
    def __init__(self, path, transform):
        self.path = path
        self.file_list = os.listdir(path)
        self.transform = transform
        
    def __len__(self) :
        return len(self.file_list)

    def __getitem__(self, index):
        img_path = self.path + self.file_list[index]
        image = im.open(img_path)
        image = self.transform(image)
        return image
    
class CompositeDataset(Dataset):
    def __init__(self, path, text_path, processor: AutoProcessor = None):
        self.path = path
        self.text_path = text_path
        self.tokenizer = Tokenizer()
        self.processor = processor
        
        self.file_numbers = os.listdir(path)
        self.file_numbers = [ os.path.splitext(filename)[0] for filename in self.file_numbers ]
        
        self.transform = Compose([
                ToTensor(),
                CenterCrop(400),
                Resize(256, antialias=True),
                RandomHorizontalFlip(),
                Lambda(lambda x: (x - 0.5) * 2)
            ])
        
    def __len__(self) :
        return len(self.file_numbers)
    
    def get_text(self, text_path):
        with open(text_path, encoding = 'CP949') as f:
            text = json.load(f)['description']['impression']['description']
        return text

    def __getitem__(self, idx) :
        img_path = self.path + self.file_numbers[idx] + '.png'
        text_path = self.text_path + self.file_numbers[idx] + '.json'
        image = im.open(img_path)
        text = self.get_text(text_path)
        if self.processor is not None:
            image = center_crop_and_resize(image, 400, 256)
            inputs = self.processor(
                text=text,
                images=image, 
                return_tensors="pt", 
                padding='max_length', 
                max_length=77, 
                truncation=True,
                )
            for j in inputs:
                inputs[j] = inputs[j].squeeze(0)
            return inputs
        else:
            image = self.transform(image)
            text = self.tokenizer.tokenize(text)
            for j in text:
                text[j] = text[j].squeeze(0)
            return image, text

class DataGenerator():
    def __init__(self, num_workers: int = 4, pin_memory: bool = True):
        self.transform = Compose([
            ToTensor(),
            Lambda(lambda x: (x - 0.5) * 2)
            ])
        self.num_workers = num_workers
        self.pin_memory = pin_memory
        
    def cifar10(self, path = './datasets', batch_size : int = 64, train : bool = True):
        train_data = CIFAR10(path, download = True, train = train, transform = self.transform)
        dl = DataLoader(train_data, batch_size, shuffle = True, num_workers=self.num_workers, pin_memory=self.pin_memory)
        return dl
    
    def celeba(self, path = './datasets', batch_size : int = 16):
        train_data = CelebA(path, transform = Compose([
            ToTensor(),
            CenterCrop(178),
            Resize(128),
            Lambda(lambda x: (x - 0.5) * 2)
            ]))
        dl = DataLoader(train_data, batch_size, shuffle = True, num_workers=self.num_workers, pin_memory=self.pin_memory)
        return dl
    
    def composite(self, path, text_path, batch_size : int = 16, is_process: bool = False):
        processor = None
        if is_process:
            model_name = "Bingsu/clip-vit-base-patch32-ko"
            processor = AutoProcessor.from_pretrained(model_name, use_fast=False)
        dataset = CompositeDataset(path, text_path, processor)
        return DataLoader(dataset, batch_size=batch_size, shuffle=True,
                          num_workers=self.num_workers, pin_memory=self.pin_memory)

    def random_data(self, size, batch_size : int = 4):
        train_data = torch.randn(size)
        return DataLoader(train_data, batch_size)