Spaces:
Running
Running
File size: 4,703 Bytes
5ab5cab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
from torchvision.datasets import CIFAR10, CelebA
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose, ToTensor, Lambda, CenterCrop, Resize, RandomHorizontalFlip
import os
import torch
import json
from PIL import Image as im
from helper.tokenizer import Tokenizer
from transformers import AutoProcessor
def center_crop_and_resize(img, crop_size, resize_size):
width, height = img.size
# 1. Center Crop
left = (width - crop_size) / 2
top = (height - crop_size) / 2
right = (width + crop_size) / 2
bottom = (height + crop_size) / 2
img_cropped = img.crop((left, top, right, bottom))
# 2. Resize
img_resized = img_cropped.resize((resize_size, resize_size), im.Resampling.BICUBIC)
return img_resized
class UnlabelDataset(Dataset):
def __init__(self, path, transform):
self.path = path
self.file_list = os.listdir(path)
self.transform = transform
def __len__(self) :
return len(self.file_list)
def __getitem__(self, index):
img_path = self.path + self.file_list[index]
image = im.open(img_path)
image = self.transform(image)
return image
class CompositeDataset(Dataset):
def __init__(self, path, text_path, processor: AutoProcessor = None):
self.path = path
self.text_path = text_path
self.tokenizer = Tokenizer()
self.processor = processor
self.file_numbers = os.listdir(path)
self.file_numbers = [ os.path.splitext(filename)[0] for filename in self.file_numbers ]
self.transform = Compose([
ToTensor(),
CenterCrop(400),
Resize(256, antialias=True),
RandomHorizontalFlip(),
Lambda(lambda x: (x - 0.5) * 2)
])
def __len__(self) :
return len(self.file_numbers)
def get_text(self, text_path):
with open(text_path, encoding = 'CP949') as f:
text = json.load(f)['description']['impression']['description']
return text
def __getitem__(self, idx) :
img_path = self.path + self.file_numbers[idx] + '.png'
text_path = self.text_path + self.file_numbers[idx] + '.json'
image = im.open(img_path)
text = self.get_text(text_path)
if self.processor is not None:
image = center_crop_and_resize(image, 400, 256)
inputs = self.processor(
text=text,
images=image,
return_tensors="pt",
padding='max_length',
max_length=77,
truncation=True,
)
for j in inputs:
inputs[j] = inputs[j].squeeze(0)
return inputs
else:
image = self.transform(image)
text = self.tokenizer.tokenize(text)
for j in text:
text[j] = text[j].squeeze(0)
return image, text
class DataGenerator():
def __init__(self, num_workers: int = 4, pin_memory: bool = True):
self.transform = Compose([
ToTensor(),
Lambda(lambda x: (x - 0.5) * 2)
])
self.num_workers = num_workers
self.pin_memory = pin_memory
def cifar10(self, path = './datasets', batch_size : int = 64, train : bool = True):
train_data = CIFAR10(path, download = True, train = train, transform = self.transform)
dl = DataLoader(train_data, batch_size, shuffle = True, num_workers=self.num_workers, pin_memory=self.pin_memory)
return dl
def celeba(self, path = './datasets', batch_size : int = 16):
train_data = CelebA(path, transform = Compose([
ToTensor(),
CenterCrop(178),
Resize(128),
Lambda(lambda x: (x - 0.5) * 2)
]))
dl = DataLoader(train_data, batch_size, shuffle = True, num_workers=self.num_workers, pin_memory=self.pin_memory)
return dl
def composite(self, path, text_path, batch_size : int = 16, is_process: bool = False):
processor = None
if is_process:
model_name = "Bingsu/clip-vit-base-patch32-ko"
processor = AutoProcessor.from_pretrained(model_name, use_fast=False)
dataset = CompositeDataset(path, text_path, processor)
return DataLoader(dataset, batch_size=batch_size, shuffle=True,
num_workers=self.num_workers, pin_memory=self.pin_memory)
def random_data(self, size, batch_size : int = 4):
train_data = torch.randn(size)
return DataLoader(train_data, batch_size)
|