import os import random from PIL import Image import numpy as np import torch from torch.utils.data import Dataset from transformers import AutoImageProcessor from DiT_VAE.diffusion.data.builder import DATASETS from omegaconf import OmegaConf from torchvision import transforms from transformers import CLIPImageProcessor import io import zipfile import numpy import json def to_rgb_image(maybe_rgba: Image.Image): if maybe_rgba.mode == 'RGB': return maybe_rgba elif maybe_rgba.mode == 'RGBA': rgba = maybe_rgba img = numpy.random.randint(127, 128, size=[rgba.size[1], rgba.size[0], 3], dtype=numpy.uint8) img = Image.fromarray(img, 'RGB') img.paste(rgba, mask=rgba.getchannel('A')) return img else: raise ValueError("Unsupported image type.", maybe_rgba.mode) @DATASETS.register_module() class TriplaneData(Dataset): def __init__(self, data_base_dir, model_names, data_json_file, dino_path, i_drop_rate=0.1, image_size=256, **kwargs): self.dict_data_image = json.load(open(data_json_file)) # {'image_name': pose} self.data_base_dir = data_base_dir self.dino_img_processor = AutoImageProcessor.from_pretrained(dino_path) self.size = image_size self.data_list = list(self.dict_data_image.keys()) self.zip_file_dict = {} config_gan_model = OmegaConf.load(model_names) all_models = config_gan_model['gan_models'].keys() for model_name in all_models: zipfile_path = os.path.join(self.data_base_dir, model_name + '.zip') zipfile_load = zipfile.ZipFile(zipfile_path) self.zip_file_dict[model_name] = zipfile_load self.transform = transforms.Compose([ transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop(self.size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ]) self.clip_image_processor = CLIPImageProcessor() self.i_drop_rate = i_drop_rate def getdata(self, idx): data_name = self.data_list[idx] data_model_name = self.dict_data_image[data_name]['model_name'] zipfile_loaded = self.zip_file_dict[data_model_name] # zipfile_path = os.path.join(self.data_base_dir, data_model_name) # zipfile_loaded = zipfile.ZipFile(zipfile_path) with zipfile_loaded.open(self.dict_data_image[data_name]['z_dir'], 'r') as f: buffer = io.BytesIO(f.read()) data_z = torch.load(buffer) with zipfile_loaded.open(self.dict_data_image[data_name]['vert_dir'], 'r') as f: buffer = io.BytesIO(f.read()) data_vert = torch.load(buffer) with zipfile_loaded.open(self.dict_data_image[data_name]['img_dir'], 'r') as f: raw_image = to_rgb_image(Image.open(f)) dino_img = self.dino_img_processor(images=raw_image, return_tensors="pt").pixel_values image = self.transform(raw_image.convert("RGB")) clip_image = self.clip_image_processor(images=raw_image, return_tensors="pt").pixel_values drop_image_embed = 0 rand_num = random.random() if rand_num < self.i_drop_rate: drop_image_embed = 1 return { "raw_image": raw_image, "dino_img": dino_img, "image": image, "clip_image": clip_image.clone(), "data_z": data_z, "data_vert": data_vert, "data_model_name": data_model_name, "drop_image_embed": drop_image_embed, } # # img_path = self.img_samples[index] # npz_path = self.txt_feat_samples[index] # npy_path = self.vae_feat_samples[index] # prompt = self.prompt_samples[index] # data_info = { # 'img_hw': torch.tensor([torch.tensor(self.resolution), torch.tensor(self.resolution)], dtype=torch.float32), # 'aspect_ratio': torch.tensor(1.) # } # # img = self.loader(npy_path) if self.load_vae_feat else self.loader(img_path) # txt_info = np.load(npz_path) # txt_fea = torch.from_numpy(txt_info['caption_feature']) # 1xTx4096 # attention_mask = torch.ones(1, 1, txt_fea.shape[1]) # 1x1xT # if 'attention_mask' in txt_info.keys(): # attention_mask = torch.from_numpy(txt_info['attention_mask'])[None] # if txt_fea.shape[1] != self.max_lenth: # txt_fea = torch.cat([txt_fea, txt_fea[:, -1:].repeat(1, self.max_lenth-txt_fea.shape[1], 1)], dim=1) # attention_mask = torch.cat([attention_mask, torch.zeros(1, 1, self.max_lenth-attention_mask.shape[-1])], dim=-1) # # if self.transform: # img = self.transform(img) # # data_info['prompt'] = prompt # return img, txt_fea, attention_mask, data_info def __getitem__(self, idx): for _ in range(20): try: return self.getdata(idx) except Exception as e: print(f"Error details: {str(e)}") idx = np.random.randint(len(self)) raise RuntimeError('Too many bad data.') def __len__(self): return len(self.data_list) def __getattr__(self, name): if name == "set_epoch": return lambda epoch: None raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")