matgen / text2tex /lib /projection_helper.py
jingyangcarl's picture
pushback
b5042f1
import os
import torch
import cv2
import random
import numpy as np
from torchvision import transforms
from pytorch3d.renderer import TexturesUV
from pytorch3d.ops import interpolate_face_attributes
from PIL import Image
from tqdm import tqdm
# customized
import sys
sys.path.append(".")
from lib.camera_helper import init_camera
from lib.render_helper import init_renderer, render
from lib.shading_helper import (
BlendParams,
init_soft_phong_shader,
init_flat_texel_shader,
)
from lib.vis_helper import visualize_outputs, visualize_quad_mask
from lib.constants import *
def get_all_4_locations(values_y, values_x):
y_0 = torch.floor(values_y)
y_1 = torch.ceil(values_y)
x_0 = torch.floor(values_x)
x_1 = torch.ceil(values_x)
return torch.cat([y_0, y_0, y_1, y_1], 0).long(), torch.cat([x_0, x_1, x_0, x_1], 0).long()
def compose_quad_mask(new_mask_image, update_mask_image, old_mask_image, device):
"""
compose quad mask:
-> 0: background
-> 1: old
-> 2: update
-> 3: new
"""
new_mask_tensor = transforms.ToTensor()(new_mask_image).to(device)
update_mask_tensor = transforms.ToTensor()(update_mask_image).to(device)
old_mask_tensor = transforms.ToTensor()(old_mask_image).to(device)
all_mask_tensor = new_mask_tensor + update_mask_tensor + old_mask_tensor
quad_mask_tensor = torch.zeros_like(all_mask_tensor)
quad_mask_tensor[old_mask_tensor == 1] = 1
quad_mask_tensor[update_mask_tensor == 1] = 2
quad_mask_tensor[new_mask_tensor == 1] = 3
return old_mask_tensor, update_mask_tensor, new_mask_tensor, all_mask_tensor, quad_mask_tensor
def compute_view_heat(similarity_tensor, quad_mask_tensor):
num_total_pixels = quad_mask_tensor.reshape(-1).shape[0]
heat = 0
for idx in QUAD_WEIGHTS:
heat += (quad_mask_tensor == idx).sum() * QUAD_WEIGHTS[idx] / num_total_pixels
return heat
def select_viewpoint(selected_view_ids, view_punishments,
mode, dist_list, elev_list, azim_list, sector_list, view_idx,
similarity_texture_cache, exist_texture,
mesh, faces, verts_uvs,
image_size, faces_per_pixel,
init_image_dir, mask_image_dir, normal_map_dir, depth_map_dir, similarity_map_dir,
device, use_principle=False
):
if mode == "sequential":
num_views = len(dist_list)
dist = dist_list[view_idx % num_views]
elev = elev_list[view_idx % num_views]
azim = azim_list[view_idx % num_views]
sector = sector_list[view_idx % num_views]
selected_view_ids.append(view_idx % num_views)
elif mode == "heuristic":
if use_principle and view_idx < 6:
selected_view_idx = view_idx
else:
selected_view_idx = None
max_heat = 0
print("=> selecting next view...")
view_heat_list = []
for sample_idx in tqdm(range(len(dist_list))):
view_heat, *_ = render_one_view_and_build_masks(dist_list[sample_idx], elev_list[sample_idx], azim_list[sample_idx],
sample_idx, sample_idx, view_punishments,
similarity_texture_cache, exist_texture,
mesh, faces, verts_uvs,
image_size, faces_per_pixel,
init_image_dir, mask_image_dir, normal_map_dir, depth_map_dir, similarity_map_dir,
device)
if view_heat > max_heat:
selected_view_idx = sample_idx
max_heat = view_heat
view_heat_list.append(view_heat.item())
print(view_heat_list)
print("select view {} with heat {}".format(selected_view_idx, max_heat))
dist = dist_list[selected_view_idx]
elev = elev_list[selected_view_idx]
azim = azim_list[selected_view_idx]
sector = sector_list[selected_view_idx]
selected_view_ids.append(selected_view_idx)
view_punishments[selected_view_idx] *= 0.01
elif mode == "random":
selected_view_idx = random.choice(range(len(dist_list)))
dist = dist_list[selected_view_idx]
elev = elev_list[selected_view_idx]
azim = azim_list[selected_view_idx]
sector = sector_list[selected_view_idx]
selected_view_ids.append(selected_view_idx)
else:
raise NotImplementedError()
return dist, elev, azim, sector, selected_view_ids, view_punishments
@torch.no_grad()
def build_backproject_mask(mesh, faces, verts_uvs,
cameras, reference_image, faces_per_pixel,
image_size, uv_size, device):
# construct pixel UVs
renderer_scaled = init_renderer(cameras,
shader=init_soft_phong_shader(
camera=cameras,
blend_params=BlendParams(),
device=device),
image_size=image_size,
faces_per_pixel=faces_per_pixel
)
fragments_scaled = renderer_scaled.rasterizer(mesh)
# get UV coordinates for each pixel
faces_verts_uvs = verts_uvs[faces.textures_idx]
pixel_uvs = interpolate_face_attributes(
fragments_scaled.pix_to_face, fragments_scaled.bary_coords, faces_verts_uvs
) # NxHsxWsxKx2
pixel_uvs = pixel_uvs.permute(0, 3, 1, 2, 4).reshape(-1, 2)
texture_locations_y, texture_locations_x = get_all_4_locations(
(1 - pixel_uvs[:, 1]).reshape(-1) * (uv_size - 1),
pixel_uvs[:, 0].reshape(-1) * (uv_size - 1)
)
K = faces_per_pixel
texture_values = torch.from_numpy(np.array(reference_image.resize((image_size, image_size)))).float() / 255.
texture_values = texture_values.to(device).unsqueeze(0).expand([4, -1, -1, -1]).unsqueeze(0).expand([K, -1, -1, -1, -1])
# texture
texture_tensor = torch.zeros(uv_size, uv_size, 3).to(device)
texture_tensor[texture_locations_y, texture_locations_x, :] = texture_values.reshape(-1, 3)
return texture_tensor[:, :, 0]
@torch.no_grad()
def build_diffusion_mask(mesh_stuff,
renderer, exist_texture, similarity_texture_cache, target_value, device, image_size,
smooth_mask=False, view_threshold=0.01):
mesh, faces, verts_uvs = mesh_stuff
mask_mesh = mesh.clone() # NOTE in-place operation - DANGER!!!
# visible mask => the whole region
exist_texture_expand = exist_texture.unsqueeze(0).unsqueeze(-1).expand(-1, -1, -1, 3).to(device)
mask_mesh.textures = TexturesUV(
maps=torch.ones_like(exist_texture_expand),
faces_uvs=faces.textures_idx[None, ...],
verts_uvs=verts_uvs[None, ...],
sampling_mode="nearest"
)
# visible_mask_tensor, *_ = render(mask_mesh, renderer)
visible_mask_tensor, _, similarity_map_tensor, *_ = render(mask_mesh, renderer)
# faces that are too rotated away from the viewpoint will be treated as invisible
valid_mask_tensor = (similarity_map_tensor >= view_threshold).float()
visible_mask_tensor *= valid_mask_tensor
# nonexist mask <=> new mask
exist_texture_expand = exist_texture.unsqueeze(0).unsqueeze(-1).expand(-1, -1, -1, 3).to(device)
mask_mesh.textures = TexturesUV(
maps=1 - exist_texture_expand,
faces_uvs=faces.textures_idx[None, ...],
verts_uvs=verts_uvs[None, ...],
sampling_mode="nearest"
)
new_mask_tensor, *_ = render(mask_mesh, renderer)
new_mask_tensor *= valid_mask_tensor
# exist mask => visible mask - new mask
exist_mask_tensor = visible_mask_tensor - new_mask_tensor
exist_mask_tensor[exist_mask_tensor < 0] = 0 # NOTE dilate can lead to overflow
# all update mask
mask_mesh.textures = TexturesUV(
maps=(
similarity_texture_cache.argmax(0) == target_value
# # only consider the views that have already appeared before
# similarity_texture_cache[0:target_value+1].argmax(0) == target_value
).float().unsqueeze(0).unsqueeze(-1).expand(-1, -1, -1, 3).to(device),
faces_uvs=faces.textures_idx[None, ...],
verts_uvs=verts_uvs[None, ...],
sampling_mode="nearest"
)
all_update_mask_tensor, *_ = render(mask_mesh, renderer)
# current update mask => intersection between all update mask and exist mask
update_mask_tensor = exist_mask_tensor * all_update_mask_tensor
# keep mask => exist mask - update mask
old_mask_tensor = exist_mask_tensor - update_mask_tensor
# convert
new_mask = new_mask_tensor[0].cpu().float().permute(2, 0, 1)
new_mask = transforms.ToPILImage()(new_mask).convert("L")
update_mask = update_mask_tensor[0].cpu().float().permute(2, 0, 1)
update_mask = transforms.ToPILImage()(update_mask).convert("L")
old_mask = old_mask_tensor[0].cpu().float().permute(2, 0, 1)
old_mask = transforms.ToPILImage()(old_mask).convert("L")
exist_mask = exist_mask_tensor[0].cpu().float().permute(2, 0, 1)
exist_mask = transforms.ToPILImage()(exist_mask).convert("L")
return new_mask, update_mask, old_mask, exist_mask
@torch.no_grad()
def render_one_view(mesh,
dist, elev, azim,
image_size, faces_per_pixel,
device):
# render the view
cameras = init_camera(
dist, elev, azim,
image_size, device
)
renderer = init_renderer(cameras,
shader=init_soft_phong_shader(
camera=cameras,
blend_params=BlendParams(),
device=device),
image_size=image_size,
faces_per_pixel=faces_per_pixel
)
init_images_tensor, normal_maps_tensor, similarity_tensor, depth_maps_tensor, fragments = render(mesh, renderer)
return (
cameras, renderer,
init_images_tensor, normal_maps_tensor, similarity_tensor, depth_maps_tensor, fragments
)
@torch.no_grad()
def build_similarity_texture_cache_for_all_views(mesh, faces, verts_uvs,
dist_list, elev_list, azim_list,
image_size, image_size_scaled, uv_size, faces_per_pixel,
device):
num_candidate_views = len(dist_list)
similarity_texture_cache = torch.zeros(num_candidate_views, uv_size, uv_size).to(device)
print("=> building similarity texture cache for all views...")
for i in tqdm(range(num_candidate_views)):
cameras, _, _, _, similarity_tensor, _, _ = render_one_view(mesh,
dist_list[i], elev_list[i], azim_list[i],
image_size, faces_per_pixel, device)
similarity_texture_cache[i] = build_backproject_mask(mesh, faces, verts_uvs,
cameras, transforms.ToPILImage()(similarity_tensor[0, :, :, 0]).convert("RGB"), faces_per_pixel,
image_size_scaled, uv_size, device)
return similarity_texture_cache
@torch.no_grad()
def render_one_view_and_build_masks(dist, elev, azim,
selected_view_idx, view_idx, view_punishments,
similarity_texture_cache, exist_texture,
mesh, faces, verts_uvs,
image_size, faces_per_pixel,
init_image_dir, mask_image_dir, normal_map_dir, depth_map_dir, similarity_map_dir,
device, save_intermediate=False, smooth_mask=False, view_threshold=0.01):
# render the view
(
cameras, renderer,
init_images_tensor, normal_maps_tensor, similarity_tensor, depth_maps_tensor, fragments
) = render_one_view(mesh,
dist, elev, azim,
image_size, faces_per_pixel,
device
)
init_image = init_images_tensor[0].cpu()
init_image = init_image.permute(2, 0, 1)
init_image = transforms.ToPILImage()(init_image).convert("RGB")
normal_map = normal_maps_tensor[0].cpu()
normal_map = normal_map.permute(2, 0, 1)
normal_map = transforms.ToPILImage()(normal_map).convert("RGB")
depth_map = depth_maps_tensor[0].cpu().numpy()
depth_map = Image.fromarray(depth_map).convert("L")
similarity_map = similarity_tensor[0, :, :, 0].cpu()
similarity_map = transforms.ToPILImage()(similarity_map).convert("L")
flat_renderer = init_renderer(cameras,
shader=init_flat_texel_shader(
camera=cameras,
device=device),
image_size=image_size,
faces_per_pixel=faces_per_pixel
)
new_mask_image, update_mask_image, old_mask_image, exist_mask_image = build_diffusion_mask(
(mesh, faces, verts_uvs),
flat_renderer, exist_texture, similarity_texture_cache, selected_view_idx, device, image_size,
smooth_mask=smooth_mask, view_threshold=view_threshold
)
# NOTE the view idx is the absolute idx in the sample space (i.e. `selected_view_idx`)
# it should match with `similarity_texture_cache`
(
old_mask_tensor,
update_mask_tensor,
new_mask_tensor,
all_mask_tensor,
quad_mask_tensor
) = compose_quad_mask(new_mask_image, update_mask_image, old_mask_image, device)
view_heat = compute_view_heat(similarity_tensor, quad_mask_tensor)
view_heat *= view_punishments[selected_view_idx]
# save intermediate results
if save_intermediate:
init_image.save(os.path.join(init_image_dir, "{}.png".format(view_idx)))
normal_map.save(os.path.join(normal_map_dir, "{}.png".format(view_idx)))
depth_map.save(os.path.join(depth_map_dir, "{}.png".format(view_idx)))
similarity_map.save(os.path.join(similarity_map_dir, "{}.png".format(view_idx)))
new_mask_image.save(os.path.join(mask_image_dir, "{}_new.png".format(view_idx)))
update_mask_image.save(os.path.join(mask_image_dir, "{}_update.png".format(view_idx)))
old_mask_image.save(os.path.join(mask_image_dir, "{}_old.png".format(view_idx)))
exist_mask_image.save(os.path.join(mask_image_dir, "{}_exist.png".format(view_idx)))
visualize_quad_mask(mask_image_dir, quad_mask_tensor, view_idx, view_heat, device)
return (
view_heat,
renderer, cameras, fragments,
init_image, normal_map, depth_map,
init_images_tensor, normal_maps_tensor, depth_maps_tensor, similarity_tensor,
old_mask_image, update_mask_image, new_mask_image,
old_mask_tensor, update_mask_tensor, new_mask_tensor, all_mask_tensor, quad_mask_tensor
)
@torch.no_grad()
def backproject_from_image(mesh, faces, verts_uvs, cameras,
reference_image, new_mask_image, update_mask_image,
init_texture, exist_texture,
image_size, uv_size, faces_per_pixel,
device):
# construct pixel UVs
renderer_scaled = init_renderer(cameras,
shader=init_soft_phong_shader(
camera=cameras,
blend_params=BlendParams(),
device=device),
image_size=image_size,
faces_per_pixel=faces_per_pixel
)
fragments_scaled = renderer_scaled.rasterizer(mesh)
# get UV coordinates for each pixel
faces_verts_uvs = verts_uvs[faces.textures_idx]
pixel_uvs = interpolate_face_attributes(
fragments_scaled.pix_to_face, fragments_scaled.bary_coords, faces_verts_uvs
) # NxHsxWsxKx2
pixel_uvs = pixel_uvs.permute(0, 3, 1, 2, 4).reshape(pixel_uvs.shape[-2], pixel_uvs.shape[1], pixel_uvs.shape[2], 2)
# the update mask has to be on top of the diffusion mask
new_mask_image_tensor = transforms.ToTensor()(new_mask_image).to(device).unsqueeze(-1)
update_mask_image_tensor = transforms.ToTensor()(update_mask_image).to(device).unsqueeze(-1)
project_mask_image_tensor = torch.logical_or(update_mask_image_tensor, new_mask_image_tensor).float()
project_mask_image = project_mask_image_tensor * 255.
project_mask_image = Image.fromarray(project_mask_image[0, :, :, 0].cpu().numpy().astype(np.uint8))
project_mask_image_scaled = project_mask_image.resize(
(image_size, image_size),
Image.Resampling.NEAREST
)
project_mask_image_tensor_scaled = transforms.ToTensor()(project_mask_image_scaled).to(device)
pixel_uvs_masked = pixel_uvs[project_mask_image_tensor_scaled == 1]
texture_locations_y, texture_locations_x = get_all_4_locations(
(1 - pixel_uvs_masked[:, 1]).reshape(-1) * (uv_size - 1),
pixel_uvs_masked[:, 0].reshape(-1) * (uv_size - 1)
)
K = pixel_uvs.shape[0]
project_mask_image_tensor_scaled = project_mask_image_tensor_scaled[:, None, :, :, None].repeat(1, 4, 1, 1, 3)
texture_values = torch.from_numpy(np.array(reference_image.resize((image_size, image_size))))
texture_values = texture_values.to(device).unsqueeze(0).expand([4, -1, -1, -1]).unsqueeze(0).expand([K, -1, -1, -1, -1])
texture_values_masked = texture_values.reshape(-1, 3)[project_mask_image_tensor_scaled.reshape(-1, 3) == 1].reshape(-1, 3)
# texture
texture_tensor = torch.from_numpy(np.array(init_texture)).to(device)
texture_tensor[texture_locations_y, texture_locations_x, :] = texture_values_masked
init_texture = Image.fromarray(texture_tensor.cpu().numpy().astype(np.uint8))
# update texture cache
exist_texture[texture_locations_y, texture_locations_x] = 1
return init_texture, project_mask_image, exist_texture