KoFace-AI / helper /data_generator.py
JuyeopDang's picture
Upload 35 files
5ab5cab verified
from torchvision.datasets import CIFAR10, CelebA
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose, ToTensor, Lambda, CenterCrop, Resize, RandomHorizontalFlip
import os
import torch
import json
from PIL import Image as im
from helper.tokenizer import Tokenizer
from transformers import AutoProcessor
def center_crop_and_resize(img, crop_size, resize_size):
width, height = img.size
# 1. Center Crop
left = (width - crop_size) / 2
top = (height - crop_size) / 2
right = (width + crop_size) / 2
bottom = (height + crop_size) / 2
img_cropped = img.crop((left, top, right, bottom))
# 2. Resize
img_resized = img_cropped.resize((resize_size, resize_size), im.Resampling.BICUBIC)
return img_resized
class UnlabelDataset(Dataset):
def __init__(self, path, transform):
self.path = path
self.file_list = os.listdir(path)
self.transform = transform
def __len__(self) :
return len(self.file_list)
def __getitem__(self, index):
img_path = self.path + self.file_list[index]
image = im.open(img_path)
image = self.transform(image)
return image
class CompositeDataset(Dataset):
def __init__(self, path, text_path, processor: AutoProcessor = None):
self.path = path
self.text_path = text_path
self.tokenizer = Tokenizer()
self.processor = processor
self.file_numbers = os.listdir(path)
self.file_numbers = [ os.path.splitext(filename)[0] for filename in self.file_numbers ]
self.transform = Compose([
ToTensor(),
CenterCrop(400),
Resize(256, antialias=True),
RandomHorizontalFlip(),
Lambda(lambda x: (x - 0.5) * 2)
])
def __len__(self) :
return len(self.file_numbers)
def get_text(self, text_path):
with open(text_path, encoding = 'CP949') as f:
text = json.load(f)['description']['impression']['description']
return text
def __getitem__(self, idx) :
img_path = self.path + self.file_numbers[idx] + '.png'
text_path = self.text_path + self.file_numbers[idx] + '.json'
image = im.open(img_path)
text = self.get_text(text_path)
if self.processor is not None:
image = center_crop_and_resize(image, 400, 256)
inputs = self.processor(
text=text,
images=image,
return_tensors="pt",
padding='max_length',
max_length=77,
truncation=True,
)
for j in inputs:
inputs[j] = inputs[j].squeeze(0)
return inputs
else:
image = self.transform(image)
text = self.tokenizer.tokenize(text)
for j in text:
text[j] = text[j].squeeze(0)
return image, text
class DataGenerator():
def __init__(self, num_workers: int = 4, pin_memory: bool = True):
self.transform = Compose([
ToTensor(),
Lambda(lambda x: (x - 0.5) * 2)
])
self.num_workers = num_workers
self.pin_memory = pin_memory
def cifar10(self, path = './datasets', batch_size : int = 64, train : bool = True):
train_data = CIFAR10(path, download = True, train = train, transform = self.transform)
dl = DataLoader(train_data, batch_size, shuffle = True, num_workers=self.num_workers, pin_memory=self.pin_memory)
return dl
def celeba(self, path = './datasets', batch_size : int = 16):
train_data = CelebA(path, transform = Compose([
ToTensor(),
CenterCrop(178),
Resize(128),
Lambda(lambda x: (x - 0.5) * 2)
]))
dl = DataLoader(train_data, batch_size, shuffle = True, num_workers=self.num_workers, pin_memory=self.pin_memory)
return dl
def composite(self, path, text_path, batch_size : int = 16, is_process: bool = False):
processor = None
if is_process:
model_name = "Bingsu/clip-vit-base-patch32-ko"
processor = AutoProcessor.from_pretrained(model_name, use_fast=False)
dataset = CompositeDataset(path, text_path, processor)
return DataLoader(dataset, batch_size=batch_size, shuffle=True,
num_workers=self.num_workers, pin_memory=self.pin_memory)
def random_data(self, size, batch_size : int = 4):
train_data = torch.randn(size)
return DataLoader(train_data, batch_size)