docker_test28 / text2tex /lib /projection_helper.py
jingyangcarl's picture
update
23dfebd
raw
history blame
17 kB
import os
import torch
import cv2
import random
import numpy as np
from torchvision import transforms
from pytorch3d.renderer import TexturesUV
from pytorch3d.ops import interpolate_face_attributes
from PIL import Image
from tqdm import tqdm
# customized
import sys
sys.path.append(".")
from lib.camera_helper import init_camera
from lib.render_helper import init_renderer, render
from lib.shading_helper import (
BlendParams,
init_soft_phong_shader,
init_flat_texel_shader,
)
from lib.vis_helper import visualize_outputs, visualize_quad_mask
from lib.constants import *
def get_all_4_locations(values_y, values_x):
y_0 = torch.floor(values_y)
y_1 = torch.ceil(values_y)
x_0 = torch.floor(values_x)
x_1 = torch.ceil(values_x)
return torch.cat([y_0, y_0, y_1, y_1], 0).long(), torch.cat([x_0, x_1, x_0, x_1], 0).long()
def compose_quad_mask(new_mask_image, update_mask_image, old_mask_image, device):
"""
compose quad mask:
-> 0: background
-> 1: old
-> 2: update
-> 3: new
"""
new_mask_tensor = transforms.ToTensor()(new_mask_image).to(device)
update_mask_tensor = transforms.ToTensor()(update_mask_image).to(device)
old_mask_tensor = transforms.ToTensor()(old_mask_image).to(device)
all_mask_tensor = new_mask_tensor + update_mask_tensor + old_mask_tensor
quad_mask_tensor = torch.zeros_like(all_mask_tensor)
quad_mask_tensor[old_mask_tensor == 1] = 1
quad_mask_tensor[update_mask_tensor == 1] = 2
quad_mask_tensor[new_mask_tensor == 1] = 3
return old_mask_tensor, update_mask_tensor, new_mask_tensor, all_mask_tensor, quad_mask_tensor
def compute_view_heat(similarity_tensor, quad_mask_tensor):
num_total_pixels = quad_mask_tensor.reshape(-1).shape[0]
heat = 0
for idx in QUAD_WEIGHTS:
heat += (quad_mask_tensor == idx).sum() * QUAD_WEIGHTS[idx] / num_total_pixels
return heat
def select_viewpoint(selected_view_ids, view_punishments,
mode, dist_list, elev_list, azim_list, sector_list, view_idx,
similarity_texture_cache, exist_texture,
mesh, faces, verts_uvs,
image_size, faces_per_pixel,
init_image_dir, mask_image_dir, normal_map_dir, depth_map_dir, similarity_map_dir,
device, use_principle=False
):
if mode == "sequential":
num_views = len(dist_list)
dist = dist_list[view_idx % num_views]
elev = elev_list[view_idx % num_views]
azim = azim_list[view_idx % num_views]
sector = sector_list[view_idx % num_views]
selected_view_ids.append(view_idx % num_views)
elif mode == "heuristic":
if use_principle and view_idx < 6:
selected_view_idx = view_idx
else:
selected_view_idx = None
max_heat = 0
print("=> selecting next view...")
view_heat_list = []
for sample_idx in tqdm(range(len(dist_list))):
view_heat, *_ = render_one_view_and_build_masks(dist_list[sample_idx], elev_list[sample_idx], azim_list[sample_idx],
sample_idx, sample_idx, view_punishments,
similarity_texture_cache, exist_texture,
mesh, faces, verts_uvs,
image_size, faces_per_pixel,
init_image_dir, mask_image_dir, normal_map_dir, depth_map_dir, similarity_map_dir,
device)
if view_heat > max_heat:
selected_view_idx = sample_idx
max_heat = view_heat
view_heat_list.append(view_heat.item())
print(view_heat_list)
print("select view {} with heat {}".format(selected_view_idx, max_heat))
dist = dist_list[selected_view_idx]
elev = elev_list[selected_view_idx]
azim = azim_list[selected_view_idx]
sector = sector_list[selected_view_idx]
selected_view_ids.append(selected_view_idx)
view_punishments[selected_view_idx] *= 0.01
elif mode == "random":
selected_view_idx = random.choice(range(len(dist_list)))
dist = dist_list[selected_view_idx]
elev = elev_list[selected_view_idx]
azim = azim_list[selected_view_idx]
sector = sector_list[selected_view_idx]
selected_view_ids.append(selected_view_idx)
else:
raise NotImplementedError()
return dist, elev, azim, sector, selected_view_ids, view_punishments
@torch.no_grad()
def build_backproject_mask(mesh, faces, verts_uvs,
cameras, reference_image, faces_per_pixel,
image_size, uv_size, device):
# construct pixel UVs
renderer_scaled = init_renderer(cameras,
shader=init_soft_phong_shader(
camera=cameras,
blend_params=BlendParams(),
device=device),
image_size=image_size,
faces_per_pixel=faces_per_pixel
)
fragments_scaled = renderer_scaled.rasterizer(mesh)
# get UV coordinates for each pixel
faces_verts_uvs = verts_uvs[faces.textures_idx]
pixel_uvs = interpolate_face_attributes(
fragments_scaled.pix_to_face, fragments_scaled.bary_coords, faces_verts_uvs
) # NxHsxWsxKx2
pixel_uvs = pixel_uvs.permute(0, 3, 1, 2, 4).reshape(-1, 2)
texture_locations_y, texture_locations_x = get_all_4_locations(
(1 - pixel_uvs[:, 1]).reshape(-1) * (uv_size - 1),
pixel_uvs[:, 0].reshape(-1) * (uv_size - 1)
)
K = faces_per_pixel
texture_values = torch.from_numpy(np.array(reference_image.resize((image_size, image_size)))).float() / 255.
texture_values = texture_values.to(device).unsqueeze(0).expand([4, -1, -1, -1]).unsqueeze(0).expand([K, -1, -1, -1, -1])
# texture
texture_tensor = torch.zeros(uv_size, uv_size, 3).to(device)
texture_tensor[texture_locations_y, texture_locations_x, :] = texture_values.reshape(-1, 3)
return texture_tensor[:, :, 0]
@torch.no_grad()
def build_diffusion_mask(mesh_stuff,
renderer, exist_texture, similarity_texture_cache, target_value, device, image_size,
smooth_mask=False, view_threshold=0.01):
mesh, faces, verts_uvs = mesh_stuff
mask_mesh = mesh.clone() # NOTE in-place operation - DANGER!!!
# visible mask => the whole region
exist_texture_expand = exist_texture.unsqueeze(0).unsqueeze(-1).expand(-1, -1, -1, 3).to(device)
mask_mesh.textures = TexturesUV(
maps=torch.ones_like(exist_texture_expand),
faces_uvs=faces.textures_idx[None, ...],
verts_uvs=verts_uvs[None, ...],
sampling_mode="nearest"
)
# visible_mask_tensor, *_ = render(mask_mesh, renderer)
visible_mask_tensor, _, similarity_map_tensor, *_ = render(mask_mesh, renderer)
# faces that are too rotated away from the viewpoint will be treated as invisible
valid_mask_tensor = (similarity_map_tensor >= view_threshold).float()
visible_mask_tensor *= valid_mask_tensor
# nonexist mask <=> new mask
exist_texture_expand = exist_texture.unsqueeze(0).unsqueeze(-1).expand(-1, -1, -1, 3).to(device)
mask_mesh.textures = TexturesUV(
maps=1 - exist_texture_expand,
faces_uvs=faces.textures_idx[None, ...],
verts_uvs=verts_uvs[None, ...],
sampling_mode="nearest"
)
new_mask_tensor, *_ = render(mask_mesh, renderer)
new_mask_tensor *= valid_mask_tensor
# exist mask => visible mask - new mask
exist_mask_tensor = visible_mask_tensor - new_mask_tensor
exist_mask_tensor[exist_mask_tensor < 0] = 0 # NOTE dilate can lead to overflow
# all update mask
mask_mesh.textures = TexturesUV(
maps=(
similarity_texture_cache.argmax(0) == target_value
# # only consider the views that have already appeared before
# similarity_texture_cache[0:target_value+1].argmax(0) == target_value
).float().unsqueeze(0).unsqueeze(-1).expand(-1, -1, -1, 3).to(device),
faces_uvs=faces.textures_idx[None, ...],
verts_uvs=verts_uvs[None, ...],
sampling_mode="nearest"
)
all_update_mask_tensor, *_ = render(mask_mesh, renderer)
# current update mask => intersection between all update mask and exist mask
update_mask_tensor = exist_mask_tensor * all_update_mask_tensor
# keep mask => exist mask - update mask
old_mask_tensor = exist_mask_tensor - update_mask_tensor
# convert
new_mask = new_mask_tensor[0].cpu().float().permute(2, 0, 1)
new_mask = transforms.ToPILImage()(new_mask).convert("L")
update_mask = update_mask_tensor[0].cpu().float().permute(2, 0, 1)
update_mask = transforms.ToPILImage()(update_mask).convert("L")
old_mask = old_mask_tensor[0].cpu().float().permute(2, 0, 1)
old_mask = transforms.ToPILImage()(old_mask).convert("L")
exist_mask = exist_mask_tensor[0].cpu().float().permute(2, 0, 1)
exist_mask = transforms.ToPILImage()(exist_mask).convert("L")
return new_mask, update_mask, old_mask, exist_mask
@torch.no_grad()
def render_one_view(mesh,
dist, elev, azim,
image_size, faces_per_pixel,
device):
# render the view
cameras = init_camera(
dist, elev, azim,
image_size, device
)
renderer = init_renderer(cameras,
shader=init_soft_phong_shader(
camera=cameras,
blend_params=BlendParams(),
device=device),
image_size=image_size,
faces_per_pixel=faces_per_pixel
)
init_images_tensor, normal_maps_tensor, similarity_tensor, depth_maps_tensor, fragments = render(mesh, renderer)
return (
cameras, renderer,
init_images_tensor, normal_maps_tensor, similarity_tensor, depth_maps_tensor, fragments
)
@torch.no_grad()
def build_similarity_texture_cache_for_all_views(mesh, faces, verts_uvs,
dist_list, elev_list, azim_list,
image_size, image_size_scaled, uv_size, faces_per_pixel,
device):
num_candidate_views = len(dist_list)
similarity_texture_cache = torch.zeros(num_candidate_views, uv_size, uv_size).to(device)
print("=> building similarity texture cache for all views...")
for i in tqdm(range(num_candidate_views)):
cameras, _, _, _, similarity_tensor, _, _ = render_one_view(mesh,
dist_list[i], elev_list[i], azim_list[i],
image_size, faces_per_pixel, device)
similarity_texture_cache[i] = build_backproject_mask(mesh, faces, verts_uvs,
cameras, transforms.ToPILImage()(similarity_tensor[0, :, :, 0]).convert("RGB"), faces_per_pixel,
image_size_scaled, uv_size, device)
return similarity_texture_cache
@torch.no_grad()
def render_one_view_and_build_masks(dist, elev, azim,
selected_view_idx, view_idx, view_punishments,
similarity_texture_cache, exist_texture,
mesh, faces, verts_uvs,
image_size, faces_per_pixel,
init_image_dir, mask_image_dir, normal_map_dir, depth_map_dir, similarity_map_dir,
device, save_intermediate=False, smooth_mask=False, view_threshold=0.01):
# render the view
(
cameras, renderer,
init_images_tensor, normal_maps_tensor, similarity_tensor, depth_maps_tensor, fragments
) = render_one_view(mesh,
dist, elev, azim,
image_size, faces_per_pixel,
device
)
init_image = init_images_tensor[0].cpu()
init_image = init_image.permute(2, 0, 1)
init_image = transforms.ToPILImage()(init_image).convert("RGB")
normal_map = normal_maps_tensor[0].cpu()
normal_map = normal_map.permute(2, 0, 1)
normal_map = transforms.ToPILImage()(normal_map).convert("RGB")
depth_map = depth_maps_tensor[0].cpu().numpy()
depth_map = Image.fromarray(depth_map).convert("L")
similarity_map = similarity_tensor[0, :, :, 0].cpu()
similarity_map = transforms.ToPILImage()(similarity_map).convert("L")
flat_renderer = init_renderer(cameras,
shader=init_flat_texel_shader(
camera=cameras,
device=device),
image_size=image_size,
faces_per_pixel=faces_per_pixel
)
new_mask_image, update_mask_image, old_mask_image, exist_mask_image = build_diffusion_mask(
(mesh, faces, verts_uvs),
flat_renderer, exist_texture, similarity_texture_cache, selected_view_idx, device, image_size,
smooth_mask=smooth_mask, view_threshold=view_threshold
)
# NOTE the view idx is the absolute idx in the sample space (i.e. `selected_view_idx`)
# it should match with `similarity_texture_cache`
(
old_mask_tensor,
update_mask_tensor,
new_mask_tensor,
all_mask_tensor,
quad_mask_tensor
) = compose_quad_mask(new_mask_image, update_mask_image, old_mask_image, device)
view_heat = compute_view_heat(similarity_tensor, quad_mask_tensor)
view_heat *= view_punishments[selected_view_idx]
# save intermediate results
if save_intermediate:
init_image.save(os.path.join(init_image_dir, "{}.png".format(view_idx)))
normal_map.save(os.path.join(normal_map_dir, "{}.png".format(view_idx)))
depth_map.save(os.path.join(depth_map_dir, "{}.png".format(view_idx)))
similarity_map.save(os.path.join(similarity_map_dir, "{}.png".format(view_idx)))
new_mask_image.save(os.path.join(mask_image_dir, "{}_new.png".format(view_idx)))
update_mask_image.save(os.path.join(mask_image_dir, "{}_update.png".format(view_idx)))
old_mask_image.save(os.path.join(mask_image_dir, "{}_old.png".format(view_idx)))
exist_mask_image.save(os.path.join(mask_image_dir, "{}_exist.png".format(view_idx)))
visualize_quad_mask(mask_image_dir, quad_mask_tensor, view_idx, view_heat, device)
return (
view_heat,
renderer, cameras, fragments,
init_image, normal_map, depth_map,
init_images_tensor, normal_maps_tensor, depth_maps_tensor, similarity_tensor,
old_mask_image, update_mask_image, new_mask_image,
old_mask_tensor, update_mask_tensor, new_mask_tensor, all_mask_tensor, quad_mask_tensor
)
@torch.no_grad()
def backproject_from_image(mesh, faces, verts_uvs, cameras,
reference_image, new_mask_image, update_mask_image,
init_texture, exist_texture,
image_size, uv_size, faces_per_pixel,
device):
# construct pixel UVs
renderer_scaled = init_renderer(cameras,
shader=init_soft_phong_shader(
camera=cameras,
blend_params=BlendParams(),
device=device),
image_size=image_size,
faces_per_pixel=faces_per_pixel
)
fragments_scaled = renderer_scaled.rasterizer(mesh)
# get UV coordinates for each pixel
faces_verts_uvs = verts_uvs[faces.textures_idx]
pixel_uvs = interpolate_face_attributes(
fragments_scaled.pix_to_face, fragments_scaled.bary_coords, faces_verts_uvs
) # NxHsxWsxKx2
pixel_uvs = pixel_uvs.permute(0, 3, 1, 2, 4).reshape(pixel_uvs.shape[-2], pixel_uvs.shape[1], pixel_uvs.shape[2], 2)
# the update mask has to be on top of the diffusion mask
new_mask_image_tensor = transforms.ToTensor()(new_mask_image).to(device).unsqueeze(-1)
update_mask_image_tensor = transforms.ToTensor()(update_mask_image).to(device).unsqueeze(-1)
project_mask_image_tensor = torch.logical_or(update_mask_image_tensor, new_mask_image_tensor).float()
project_mask_image = project_mask_image_tensor * 255.
project_mask_image = Image.fromarray(project_mask_image[0, :, :, 0].cpu().numpy().astype(np.uint8))
project_mask_image_scaled = project_mask_image.resize(
(image_size, image_size),
Image.Resampling.NEAREST
)
project_mask_image_tensor_scaled = transforms.ToTensor()(project_mask_image_scaled).to(device)
pixel_uvs_masked = pixel_uvs[project_mask_image_tensor_scaled == 1]
texture_locations_y, texture_locations_x = get_all_4_locations(
(1 - pixel_uvs_masked[:, 1]).reshape(-1) * (uv_size - 1),
pixel_uvs_masked[:, 0].reshape(-1) * (uv_size - 1)
)
K = pixel_uvs.shape[0]
project_mask_image_tensor_scaled = project_mask_image_tensor_scaled[:, None, :, :, None].repeat(1, 4, 1, 1, 3)
texture_values = torch.from_numpy(np.array(reference_image.resize((image_size, image_size))))
texture_values = texture_values.to(device).unsqueeze(0).expand([4, -1, -1, -1]).unsqueeze(0).expand([K, -1, -1, -1, -1])
texture_values_masked = texture_values.reshape(-1, 3)[project_mask_image_tensor_scaled.reshape(-1, 3) == 1].reshape(-1, 3)
# texture
texture_tensor = torch.from_numpy(np.array(init_texture)).to(device)
texture_tensor[texture_locations_y, texture_locations_x, :] = texture_values_masked
init_texture = Image.fromarray(texture_tensor.cpu().numpy().astype(np.uint8))
# update texture cache
exist_texture[texture_locations_y, texture_locations_x] = 1
return init_texture, project_mask_image, exist_texture