Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import random | |
from PIL import Image | |
import numpy as np | |
import torch | |
from torch.utils.data import Dataset | |
from transformers import AutoImageProcessor | |
from DiT_VAE.diffusion.data.builder import DATASETS | |
from omegaconf import OmegaConf | |
from torchvision import transforms | |
from transformers import CLIPImageProcessor | |
import io | |
import zipfile | |
import numpy | |
import json | |
def to_rgb_image(maybe_rgba: Image.Image): | |
if maybe_rgba.mode == 'RGB': | |
return maybe_rgba | |
elif maybe_rgba.mode == 'RGBA': | |
rgba = maybe_rgba | |
img = numpy.random.randint(127, 128, size=[rgba.size[1], rgba.size[0], 3], dtype=numpy.uint8) | |
img = Image.fromarray(img, 'RGB') | |
img.paste(rgba, mask=rgba.getchannel('A')) | |
return img | |
else: | |
raise ValueError("Unsupported image type.", maybe_rgba.mode) | |
class TriplaneData(Dataset): | |
def __init__(self, | |
data_base_dir, | |
model_names, | |
data_json_file, | |
dino_path, | |
i_drop_rate=0.1, | |
image_size=256, | |
**kwargs): | |
self.dict_data_image = json.load(open(data_json_file)) # {'image_name': pose} | |
self.data_base_dir = data_base_dir | |
self.dino_img_processor = AutoImageProcessor.from_pretrained(dino_path) | |
self.size = image_size | |
self.data_list = list(self.dict_data_image.keys()) | |
self.zip_file_dict = {} | |
config_gan_model = OmegaConf.load(model_names) | |
all_models = config_gan_model['gan_models'].keys() | |
for model_name in all_models: | |
zipfile_path = os.path.join(self.data_base_dir, model_name + '.zip') | |
zipfile_load = zipfile.ZipFile(zipfile_path) | |
self.zip_file_dict[model_name] = zipfile_load | |
self.transform = transforms.Compose([ | |
transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR), | |
transforms.CenterCrop(self.size), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5], [0.5]), | |
]) | |
self.clip_image_processor = CLIPImageProcessor() | |
self.i_drop_rate = i_drop_rate | |
def getdata(self, idx): | |
data_name = self.data_list[idx] | |
data_model_name = self.dict_data_image[data_name]['model_name'] | |
zipfile_loaded = self.zip_file_dict[data_model_name] | |
# zipfile_path = os.path.join(self.data_base_dir, data_model_name) | |
# zipfile_loaded = zipfile.ZipFile(zipfile_path) | |
with zipfile_loaded.open(self.dict_data_image[data_name]['z_dir'], 'r') as f: | |
buffer = io.BytesIO(f.read()) | |
data_z = torch.load(buffer) | |
with zipfile_loaded.open(self.dict_data_image[data_name]['vert_dir'], 'r') as f: | |
buffer = io.BytesIO(f.read()) | |
data_vert = torch.load(buffer) | |
with zipfile_loaded.open(self.dict_data_image[data_name]['img_dir'], 'r') as f: | |
raw_image = to_rgb_image(Image.open(f)) | |
dino_img = self.dino_img_processor(images=raw_image, return_tensors="pt").pixel_values | |
image = self.transform(raw_image.convert("RGB")) | |
clip_image = self.clip_image_processor(images=raw_image, return_tensors="pt").pixel_values | |
drop_image_embed = 0 | |
rand_num = random.random() | |
if rand_num < self.i_drop_rate: | |
drop_image_embed = 1 | |
return { | |
"raw_image": raw_image, | |
"dino_img": dino_img, | |
"image": image, | |
"clip_image": clip_image.clone(), | |
"data_z": data_z, | |
"data_vert": data_vert, | |
"data_model_name": data_model_name, | |
"drop_image_embed": drop_image_embed, | |
} | |
# | |
# img_path = self.img_samples[index] | |
# npz_path = self.txt_feat_samples[index] | |
# npy_path = self.vae_feat_samples[index] | |
# prompt = self.prompt_samples[index] | |
# data_info = { | |
# 'img_hw': torch.tensor([torch.tensor(self.resolution), torch.tensor(self.resolution)], dtype=torch.float32), | |
# 'aspect_ratio': torch.tensor(1.) | |
# } | |
# | |
# img = self.loader(npy_path) if self.load_vae_feat else self.loader(img_path) | |
# txt_info = np.load(npz_path) | |
# txt_fea = torch.from_numpy(txt_info['caption_feature']) # 1xTx4096 | |
# attention_mask = torch.ones(1, 1, txt_fea.shape[1]) # 1x1xT | |
# if 'attention_mask' in txt_info.keys(): | |
# attention_mask = torch.from_numpy(txt_info['attention_mask'])[None] | |
# if txt_fea.shape[1] != self.max_lenth: | |
# txt_fea = torch.cat([txt_fea, txt_fea[:, -1:].repeat(1, self.max_lenth-txt_fea.shape[1], 1)], dim=1) | |
# attention_mask = torch.cat([attention_mask, torch.zeros(1, 1, self.max_lenth-attention_mask.shape[-1])], dim=-1) | |
# | |
# if self.transform: | |
# img = self.transform(img) | |
# | |
# data_info['prompt'] = prompt | |
# return img, txt_fea, attention_mask, data_info | |
def __getitem__(self, idx): | |
for _ in range(20): | |
try: | |
return self.getdata(idx) | |
except Exception as e: | |
print(f"Error details: {str(e)}") | |
idx = np.random.randint(len(self)) | |
raise RuntimeError('Too many bad data.') | |
def __len__(self): | |
return len(self.data_list) | |
def __getattr__(self, name): | |
if name == "set_epoch": | |
return lambda epoch: None | |
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") | |