File size: 5,609 Bytes
ab06a25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import os
import random
from PIL import Image
import numpy as np
import torch
from torch.utils.data import Dataset
from transformers import AutoImageProcessor
from DiT_VAE.diffusion.data.builder import DATASETS
from omegaconf import OmegaConf
from torchvision import transforms
from transformers import CLIPImageProcessor
import io
import zipfile
import numpy
import json


def to_rgb_image(maybe_rgba: Image.Image):
    if maybe_rgba.mode == 'RGB':
        return maybe_rgba
    elif maybe_rgba.mode == 'RGBA':
        rgba = maybe_rgba
        img = numpy.random.randint(127, 128, size=[rgba.size[1], rgba.size[0], 3], dtype=numpy.uint8)
        img = Image.fromarray(img, 'RGB')
        img.paste(rgba, mask=rgba.getchannel('A'))
        return img
    else:
        raise ValueError("Unsupported image type.", maybe_rgba.mode)


@DATASETS.register_module()
class TriplaneData(Dataset):
    def __init__(self,
                 data_base_dir,
                 model_names,
                 data_json_file,
                 dino_path,
                 i_drop_rate=0.1,
                 image_size=256,
                 **kwargs):
        self.dict_data_image = json.load(open(data_json_file))  # {'image_name': pose}
        self.data_base_dir = data_base_dir
        self.dino_img_processor = AutoImageProcessor.from_pretrained(dino_path)
        self.size = image_size
        self.data_list = list(self.dict_data_image.keys())
        self.zip_file_dict = {}
        config_gan_model = OmegaConf.load(model_names)
        all_models = config_gan_model['gan_models'].keys()
        for model_name in all_models:
            zipfile_path = os.path.join(self.data_base_dir, model_name + '.zip')
            zipfile_load = zipfile.ZipFile(zipfile_path)
            self.zip_file_dict[model_name] = zipfile_load
        self.transform = transforms.Compose([
            transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(self.size),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ])
        self.clip_image_processor = CLIPImageProcessor()
        self.i_drop_rate = i_drop_rate


    def getdata(self, idx):

        data_name = self.data_list[idx]
        data_model_name = self.dict_data_image[data_name]['model_name']
        zipfile_loaded = self.zip_file_dict[data_model_name]
        # zipfile_path = os.path.join(self.data_base_dir, data_model_name)
        # zipfile_loaded = zipfile.ZipFile(zipfile_path)
        with zipfile_loaded.open(self.dict_data_image[data_name]['z_dir'], 'r') as f:
            buffer = io.BytesIO(f.read())
            data_z = torch.load(buffer)

        with zipfile_loaded.open(self.dict_data_image[data_name]['vert_dir'], 'r') as f:
            buffer = io.BytesIO(f.read())
            data_vert = torch.load(buffer)

        with zipfile_loaded.open(self.dict_data_image[data_name]['img_dir'], 'r') as f:
            raw_image = to_rgb_image(Image.open(f))
            dino_img = self.dino_img_processor(images=raw_image, return_tensors="pt").pixel_values
            image = self.transform(raw_image.convert("RGB"))
            clip_image = self.clip_image_processor(images=raw_image, return_tensors="pt").pixel_values
        drop_image_embed = 0
        rand_num = random.random()
        if rand_num < self.i_drop_rate:
            drop_image_embed = 1
        return {
            "raw_image": raw_image,
            "dino_img": dino_img,
            "image": image,
            "clip_image": clip_image.clone(),
            "data_z": data_z,
            "data_vert": data_vert,
            "data_model_name": data_model_name,
            "drop_image_embed": drop_image_embed,

        }

        #
        # img_path = self.img_samples[index]
        # npz_path = self.txt_feat_samples[index]
        # npy_path = self.vae_feat_samples[index]
        # prompt = self.prompt_samples[index]
        # data_info = {
        #     'img_hw': torch.tensor([torch.tensor(self.resolution), torch.tensor(self.resolution)], dtype=torch.float32),
        #     'aspect_ratio': torch.tensor(1.)
        # }
        #
        # img = self.loader(npy_path) if self.load_vae_feat else self.loader(img_path)
        # txt_info = np.load(npz_path)
        # txt_fea = torch.from_numpy(txt_info['caption_feature'])     # 1xTx4096
        # attention_mask = torch.ones(1, 1, txt_fea.shape[1])     # 1x1xT
        # if 'attention_mask' in txt_info.keys():
        #     attention_mask = torch.from_numpy(txt_info['attention_mask'])[None]
        # if txt_fea.shape[1] != self.max_lenth:
        #     txt_fea = torch.cat([txt_fea, txt_fea[:, -1:].repeat(1, self.max_lenth-txt_fea.shape[1], 1)], dim=1)
        #     attention_mask = torch.cat([attention_mask, torch.zeros(1, 1, self.max_lenth-attention_mask.shape[-1])], dim=-1)
        #
        # if self.transform:
        #     img = self.transform(img)
        #
        # data_info['prompt'] = prompt
        # return img, txt_fea, attention_mask, data_info

    def __getitem__(self, idx):
        for _ in range(20):
            try:
                return self.getdata(idx)
            except Exception as e:
                print(f"Error details: {str(e)}")
                idx = np.random.randint(len(self))
        raise RuntimeError('Too many bad data.')


    def __len__(self):
        return len(self.data_list)

    def __getattr__(self, name):
        if name == "set_epoch":
            return lambda epoch: None
        raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")