File size: 3,793 Bytes
8ed2f16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os

import numpy
import json
import zipfile
import torch
from PIL import Image
# from transformers import CLIPImageProcessor
from torch.utils.data import Dataset
import io
from omegaconf import OmegaConf
import numpy as np
# from torchvision import transforms
# from einops import rearrange
# import random
# import os
# from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler, DDIMScheduler
# import time
# import io
# import array
# import numpy as np
#
# from training.triplane import TriPlaneGenerator


def to_rgb_image(maybe_rgba: Image.Image):
    if maybe_rgba.mode == 'RGB':
        return maybe_rgba
    elif maybe_rgba.mode == 'RGBA':
        rgba = maybe_rgba
        img = numpy.random.randint(127, 128, size=[rgba.size[1], rgba.size[0], 3], dtype=numpy.uint8)
        img = Image.fromarray(img, 'RGB')
        img.paste(rgba, mask=rgba.getchannel('A'))
        return img
    else:
        raise ValueError("Unsupported image type.", maybe_rgba.mode)



# image(contain style),z,pose,text
class TriplaneDataset(Dataset):
    # image, triplane, ref_feature
    def __init__(self, json_file, data_base_dir, model_names):
        super().__init__()
        self.dict_data_image = json.load(open(json_file))  # {'image_name': pose}
        self.data_base_dir = data_base_dir
        self.data_list = list(self.dict_data_image.keys())
        self.zip_file_dict = {}
        config_gan_model = OmegaConf.load(model_names)
        all_models = config_gan_model['gan_models'].keys()
        for model_name in all_models:
            zipfile_path = os.path.join(self.data_base_dir, model_name+'.zip')
            zipfile_load = zipfile.ZipFile(zipfile_path)
            self.zip_file_dict[model_name] = zipfile_load

    def getdata(self, idx):
        # need z and expression and model name
        # image:"seed0035.png"
        # data_each_dict = {
        #     'vert_dir': vert_dir,
        #     'z_dir': z_dir,
        #     'pose_dir': pose_dir,
        #     'img_dir': img_dir,
        #     'model_name': model_name
        # }
        data_name = self.data_list[idx]
        data_model_name = self.dict_data_image[data_name]['model_name']
        zipfile_loaded = self.zip_file_dict[data_model_name]
        # zipfile_path = os.path.join(self.data_base_dir, data_model_name)
        # zipfile_loaded = zipfile.ZipFile(zipfile_path)
        with zipfile_loaded.open(self.dict_data_image[data_name]['z_dir'], 'r') as f:
            buffer = io.BytesIO(f.read())
            data_z = torch.load(buffer)
        buffer.close()
        f.close()
        with zipfile_loaded.open(self.dict_data_image[data_name]['vert_dir'], 'r') as ff:
            buffer_v = io.BytesIO(ff.read())
            data_vert = torch.load(buffer_v)
        buffer_v.close()
        ff.close()
        #     raw_image = to_rgb_image(Image.open(f))
        #
        # data_model_name = self.dict_data_image[data_name]['model_name']
        # data_z_dir = os.path.join(self.data_base_dir, data_model_name, self.dict_data_image[data_name]['z_dir'])
        # data_vert_dir = os.path.join(self.data_base_dir, data_model_name, self.dict_data_image[data_name]['vert_dir'])
        # data_z = torch.load(data_z_dir)
        # data_vert = torch.load(data_vert_dir)

        return {
            "data_z": data_z,
            "data_vert": data_vert,
            "data_model_name": data_model_name
        }

    def __getitem__(self, idx):
        for _ in range(20):
            try:
                return self.getdata(idx)
            except Exception as e:
                print(f"Error details: {str(e)}")
                idx = np.random.randint(len(self))
        raise RuntimeError('Too many bad data.')


    def __len__(self):
        return len(self.data_list)

# for zip files