|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import time |
|
import glob |
|
import json |
|
import random |
|
import numpy as np |
|
import torch |
|
from .loader_util import BaseDataset |
|
|
|
|
|
class TextureDataset(BaseDataset): |
|
|
|
def __init__( |
|
self, json_path, num_view=6, image_size=512, lighting_suffix_pool=["light_PL", "light_AL", "light_ENVMAP"] |
|
): |
|
self.data = list() |
|
self.num_view = num_view |
|
self.image_size = image_size |
|
self.lighting_suffix_pool = lighting_suffix_pool |
|
if isinstance(json_path, str): |
|
json_path = [json_path] |
|
for jp in json_path: |
|
with open(jp) as f: |
|
self.data.extend(json.load(f)) |
|
print("============= length of dataset %d =============" % len(self.data)) |
|
|
|
def __getitem__(self, index): |
|
try_sleep_interval = 20 |
|
total_try_num = 100 |
|
cnt = try_sleep_interval * total_try_num |
|
|
|
images_ref = list() |
|
images_albedo = list() |
|
images_mr = list() |
|
images_normal = list() |
|
images_position = list() |
|
bg_white = [1.0, 1.0, 1.0] |
|
bg_black = [0.0, 0.0, 0.0] |
|
bg_gray = [127 / 255.0, 127 / 255.0, 127 / 255.0] |
|
dirx = self.data[index] |
|
|
|
condition_dict = {} |
|
|
|
|
|
fix_num_view = self.num_view |
|
available_views = [] |
|
for ext in ["*_albedo.png", "*_albedo.jpg", "*_albedo.jpeg"]: |
|
available_views.extend(glob.glob(os.path.join(dirx, "render_tex", ext))) |
|
cond_images = ( |
|
glob.glob(os.path.join(dirx, "render_cond", "*.png")) |
|
+ glob.glob(os.path.join(dirx, "render_cond", "*.jpg")) |
|
+ glob.glob(os.path.join(dirx, "render_cond", "*.jpeg")) |
|
) |
|
|
|
|
|
if len(available_views) < fix_num_view: |
|
print( |
|
f"Warning: Only {len(available_views)} views available, but {fix_num_view} requested." |
|
"Using all available views." |
|
) |
|
images_gen = available_views |
|
else: |
|
images_gen = random.sample(available_views, fix_num_view) |
|
|
|
if not cond_images: |
|
raise ValueError(f"No condition images found in {os.path.join(dirx, 'render_cond')}") |
|
ref_image_path = random.choice(cond_images) |
|
light_suffix = None |
|
for suffix in self.lighting_suffix_pool: |
|
if suffix in ref_image_path: |
|
light_suffix = suffix |
|
break |
|
if light_suffix is None: |
|
raise ValueError(f"light suffix not found in {ref_image_path}") |
|
ref_image_diff_light_path = random.choice( |
|
[ |
|
ref_image_path.replace(light_suffix, tar_suffix) |
|
for tar_suffix in self.lighting_suffix_pool |
|
if tar_suffix != light_suffix |
|
] |
|
) |
|
images_ref_paths = [ref_image_path, ref_image_diff_light_path] |
|
|
|
|
|
bg_c_record = None |
|
for i, image_ref in enumerate(images_ref_paths): |
|
if random.random() < 0.6: |
|
bg_c = bg_gray |
|
else: |
|
if random.random() < 0.5: |
|
bg_c = bg_black |
|
else: |
|
bg_c = bg_white |
|
if i == 0: |
|
bg_c_record = bg_c |
|
image, alpha = self.load_image(image_ref, bg_c_record) |
|
image = self.augment_image(image, bg_c_record).float() |
|
images_ref.append(image) |
|
condition_dict["images_cond"] = torch.stack(images_ref, dim=0).float() |
|
|
|
for i, image_gen in enumerate(images_gen): |
|
images_albedo.append(self.augment_image(self.load_image(image_gen, bg_gray)[0], bg_gray)) |
|
images_mr.append( |
|
self.augment_image(self.load_image(image_gen.replace("_albedo", "_mr"), bg_gray)[0], bg_gray) |
|
) |
|
images_normal.append( |
|
self.augment_image(self.load_image(image_gen.replace("_albedo", "_normal"), bg_gray)[0], bg_gray) |
|
) |
|
images_position.append( |
|
self.augment_image(self.load_image(image_gen.replace("_albedo", "_pos"), bg_gray)[0], bg_gray) |
|
) |
|
|
|
condition_dict["images_albedo"] = torch.stack(images_albedo, dim=0).float() |
|
condition_dict["images_mr"] = torch.stack(images_mr, dim=0).float() |
|
condition_dict["images_normal"] = torch.stack(images_normal, dim=0).float() |
|
condition_dict["images_position"] = torch.stack(images_position, dim=0).float() |
|
condition_dict["name"] = dirx |
|
return condition_dict |
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
dataset = TextureDataset(json_path=["../../../train_examples/examples.json"]) |
|
print("images_cond", dataset[0]["images_cond"].shape) |
|
print("images_albedo", dataset[0]["images_albedo"].shape) |
|
print("images_mr", dataset[0]["images_mr"].shape) |
|
print("images_normal", dataset[0]["images_normal"].shape) |
|
print("images_position", dataset[0]["images_position"].shape) |
|
print("name", dataset[0]["name"]) |
|
|