File size: 3,747 Bytes
239ee43 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
from pathlib import Path
from functools import partial
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T, utils
import torch.nn.functional as F
from imagen_pytorch import t5
from torch.nn.utils.rnn import pad_sequence
from PIL import Image
from datasets.utils.file_utils import get_datasets_user_agent
import io
import urllib
USER_AGENT = get_datasets_user_agent()
# helpers functions
def exists(val):
return val is not None
def cycle(dl):
while True:
for data in dl:
yield data
def convert_image_to(img_type, image):
if image.mode != img_type:
return image.convert(img_type)
return image
# dataset, dataloader, collator
class Collator:
def __init__(self, image_size, url_label, text_label, image_label, name, channels):
self.url_label = url_label
self.text_label = text_label
self.image_label = image_label
self.download = url_label is not None
self.name = name
self.channels = channels
self.transform = T.Compose([
T.Resize(image_size),
T.CenterCrop(image_size),
T.ToTensor(),
])
def __call__(self, batch):
texts = []
images = []
for item in batch:
try:
if self.download:
image = self.fetch_single_image(item[self.url_label])
else:
image = item[self.image_label]
image = self.transform(image.convert(self.channels))
except:
continue
text = t5.t5_encode_text([item[self.text_label]], name=self.name)
texts.append(torch.squeeze(text))
images.append(image)
if len(texts) == 0:
return None
texts = pad_sequence(texts, True)
newbatch = []
for i in range(len(texts)):
newbatch.append((images[i], texts[i]))
return torch.utils.data.dataloader.default_collate(newbatch)
def fetch_single_image(self, image_url, timeout=1):
try:
request = urllib.request.Request(
image_url,
data=None,
headers={"user-agent": USER_AGENT},
)
with urllib.request.urlopen(request, timeout=timeout) as req:
image = Image.open(io.BytesIO(req.read())).convert('RGB')
except Exception:
image = None
return image
class Dataset(Dataset):
def __init__(
self,
folder,
image_size,
exts = ['jpg', 'jpeg', 'png', 'tiff'],
convert_image_to_type = None
):
super().__init__()
self.folder = folder
self.image_size = image_size
self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
convert_fn = partial(convert_image_to, convert_image_to_type) if exists(convert_image_to_type) else nn.Identity()
self.transform = T.Compose([
T.Lambda(convert_fn),
T.Resize(image_size),
T.RandomHorizontalFlip(),
T.CenterCrop(image_size),
T.ToTensor()
])
def __len__(self):
return len(self.paths)
def __getitem__(self, index):
path = self.paths[index]
img = Image.open(path)
return self.transform(img)
def get_images_dataloader(
folder,
*,
batch_size,
image_size,
shuffle = True,
cycle_dl = False,
pin_memory = True
):
ds = Dataset(folder, image_size)
dl = DataLoader(ds, batch_size = batch_size, shuffle = shuffle, pin_memory = pin_memory)
if cycle_dl:
dl = cycle(dl)
return dl
|