Spaces:
Running
Running
from torchvision.datasets import CIFAR10, CelebA | |
from torch.utils.data import DataLoader, Dataset | |
from torchvision.transforms import Compose, ToTensor, Lambda, CenterCrop, Resize, RandomHorizontalFlip | |
import os | |
import torch | |
import json | |
from PIL import Image as im | |
from helper.tokenizer import Tokenizer | |
from transformers import AutoProcessor | |
def center_crop_and_resize(img, crop_size, resize_size): | |
width, height = img.size | |
# 1. Center Crop | |
left = (width - crop_size) / 2 | |
top = (height - crop_size) / 2 | |
right = (width + crop_size) / 2 | |
bottom = (height + crop_size) / 2 | |
img_cropped = img.crop((left, top, right, bottom)) | |
# 2. Resize | |
img_resized = img_cropped.resize((resize_size, resize_size), im.Resampling.BICUBIC) | |
return img_resized | |
class UnlabelDataset(Dataset): | |
def __init__(self, path, transform): | |
self.path = path | |
self.file_list = os.listdir(path) | |
self.transform = transform | |
def __len__(self) : | |
return len(self.file_list) | |
def __getitem__(self, index): | |
img_path = self.path + self.file_list[index] | |
image = im.open(img_path) | |
image = self.transform(image) | |
return image | |
class CompositeDataset(Dataset): | |
def __init__(self, path, text_path, processor: AutoProcessor = None): | |
self.path = path | |
self.text_path = text_path | |
self.tokenizer = Tokenizer() | |
self.processor = processor | |
self.file_numbers = os.listdir(path) | |
self.file_numbers = [ os.path.splitext(filename)[0] for filename in self.file_numbers ] | |
self.transform = Compose([ | |
ToTensor(), | |
CenterCrop(400), | |
Resize(256, antialias=True), | |
RandomHorizontalFlip(), | |
Lambda(lambda x: (x - 0.5) * 2) | |
]) | |
def __len__(self) : | |
return len(self.file_numbers) | |
def get_text(self, text_path): | |
with open(text_path, encoding = 'CP949') as f: | |
text = json.load(f)['description']['impression']['description'] | |
return text | |
def __getitem__(self, idx) : | |
img_path = self.path + self.file_numbers[idx] + '.png' | |
text_path = self.text_path + self.file_numbers[idx] + '.json' | |
image = im.open(img_path) | |
text = self.get_text(text_path) | |
if self.processor is not None: | |
image = center_crop_and_resize(image, 400, 256) | |
inputs = self.processor( | |
text=text, | |
images=image, | |
return_tensors="pt", | |
padding='max_length', | |
max_length=77, | |
truncation=True, | |
) | |
for j in inputs: | |
inputs[j] = inputs[j].squeeze(0) | |
return inputs | |
else: | |
image = self.transform(image) | |
text = self.tokenizer.tokenize(text) | |
for j in text: | |
text[j] = text[j].squeeze(0) | |
return image, text | |
class DataGenerator(): | |
def __init__(self, num_workers: int = 4, pin_memory: bool = True): | |
self.transform = Compose([ | |
ToTensor(), | |
Lambda(lambda x: (x - 0.5) * 2) | |
]) | |
self.num_workers = num_workers | |
self.pin_memory = pin_memory | |
def cifar10(self, path = './datasets', batch_size : int = 64, train : bool = True): | |
train_data = CIFAR10(path, download = True, train = train, transform = self.transform) | |
dl = DataLoader(train_data, batch_size, shuffle = True, num_workers=self.num_workers, pin_memory=self.pin_memory) | |
return dl | |
def celeba(self, path = './datasets', batch_size : int = 16): | |
train_data = CelebA(path, transform = Compose([ | |
ToTensor(), | |
CenterCrop(178), | |
Resize(128), | |
Lambda(lambda x: (x - 0.5) * 2) | |
])) | |
dl = DataLoader(train_data, batch_size, shuffle = True, num_workers=self.num_workers, pin_memory=self.pin_memory) | |
return dl | |
def composite(self, path, text_path, batch_size : int = 16, is_process: bool = False): | |
processor = None | |
if is_process: | |
model_name = "Bingsu/clip-vit-base-patch32-ko" | |
processor = AutoProcessor.from_pretrained(model_name, use_fast=False) | |
dataset = CompositeDataset(path, text_path, processor) | |
return DataLoader(dataset, batch_size=batch_size, shuffle=True, | |
num_workers=self.num_workers, pin_memory=self.pin_memory) | |
def random_data(self, size, batch_size : int = 4): | |
train_data = torch.randn(size) | |
return DataLoader(train_data, batch_size) | |