import os, tarfile, glob, shutil import yaml import numpy as np from tqdm import tqdm from PIL import Image import custom_albumentations as albumentations from omegaconf import OmegaConf from torch.utils.data import Dataset from custom_controlnet_aux.diffusion_edge.taming.data.base import ImagePaths from custom_controlnet_aux.diffusion_edge.taming.util import download, retrieve import taming.data.utils as bdu def give_synsets_from_indices(indices, path_to_yaml="data/imagenet_idx_to_synset.yaml"): synsets = [] with open(path_to_yaml) as f: di2s = yaml.load(f) for idx in indices: synsets.append(str(di2s[idx])) print("Using {} different synsets for construction of Restriced Imagenet.".format(len(synsets))) return synsets def str_to_indices(string): """Expects a string in the format '32-123, 256, 280-321'""" assert not string.endswith(","), "provided string '{}' ends with a comma, pls remove it".format(string) subs = string.split(",") indices = [] for sub in subs: subsubs = sub.split("-") assert len(subsubs) > 0 if len(subsubs) == 1: indices.append(int(subsubs[0])) else: rang = [j for j in range(int(subsubs[0]), int(subsubs[1]))] indices.extend(rang) return sorted(indices) class ImageNetBase(Dataset): def __init__(self, config=None): self.config = config or OmegaConf.create() if not type(self.config)==dict: self.config = OmegaConf.to_container(self.config) self._prepare() self._prepare_synset_to_human() self._prepare_idx_to_synset() self._load() def __len__(self): return len(self.data) def __getitem__(self, i): return self.data[i] def _prepare(self): raise NotImplementedError() def _filter_relpaths(self, relpaths): ignore = set([ "n06596364_9591.JPEG", ]) relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore] if "sub_indices" in self.config: indices = str_to_indices(self.config["sub_indices"]) synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings files = [] for rpath in relpaths: syn = rpath.split("/")[0] if syn in synsets: files.append(rpath) return files else: return relpaths def _prepare_synset_to_human(self): SIZE = 2655750 URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1" self.human_dict = os.path.join(self.root, "synset_human.txt") if (not os.path.exists(self.human_dict) or not os.path.getsize(self.human_dict)==SIZE): download(URL, self.human_dict) def _prepare_idx_to_synset(self): URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1" self.idx2syn = os.path.join(self.root, "index_synset.yaml") if (not os.path.exists(self.idx2syn)): download(URL, self.idx2syn) def _load(self): with open(self.txt_filelist, "r") as f: self.relpaths = f.read().splitlines() l1 = len(self.relpaths) self.relpaths = self._filter_relpaths(self.relpaths) print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths))) self.synsets = [p.split("/")[0] for p in self.relpaths] self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths] unique_synsets = np.unique(self.synsets) class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets)) self.class_labels = [class_dict[s] for s in self.synsets] with open(self.human_dict, "r") as f: human_dict = f.read().splitlines() human_dict = dict(line.split(maxsplit=1) for line in human_dict) self.human_labels = [human_dict[s] for s in self.synsets] labels = { "relpath": np.array(self.relpaths), "synsets": np.array(self.synsets), "class_label": np.array(self.class_labels), "human_label": np.array(self.human_labels), } self.data = ImagePaths(self.abspaths, labels=labels, size=retrieve(self.config, "size", default=0), random_crop=self.random_crop) class ImageNetTrain(ImageNetBase): NAME = "ILSVRC2012_train" URL = "http://www.image-net.org/challenges/LSVRC/2012/" AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2" FILES = [ "ILSVRC2012_img_train.tar", ] SIZES = [ 147897477120, ] def _prepare(self): self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop", default=True) cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) self.datadir = os.path.join(self.root, "data") self.txt_filelist = os.path.join(self.root, "filelist.txt") self.expected_length = 1281167 if not bdu.is_prepared(self.root): # prep print("Preparing dataset {} in {}".format(self.NAME, self.root)) datadir = self.datadir if not os.path.exists(datadir): path = os.path.join(self.root, self.FILES[0]) if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: import academictorrents as at atpath = at.get(self.AT_HASH, datastore=self.root) assert atpath == path print("Extracting {} to {}".format(path, datadir)) os.makedirs(datadir, exist_ok=True) with tarfile.open(path, "r:") as tar: tar.extractall(path=datadir) print("Extracting sub-tars.") subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar"))) for subpath in tqdm(subpaths): subdir = subpath[:-len(".tar")] os.makedirs(subdir, exist_ok=True) with tarfile.open(subpath, "r:") as tar: tar.extractall(path=subdir) filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) filelist = [os.path.relpath(p, start=datadir) for p in filelist] filelist = sorted(filelist) filelist = "\n".join(filelist)+"\n" with open(self.txt_filelist, "w") as f: f.write(filelist) bdu.mark_prepared(self.root) class ImageNetValidation(ImageNetBase): NAME = "ILSVRC2012_validation" URL = "http://www.image-net.org/challenges/LSVRC/2012/" AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5" VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1" FILES = [ "ILSVRC2012_img_val.tar", "validation_synset.txt", ] SIZES = [ 6744924160, 1950000, ] def _prepare(self): self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop", default=False) cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) self.datadir = os.path.join(self.root, "data") self.txt_filelist = os.path.join(self.root, "filelist.txt") self.expected_length = 50000 if not bdu.is_prepared(self.root): # prep print("Preparing dataset {} in {}".format(self.NAME, self.root)) datadir = self.datadir if not os.path.exists(datadir): path = os.path.join(self.root, self.FILES[0]) if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: import academictorrents as at atpath = at.get(self.AT_HASH, datastore=self.root) assert atpath == path print("Extracting {} to {}".format(path, datadir)) os.makedirs(datadir, exist_ok=True) with tarfile.open(path, "r:") as tar: tar.extractall(path=datadir) vspath = os.path.join(self.root, self.FILES[1]) if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]: download(self.VS_URL, vspath) with open(vspath, "r") as f: synset_dict = f.read().splitlines() synset_dict = dict(line.split() for line in synset_dict) print("Reorganizing into synset folders") synsets = np.unique(list(synset_dict.values())) for s in synsets: os.makedirs(os.path.join(datadir, s), exist_ok=True) for k, v in synset_dict.items(): src = os.path.join(datadir, k) dst = os.path.join(datadir, v) shutil.move(src, dst) filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) filelist = [os.path.relpath(p, start=datadir) for p in filelist] filelist = sorted(filelist) filelist = "\n".join(filelist)+"\n" with open(self.txt_filelist, "w") as f: f.write(filelist) bdu.mark_prepared(self.root) def get_preprocessor(size=None, random_crop=False, additional_targets=None, crop_size=None): if size is not None and size > 0: transforms = list() rescaler = albumentations.SmallestMaxSize(max_size = size) transforms.append(rescaler) if not random_crop: cropper = albumentations.CenterCrop(height=size,width=size) transforms.append(cropper) else: cropper = albumentations.RandomCrop(height=size,width=size) transforms.append(cropper) flipper = albumentations.HorizontalFlip() transforms.append(flipper) preprocessor = albumentations.Compose(transforms, additional_targets=additional_targets) elif crop_size is not None and crop_size > 0: if not random_crop: cropper = albumentations.CenterCrop(height=crop_size,width=crop_size) else: cropper = albumentations.RandomCrop(height=crop_size,width=crop_size) transforms = [cropper] preprocessor = albumentations.Compose(transforms, additional_targets=additional_targets) else: preprocessor = lambda **kwargs: kwargs return preprocessor def rgba_to_depth(x): assert x.dtype == np.uint8 assert len(x.shape) == 3 and x.shape[2] == 4 y = x.copy() y.dtype = np.float32 y = y.reshape(x.shape[:2]) return np.ascontiguousarray(y) class BaseWithDepth(Dataset): DEFAULT_DEPTH_ROOT="data/imagenet_depth" def __init__(self, config=None, size=None, random_crop=False, crop_size=None, root=None): self.config = config self.base_dset = self.get_base_dset() self.preprocessor = get_preprocessor( size=size, crop_size=crop_size, random_crop=random_crop, additional_targets={"depth": "image"}) self.crop_size = crop_size if self.crop_size is not None: self.rescaler = albumentations.Compose( [albumentations.SmallestMaxSize(max_size = self.crop_size)], additional_targets={"depth": "image"}) if root is not None: self.DEFAULT_DEPTH_ROOT = root def __len__(self): return len(self.base_dset) def preprocess_depth(self, path): rgba = np.array(Image.open(path)) depth = rgba_to_depth(rgba) depth = (depth - depth.min())/max(1e-8, depth.max()-depth.min()) depth = 2.0*depth-1.0 return depth def __getitem__(self, i): e = self.base_dset[i] e["depth"] = self.preprocess_depth(self.get_depth_path(e)) # up if necessary h,w,c = e["image"].shape if self.crop_size and min(h,w) < self.crop_size: # have to upscale to be able to crop - this just uses bilinear out = self.rescaler(image=e["image"], depth=e["depth"]) e["image"] = out["image"] e["depth"] = out["depth"] transformed = self.preprocessor(image=e["image"], depth=e["depth"]) e["image"] = transformed["image"] e["depth"] = transformed["depth"] return e class ImageNetTrainWithDepth(BaseWithDepth): # default to random_crop=True def __init__(self, random_crop=True, sub_indices=None, **kwargs): self.sub_indices = sub_indices super().__init__(random_crop=random_crop, **kwargs) def get_base_dset(self): if self.sub_indices is None: return ImageNetTrain() else: return ImageNetTrain({"sub_indices": self.sub_indices}) def get_depth_path(self, e): fid = os.path.splitext(e["relpath"])[0]+".png" fid = os.path.join(self.DEFAULT_DEPTH_ROOT, "train", fid) return fid class ImageNetValidationWithDepth(BaseWithDepth): def __init__(self, sub_indices=None, **kwargs): self.sub_indices = sub_indices super().__init__(**kwargs) def get_base_dset(self): if self.sub_indices is None: return ImageNetValidation() else: return ImageNetValidation({"sub_indices": self.sub_indices}) def get_depth_path(self, e): fid = os.path.splitext(e["relpath"])[0]+".png" fid = os.path.join(self.DEFAULT_DEPTH_ROOT, "val", fid) return fid class RINTrainWithDepth(ImageNetTrainWithDepth): def __init__(self, config=None, size=None, random_crop=True, crop_size=None): sub_indices = "30-32, 33-37, 151-268, 281-285, 80-100, 365-382, 389-397, 118-121, 300-319" super().__init__(config=config, size=size, random_crop=random_crop, sub_indices=sub_indices, crop_size=crop_size) class RINValidationWithDepth(ImageNetValidationWithDepth): def __init__(self, config=None, size=None, random_crop=False, crop_size=None): sub_indices = "30-32, 33-37, 151-268, 281-285, 80-100, 365-382, 389-397, 118-121, 300-319" super().__init__(config=config, size=size, random_crop=random_crop, sub_indices=sub_indices, crop_size=crop_size) class DRINExamples(Dataset): def __init__(self): self.preprocessor = get_preprocessor(size=256, additional_targets={"depth": "image"}) with open("data/drin_examples.txt", "r") as f: relpaths = f.read().splitlines() self.image_paths = [os.path.join("data/drin_images", relpath) for relpath in relpaths] self.depth_paths = [os.path.join("data/drin_depth", relpath.replace(".JPEG", ".png")) for relpath in relpaths] def __len__(self): return len(self.image_paths) def preprocess_image(self, image_path): image = Image.open(image_path) if not image.mode == "RGB": image = image.convert("RGB") image = np.array(image).astype(np.uint8) image = self.preprocessor(image=image)["image"] image = (image/127.5 - 1.0).astype(np.float32) return image def preprocess_depth(self, path): rgba = np.array(Image.open(path)) depth = rgba_to_depth(rgba) depth = (depth - depth.min())/max(1e-8, depth.max()-depth.min()) depth = 2.0*depth-1.0 return depth def __getitem__(self, i): e = dict() e["image"] = self.preprocess_image(self.image_paths[i]) e["depth"] = self.preprocess_depth(self.depth_paths[i]) transformed = self.preprocessor(image=e["image"], depth=e["depth"]) e["image"] = transformed["image"] e["depth"] = transformed["depth"] return e def imscale(x, factor, keepshapes=False, keepmode="bicubic"): if factor is None or factor==1: return x dtype = x.dtype assert dtype in [np.float32, np.float64] assert x.min() >= -1 assert x.max() <= 1 keepmode = {"nearest": Image.NEAREST, "bilinear": Image.BILINEAR, "bicubic": Image.BICUBIC}[keepmode] lr = (x+1.0)*127.5 lr = lr.clip(0,255).astype(np.uint8) lr = Image.fromarray(lr) h, w, _ = x.shape nh = h//factor nw = w//factor assert nh > 0 and nw > 0, (nh, nw) lr = lr.resize((nw,nh), Image.BICUBIC) if keepshapes: lr = lr.resize((w,h), keepmode) lr = np.array(lr)/127.5-1.0 lr = lr.astype(dtype) return lr class ImageNetScale(Dataset): def __init__(self, size=None, crop_size=None, random_crop=False, up_factor=None, hr_factor=None, keep_mode="bicubic"): self.base = self.get_base() self.size = size self.crop_size = crop_size if crop_size is not None else self.size self.random_crop = random_crop self.up_factor = up_factor self.hr_factor = hr_factor self.keep_mode = keep_mode transforms = list() if self.size is not None and self.size > 0: rescaler = albumentations.SmallestMaxSize(max_size = self.size) self.rescaler = rescaler transforms.append(rescaler) if self.crop_size is not None and self.crop_size > 0: if len(transforms) == 0: self.rescaler = albumentations.SmallestMaxSize(max_size = self.crop_size) if not self.random_crop: cropper = albumentations.CenterCrop(height=self.crop_size,width=self.crop_size) else: cropper = albumentations.RandomCrop(height=self.crop_size,width=self.crop_size) transforms.append(cropper) if len(transforms) > 0: if self.up_factor is not None: additional_targets = {"lr": "image"} else: additional_targets = None self.preprocessor = albumentations.Compose(transforms, additional_targets=additional_targets) else: self.preprocessor = lambda **kwargs: kwargs def __len__(self): return len(self.base) def __getitem__(self, i): example = self.base[i] image = example["image"] # adjust resolution image = imscale(image, self.hr_factor, keepshapes=False) h,w,c = image.shape if self.crop_size and min(h,w) < self.crop_size: # have to upscale to be able to crop - this just uses bilinear image = self.rescaler(image=image)["image"] if self.up_factor is None: image = self.preprocessor(image=image)["image"] example["image"] = image else: lr = imscale(image, self.up_factor, keepshapes=True, keepmode=self.keep_mode) out = self.preprocessor(image=image, lr=lr) example["image"] = out["image"] example["lr"] = out["lr"] return example class ImageNetScaleTrain(ImageNetScale): def __init__(self, random_crop=True, **kwargs): super().__init__(random_crop=random_crop, **kwargs) def get_base(self): return ImageNetTrain() class ImageNetScaleValidation(ImageNetScale): def get_base(self): return ImageNetValidation() from skimage.feature import canny from skimage.color import rgb2gray class ImageNetEdges(ImageNetScale): def __init__(self, up_factor=1, **kwargs): super().__init__(up_factor=1, **kwargs) def __getitem__(self, i): example = self.base[i] image = example["image"] h,w,c = image.shape if self.crop_size and min(h,w) < self.crop_size: # have to upscale to be able to crop - this just uses bilinear image = self.rescaler(image=image)["image"] lr = canny(rgb2gray(image), sigma=2) lr = lr.astype(np.float32) lr = lr[:,:,None][:,:,[0,0,0]] out = self.preprocessor(image=image, lr=lr) example["image"] = out["image"] example["lr"] = out["lr"] return example class ImageNetEdgesTrain(ImageNetEdges): def __init__(self, random_crop=True, **kwargs): super().__init__(random_crop=random_crop, **kwargs) def get_base(self): return ImageNetTrain() class ImageNetEdgesValidation(ImageNetEdges): def get_base(self): return ImageNetValidation()