Spaces:
Sleeping
Sleeping
# common utils | |
import os | |
import argparse | |
import time | |
# pytorch3d | |
from pytorch3d.renderer import TexturesUV | |
# torch | |
import torch | |
from torchvision import transforms | |
# numpy | |
import numpy as np | |
# image | |
from PIL import Image | |
# customized | |
import sys | |
# sys.path.append(".") | |
sys.path.append("./text2tex") | |
from lib.mesh_helper import ( | |
init_mesh, | |
apply_offsets_to_mesh, | |
adjust_uv_map | |
) | |
from lib.render_helper import render | |
from lib.io_helper import ( | |
save_backproject_obj, | |
save_args, | |
save_viewpoints | |
) | |
from lib.vis_helper import ( | |
visualize_outputs, | |
visualize_principle_viewpoints, | |
visualize_refinement_viewpoints | |
) | |
from lib.diffusion_helper import ( | |
get_controlnet_depth, | |
get_inpainting, | |
apply_controlnet_depth, | |
apply_inpainting_postprocess | |
) | |
from lib.projection_helper import ( | |
backproject_from_image, | |
render_one_view_and_build_masks, | |
select_viewpoint, | |
build_similarity_texture_cache_for_all_views | |
) | |
from lib.camera_helper import init_viewpoints | |
# Setup | |
if torch.cuda.is_available(): | |
DEVICE = torch.device("cuda:0") | |
torch.cuda.set_device(DEVICE) | |
else: | |
print("no gpu avaiable") | |
exit() | |
""" | |
Use Diffusion Models conditioned on depth input to back-project textures on 3D mesh. | |
The inputs should be constructed as follows: | |
- <input_dir>/ | |
|- <obj_file> # name of the input OBJ file | |
The outputs of this script would be stored under `outputs/`, with the | |
configuration parameters as the folder name. Specifically, there should be following files in such | |
folder: | |
- outputs/ | |
|- <configs>/ # configurations of the run | |
|- generate/ # assets generated in generation stage | |
|- depth/ # depth map | |
|- inpainted/ # images generated by diffusion models | |
|- intermediate/ # renderings of textured mesh after each step | |
|- mask/ # generation mask | |
|- mesh/ # textured mesh | |
|- normal/ # normal map | |
|- rendering/ # input renderings | |
|- similarity/ # simiarity map | |
|- update/ # assets generated in refinement stage | |
|- ... # the structure is the same as generate/ | |
|- args.json # all arguments for the run | |
|- viewpoints.json # all viewpoints | |
|- principle_viewpoints.png # principle viewpoints | |
|- refinement_viewpoints.png # refinement viewpoints | |
""" | |
def init_args(): | |
print("=> initializing input arguments...") | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--input_dir", type=str, default="./inputs",) | |
parser.add_argument("--output_dir", type=str, default="./outputs") | |
parser.add_argument("--obj_name", type=str, default="mesh") | |
parser.add_argument("--obj_file", type=str, default="mesh.obj") | |
parser.add_argument("--prompt", type=str, default="a 3D object") | |
parser.add_argument("--a_prompt", type=str, default="best quality, high quality, extremely detailed, good geometry") | |
parser.add_argument("--n_prompt", type=str, default="deformed, extra digit, fewer digits, cropped, worst quality, low quality, smoke") | |
parser.add_argument("--new_strength", type=float, default=1) | |
parser.add_argument("--update_strength", type=float, default=0.5) | |
parser.add_argument("--ddim_steps", type=int, default=20) | |
parser.add_argument("--guidance_scale", type=float, default=10) | |
parser.add_argument("--output_scale", type=float, default=1) | |
parser.add_argument("--view_threshold", type=float, default=0.1) | |
parser.add_argument("--num_viewpoints", type=int, default=8) | |
parser.add_argument("--viewpoint_mode", type=str, default="predefined", choices=["predefined", "hemisphere"]) | |
parser.add_argument("--update_steps", type=int, default=8) | |
parser.add_argument("--update_mode", type=str, default="heuristic", choices=["sequential", "heuristic", "random"]) | |
parser.add_argument("--blend", type=float, default=0.5) | |
parser.add_argument("--eta", type=float, default=0.0) | |
parser.add_argument("--seed", type=int, default=42) | |
parser.add_argument("--use_patch", action="store_true", help="apply repaint during refinement to patch up the missing regions") | |
parser.add_argument("--use_multiple_objects", action="store_true", help="operate on multiple objects") | |
parser.add_argument("--use_principle", action="store_true", help="poperate on multiple objects") | |
parser.add_argument("--use_shapenet", action="store_true", help="operate on ShapeNet objects") | |
parser.add_argument("--use_objaverse", action="store_true", help="operate on Objaverse objects") | |
parser.add_argument("--use_unnormalized", action="store_true", help="save unnormalized mesh") | |
parser.add_argument("--add_view_to_prompt", action="store_true", help="add view information to the prompt") | |
parser.add_argument("--post_process", action="store_true", help="post processing the texture") | |
parser.add_argument("--smooth_mask", action="store_true", help="smooth the diffusion mask") | |
parser.add_argument("--force", action="store_true", help="forcefully generate more image") | |
# negative options | |
parser.add_argument("--no_repaint", action="store_true", help="do NOT apply repaint") | |
parser.add_argument("--no_update", action="store_true", help="do NOT apply update") | |
# device parameters | |
parser.add_argument("--device", type=str, choices=["a6000", "2080"], default="a6000") | |
# camera parameters NOTE need careful tuning!!! | |
parser.add_argument("--test_camera", action="store_true") | |
parser.add_argument("--dist", type=float, default=1, | |
help="distance to the camera from the object") | |
parser.add_argument("--elev", type=float, default=0, | |
help="the angle between the vector from the object to the camera and the horizontal plane") | |
parser.add_argument("--azim", type=float, default=180, | |
help="the angle between the vector from the object to the camera and the vertical plane") | |
parser.add_argument("--uv_size", type=int, default=1000) | |
parser.add_argument("--image_size", type=int, default=768) | |
args = parser.parse_args() | |
if args.device == "a6000": | |
setattr(args, "render_simple_factor", 12) | |
setattr(args, "fragment_k", 1) | |
setattr(args, "image_size", 768) | |
setattr(args, "uv_size", 3000) | |
else: | |
setattr(args, "render_simple_factor", 4) | |
setattr(args, "fragment_k", 1) | |
setattr(args, "image_size", args.image_size) | |
setattr(args, "uv_size", args.uv_size) | |
return args | |
def text2tex_call(args): | |
tik = time.time() | |
# save | |
output_dir = os.path.join( | |
args.output_dir, | |
"{}-{}-{}-{}-{}-{}-{}-{}-{}".format( | |
str(args.seed), | |
args.viewpoint_mode[0]+str(args.num_viewpoints), | |
args.update_mode[0]+str(args.update_steps), | |
'd'+str(args.ddim_steps), | |
str(args.new_strength), | |
str(args.update_strength), | |
str(args.view_threshold), | |
'uv'+str(args.uv_size), | |
'img'+str(args.image_size), | |
), | |
) | |
if args.no_repaint: output_dir += "-norepaint" | |
if args.no_update: output_dir += "-noupdate" | |
os.makedirs(output_dir, exist_ok=True) | |
print("=> OUTPUT_DIR:", output_dir) | |
# init resources | |
# init mesh | |
mesh, _, faces, aux, principle_directions, mesh_center, mesh_scale = init_mesh( | |
os.path.join(args.input_dir, args.obj_file), | |
os.path.join(output_dir, args.obj_file), | |
DEVICE | |
) | |
# gradient texture | |
init_texture = Image.open("./text2tex/samples/textures/dummy.png").convert("RGB").resize((args.uv_size, args.uv_size)) | |
# HACK adjust UVs for multiple materials | |
if args.use_multiple_objects: | |
new_verts_uvs, init_texture = adjust_uv_map(faces, aux, init_texture, args.uv_size) | |
else: | |
new_verts_uvs = aux.verts_uvs | |
# update the mesh | |
mesh.textures = TexturesUV( | |
maps=transforms.ToTensor()(init_texture)[None, ...].permute(0, 2, 3, 1).to(DEVICE), | |
faces_uvs=faces.textures_idx[None, ...], | |
verts_uvs=new_verts_uvs[None, ...] | |
) | |
# back-projected faces | |
exist_texture = torch.from_numpy(np.zeros([args.uv_size, args.uv_size]).astype(np.float32)).to(DEVICE) | |
# initialize viewpoints | |
# including: principle viewpoints for generation + refinement viewpoints for updating | |
( | |
dist_list, | |
elev_list, | |
azim_list, | |
sector_list, | |
view_punishments | |
) = init_viewpoints(args.viewpoint_mode, args.num_viewpoints, args.dist, args.elev, principle_directions, | |
use_principle=True, | |
use_shapenet=args.use_shapenet, | |
use_objaverse=args.use_objaverse) | |
# save args | |
save_args(args, output_dir) | |
# initialize depth2image model | |
controlnet, ddim_sampler = get_controlnet_depth() | |
# ------------------- OPERATION ZONE BELOW ------------------------ | |
# 1. generate texture with RePaint | |
# NOTE no update / refinement | |
generate_dir = os.path.join(output_dir, "generate") | |
os.makedirs(generate_dir, exist_ok=True) | |
update_dir = os.path.join(output_dir, "update") | |
os.makedirs(update_dir, exist_ok=True) | |
init_image_dir = os.path.join(generate_dir, "rendering") | |
os.makedirs(init_image_dir, exist_ok=True) | |
normal_map_dir = os.path.join(generate_dir, "normal") | |
os.makedirs(normal_map_dir, exist_ok=True) | |
mask_image_dir = os.path.join(generate_dir, "mask") | |
os.makedirs(mask_image_dir, exist_ok=True) | |
depth_map_dir = os.path.join(generate_dir, "depth") | |
os.makedirs(depth_map_dir, exist_ok=True) | |
similarity_map_dir = os.path.join(generate_dir, "similarity") | |
os.makedirs(similarity_map_dir, exist_ok=True) | |
inpainted_image_dir = os.path.join(generate_dir, "inpainted") | |
os.makedirs(inpainted_image_dir, exist_ok=True) | |
mesh_dir = os.path.join(generate_dir, "mesh") | |
os.makedirs(mesh_dir, exist_ok=True) | |
interm_dir = os.path.join(generate_dir, "intermediate") | |
os.makedirs(interm_dir, exist_ok=True) | |
# prepare viewpoints and cache | |
NUM_PRINCIPLE = 10 if args.use_shapenet or args.use_objaverse else 6 | |
pre_dist_list = dist_list[:NUM_PRINCIPLE] | |
pre_elev_list = elev_list[:NUM_PRINCIPLE] | |
pre_azim_list = azim_list[:NUM_PRINCIPLE] | |
pre_sector_list = sector_list[:NUM_PRINCIPLE] | |
pre_view_punishments = view_punishments[:NUM_PRINCIPLE] | |
pre_similarity_texture_cache = build_similarity_texture_cache_for_all_views(mesh, faces, new_verts_uvs, | |
pre_dist_list, pre_elev_list, pre_azim_list, | |
args.image_size, args.image_size * args.render_simple_factor, args.uv_size, args.fragment_k, | |
DEVICE | |
) | |
# start generation | |
print("=> start generating texture...") | |
start_time = time.time() | |
for view_idx in range(NUM_PRINCIPLE): | |
print("=> processing view {}...".format(view_idx)) | |
# sequentially pop the viewpoints | |
dist, elev, azim, sector = pre_dist_list[view_idx], pre_elev_list[view_idx], pre_azim_list[view_idx], pre_sector_list[view_idx] | |
prompt = " the {} view of {}".format(sector, args.prompt) if args.add_view_to_prompt else args.prompt | |
print("=> generating image for prompt: {}...".format(prompt)) | |
# 1.1. render and build masks | |
( | |
view_score, | |
renderer, cameras, fragments, | |
init_image, normal_map, depth_map, | |
init_images_tensor, normal_maps_tensor, depth_maps_tensor, similarity_tensor, | |
keep_mask_image, update_mask_image, generate_mask_image, | |
keep_mask_tensor, update_mask_tensor, generate_mask_tensor, all_mask_tensor, quad_mask_tensor, | |
) = render_one_view_and_build_masks(dist, elev, azim, | |
view_idx, view_idx, view_punishments, # => actual view idx and the sequence idx | |
pre_similarity_texture_cache, exist_texture, | |
mesh, faces, new_verts_uvs, | |
args.image_size, args.fragment_k, | |
init_image_dir, mask_image_dir, normal_map_dir, depth_map_dir, similarity_map_dir, | |
DEVICE, save_intermediate=True, smooth_mask=args.smooth_mask, view_threshold=args.view_threshold | |
) | |
# 1.2. generate missing region | |
# NOTE first view still gets the mask for consistent ablations | |
if args.no_repaint and view_idx != 0: | |
actual_generate_mask_image = Image.fromarray((np.ones_like(np.array(generate_mask_image)) * 255.).astype(np.uint8)) | |
else: | |
actual_generate_mask_image = generate_mask_image | |
print("=> generate for view {}".format(view_idx)) | |
generate_image, generate_image_before, generate_image_after = apply_controlnet_depth(controlnet, ddim_sampler, | |
init_image.convert("RGBA"), prompt, args.new_strength, args.ddim_steps, | |
actual_generate_mask_image, keep_mask_image, depth_maps_tensor.permute(1, 2, 0).repeat(1, 1, 3).cpu().numpy(), | |
args.a_prompt, args.n_prompt, args.guidance_scale, args.seed, args.eta, 1, DEVICE, args.blend) | |
generate_image.save(os.path.join(inpainted_image_dir, "{}.png".format(view_idx))) | |
generate_image_before.save(os.path.join(inpainted_image_dir, "{}_before.png".format(view_idx))) | |
generate_image_after.save(os.path.join(inpainted_image_dir, "{}_after.png".format(view_idx))) | |
# 1.2.2 back-project and create texture | |
# NOTE projection mask = generate mask | |
init_texture, project_mask_image, exist_texture = backproject_from_image( | |
mesh, faces, new_verts_uvs, cameras, | |
generate_image, generate_mask_image, generate_mask_image, init_texture, exist_texture, | |
args.image_size * args.render_simple_factor, args.uv_size, args.fragment_k, | |
DEVICE | |
) | |
project_mask_image.save(os.path.join(mask_image_dir, "{}_project.png".format(view_idx))) | |
# update the mesh | |
mesh.textures = TexturesUV( | |
maps=transforms.ToTensor()(init_texture)[None, ...].permute(0, 2, 3, 1).to(DEVICE), | |
faces_uvs=faces.textures_idx[None, ...], | |
verts_uvs=new_verts_uvs[None, ...] | |
) | |
# 1.2.3. re: render | |
# NOTE only the rendered image is needed - masks should be re-used | |
( | |
view_score, | |
renderer, cameras, fragments, | |
init_image, *_, | |
) = render_one_view_and_build_masks(dist, elev, azim, | |
view_idx, view_idx, view_punishments, # => actual view idx and the sequence idx | |
pre_similarity_texture_cache, exist_texture, | |
mesh, faces, new_verts_uvs, | |
args.image_size, args.fragment_k, | |
init_image_dir, mask_image_dir, normal_map_dir, depth_map_dir, similarity_map_dir, | |
DEVICE, save_intermediate=False, smooth_mask=args.smooth_mask, view_threshold=args.view_threshold | |
) | |
# 1.3. update blurry region | |
# only when: 1) use update flag; 2) there are contents to update; 3) there are enough contexts. | |
if not args.no_update and update_mask_tensor.sum() > 0 and update_mask_tensor.sum() / (all_mask_tensor.sum()) > 0.05: | |
print("=> update {} pixels for view {}".format(update_mask_tensor.sum().int(), view_idx)) | |
diffused_image, diffused_image_before, diffused_image_after = apply_controlnet_depth(controlnet, ddim_sampler, | |
init_image.convert("RGBA"), prompt, args.update_strength, args.ddim_steps, | |
update_mask_image, keep_mask_image, depth_maps_tensor.permute(1, 2, 0).repeat(1, 1, 3).cpu().numpy(), | |
args.a_prompt, args.n_prompt, args.guidance_scale, args.seed, args.eta, 1, DEVICE, args.blend) | |
diffused_image.save(os.path.join(inpainted_image_dir, "{}_update.png".format(view_idx))) | |
diffused_image_before.save(os.path.join(inpainted_image_dir, "{}_update_before.png".format(view_idx))) | |
diffused_image_after.save(os.path.join(inpainted_image_dir, "{}_update_after.png".format(view_idx))) | |
# 1.3.2. back-project and create texture | |
# NOTE projection mask = generate mask | |
init_texture, project_mask_image, exist_texture = backproject_from_image( | |
mesh, faces, new_verts_uvs, cameras, | |
diffused_image, update_mask_image, update_mask_image, init_texture, exist_texture, | |
args.image_size * args.render_simple_factor, args.uv_size, args.fragment_k, | |
DEVICE | |
) | |
# update the mesh | |
mesh.textures = TexturesUV( | |
maps=transforms.ToTensor()(init_texture)[None, ...].permute(0, 2, 3, 1).to(DEVICE), | |
faces_uvs=faces.textures_idx[None, ...], | |
verts_uvs=new_verts_uvs[None, ...] | |
) | |
# 1.4. save generated assets | |
# save backprojected OBJ file | |
save_backproject_obj( | |
mesh_dir, "{}.obj".format(view_idx), | |
mesh_scale * mesh.verts_packed() + mesh_center if args.use_unnormalized else mesh.verts_packed(), | |
faces.verts_idx, new_verts_uvs, faces.textures_idx, init_texture, | |
DEVICE | |
) | |
# save the intermediate view | |
inter_images_tensor, *_ = render(mesh, renderer) | |
inter_image = inter_images_tensor[0].cpu() | |
inter_image = inter_image.permute(2, 0, 1) | |
inter_image = transforms.ToPILImage()(inter_image).convert("RGB") | |
inter_image.save(os.path.join(interm_dir, "{}.png".format(view_idx))) | |
# save texture mask | |
exist_texture_image = exist_texture * 255. | |
exist_texture_image = Image.fromarray(exist_texture_image.cpu().numpy().astype(np.uint8)).convert("L") | |
exist_texture_image.save(os.path.join(mesh_dir, "{}_texture_mask.png".format(view_idx))) | |
print("=> total generate time: {} s".format(time.time() - start_time)) | |
# visualize viewpoints | |
visualize_principle_viewpoints(output_dir, pre_dist_list, pre_elev_list, pre_azim_list) | |
# 2. update texture with RePaint | |
if args.update_steps > 0: | |
update_dir = os.path.join(output_dir, "update") | |
os.makedirs(update_dir, exist_ok=True) | |
init_image_dir = os.path.join(update_dir, "rendering") | |
os.makedirs(init_image_dir, exist_ok=True) | |
normal_map_dir = os.path.join(update_dir, "normal") | |
os.makedirs(normal_map_dir, exist_ok=True) | |
mask_image_dir = os.path.join(update_dir, "mask") | |
os.makedirs(mask_image_dir, exist_ok=True) | |
depth_map_dir = os.path.join(update_dir, "depth") | |
os.makedirs(depth_map_dir, exist_ok=True) | |
similarity_map_dir = os.path.join(update_dir, "similarity") | |
os.makedirs(similarity_map_dir, exist_ok=True) | |
inpainted_image_dir = os.path.join(update_dir, "inpainted") | |
os.makedirs(inpainted_image_dir, exist_ok=True) | |
mesh_dir = os.path.join(update_dir, "mesh") | |
os.makedirs(mesh_dir, exist_ok=True) | |
interm_dir = os.path.join(update_dir, "intermediate") | |
os.makedirs(interm_dir, exist_ok=True) | |
dist_list = dist_list[NUM_PRINCIPLE:] | |
elev_list = elev_list[NUM_PRINCIPLE:] | |
azim_list = azim_list[NUM_PRINCIPLE:] | |
sector_list = sector_list[NUM_PRINCIPLE:] | |
view_punishments = view_punishments[NUM_PRINCIPLE:] | |
similarity_texture_cache = build_similarity_texture_cache_for_all_views(mesh, faces, new_verts_uvs, | |
dist_list, elev_list, azim_list, | |
args.image_size, args.image_size * args.render_simple_factor, args.uv_size, args.fragment_k, | |
DEVICE | |
) | |
selected_view_ids = [] | |
print("=> start updating...") | |
start_time = time.time() | |
for view_idx in range(args.update_steps): | |
print("=> processing view {}...".format(view_idx)) | |
# 2.1. render and build masks | |
# heuristically select the viewpoints | |
dist, elev, azim, sector, selected_view_ids, view_punishments = select_viewpoint( | |
selected_view_ids, view_punishments, | |
args.update_mode, dist_list, elev_list, azim_list, sector_list, view_idx, | |
similarity_texture_cache, exist_texture, | |
mesh, faces, new_verts_uvs, | |
args.image_size, args.fragment_k, | |
init_image_dir, mask_image_dir, normal_map_dir, depth_map_dir, similarity_map_dir, | |
DEVICE, False | |
) | |
( | |
view_score, | |
renderer, cameras, fragments, | |
init_image, normal_map, depth_map, | |
init_images_tensor, normal_maps_tensor, depth_maps_tensor, similarity_tensor, | |
old_mask_image, update_mask_image, generate_mask_image, | |
old_mask_tensor, update_mask_tensor, generate_mask_tensor, all_mask_tensor, quad_mask_tensor, | |
) = render_one_view_and_build_masks(dist, elev, azim, | |
selected_view_ids[-1], view_idx, view_punishments, # => actual view idx and the sequence idx | |
similarity_texture_cache, exist_texture, | |
mesh, faces, new_verts_uvs, | |
args.image_size, args.fragment_k, | |
init_image_dir, mask_image_dir, normal_map_dir, depth_map_dir, similarity_map_dir, | |
DEVICE, save_intermediate=True, smooth_mask=args.smooth_mask, view_threshold=args.view_threshold | |
) | |
# # -------------------- OPTION ZONE ------------------------ | |
# # still generate for missing regions during refinement | |
# # NOTE this could take significantly more time to complete. | |
# if args.use_patch: | |
# # 2.2.1 generate missing region | |
# prompt = " the {} view of {}".format(sector, args.prompt) if args.add_view_to_prompt else args.prompt | |
# print("=> generating image for prompt: {}...".format(prompt)) | |
# if args.no_repaint: | |
# generate_mask_image = Image.fromarray((np.ones_like(np.array(generate_mask_image)) * 255.).astype(np.uint8)) | |
# print("=> generate {} pixels for view {}".format(generate_mask_tensor.sum().int(), view_idx)) | |
# generate_image, generate_image_before, generate_image_after = apply_controlnet_depth(controlnet, ddim_sampler, | |
# init_image.convert("RGBA"), prompt, args.new_strength, args.ddim_steps, | |
# generate_mask_image, keep_mask_image, depth_maps_tensor.permute(1, 2, 0).repeat(1, 1, 3).cpu().numpy(), | |
# args.a_prompt, args.n_prompt, args.guidance_scale, args.seed, args.eta, 1, DEVICE, args.blend) | |
# generate_image.save(os.path.join(inpainted_image_dir, "{}_new.png".format(view_idx))) | |
# generate_image_before.save(os.path.join(inpainted_image_dir, "{}_new_before.png".format(view_idx))) | |
# generate_image_after.save(os.path.join(inpainted_image_dir, "{}_new_after.png".format(view_idx))) | |
# # 2.2.2. back-project and create texture | |
# # NOTE projection mask = generate mask | |
# init_texture, project_mask_image, exist_texture = backproject_from_image( | |
# mesh, faces, new_verts_uvs, cameras, | |
# generate_image, generate_mask_image, generate_mask_image, init_texture, exist_texture, | |
# args.image_size * args.render_simple_factor, args.uv_size, args.fragment_k, | |
# DEVICE | |
# ) | |
# project_mask_image.save(os.path.join(mask_image_dir, "{}_new_project.png".format(view_idx))) | |
# # update the mesh | |
# mesh.textures = TexturesUV( | |
# maps=transforms.ToTensor()(init_texture)[None, ...].permute(0, 2, 3, 1).to(DEVICE), | |
# faces_uvs=faces.textures_idx[None, ...], | |
# verts_uvs=new_verts_uvs[None, ...] | |
# ) | |
# # 2.2.4. save generated assets | |
# # save backprojected OBJ file | |
# save_backproject_obj( | |
# mesh_dir, "{}_new.obj".format(view_idx), | |
# mesh.verts_packed(), faces.verts_idx, new_verts_uvs, faces.textures_idx, init_texture, | |
# DEVICE | |
# ) | |
# # -------------------- OPTION ZONE ------------------------ | |
# 2.2. update existing region | |
prompt = " the {} view of {}".format(sector, args.prompt) if args.add_view_to_prompt else args.prompt | |
print("=> updating image for prompt: {}...".format(prompt)) | |
if not args.no_update and update_mask_tensor.sum() > 0 and update_mask_tensor.sum() / (all_mask_tensor.sum()) > 0.05: | |
print("=> update {} pixels for view {}".format(update_mask_tensor.sum().int(), view_idx)) | |
update_image, update_image_before, update_image_after = apply_controlnet_depth(controlnet, ddim_sampler, | |
init_image.convert("RGBA"), prompt, args.update_strength, args.ddim_steps, | |
update_mask_image, old_mask_image, depth_maps_tensor.permute(1, 2, 0).repeat(1, 1, 3).cpu().numpy(), | |
args.a_prompt, args.n_prompt, args.guidance_scale, args.seed, args.eta, 1, DEVICE, args.blend) | |
update_image.save(os.path.join(inpainted_image_dir, "{}.png".format(view_idx))) | |
update_image_before.save(os.path.join(inpainted_image_dir, "{}_before.png".format(view_idx))) | |
update_image_after.save(os.path.join(inpainted_image_dir, "{}_after.png".format(view_idx))) | |
else: | |
print("=> nothing to update for view {}".format(view_idx)) | |
update_image = init_image | |
old_mask_tensor += update_mask_tensor | |
update_mask_tensor[update_mask_tensor == 1] = 0 # HACK nothing to update | |
old_mask_image = transforms.ToPILImage()(old_mask_tensor) | |
update_mask_image = transforms.ToPILImage()(update_mask_tensor) | |
# 2.3. back-project and create texture | |
# NOTE projection mask = update mask | |
init_texture, project_mask_image, exist_texture = backproject_from_image( | |
mesh, faces, new_verts_uvs, cameras, | |
update_image, update_mask_image, update_mask_image, init_texture, exist_texture, | |
args.image_size * args.render_simple_factor, args.uv_size, args.fragment_k, | |
DEVICE | |
) | |
project_mask_image.save(os.path.join(mask_image_dir, "{}_project.png".format(view_idx))) | |
# update the mesh | |
mesh.textures = TexturesUV( | |
maps=transforms.ToTensor()(init_texture)[None, ...].permute(0, 2, 3, 1).to(DEVICE), | |
faces_uvs=faces.textures_idx[None, ...], | |
verts_uvs=new_verts_uvs[None, ...] | |
) | |
# 2.4. save generated assets | |
# save backprojected OBJ file | |
save_backproject_obj( | |
mesh_dir, "{}.obj".format(view_idx), | |
mesh_scale * mesh.verts_packed() + mesh_center if args.use_unnormalized else mesh.verts_packed(), | |
faces.verts_idx, new_verts_uvs, faces.textures_idx, init_texture, | |
DEVICE | |
) | |
# save the intermediate view | |
inter_images_tensor, *_ = render(mesh, renderer) | |
inter_image = inter_images_tensor[0].cpu() | |
inter_image = inter_image.permute(2, 0, 1) | |
inter_image = transforms.ToPILImage()(inter_image).convert("RGB") | |
inter_image.save(os.path.join(interm_dir, "{}.png".format(view_idx))) | |
# save texture mask | |
exist_texture_image = exist_texture * 255. | |
exist_texture_image = Image.fromarray(exist_texture_image.cpu().numpy().astype(np.uint8)).convert("L") | |
exist_texture_image.save(os.path.join(mesh_dir, "{}_texture_mask.png".format(view_idx))) | |
print("=> total update time: {} s".format(time.time() - start_time)) | |
# post-process | |
if args.post_process: | |
del controlnet | |
del ddim_sampler | |
inpainting = get_inpainting(DEVICE) | |
post_texture = apply_inpainting_postprocess(inpainting, | |
init_texture, 1-exist_texture[None, :, :, None], "", args.uv_size, args.uv_size, DEVICE) | |
save_backproject_obj( | |
mesh_dir, "{}_post.obj".format(view_idx), | |
mesh_scale * mesh.verts_packed() + mesh_center if args.use_unnormalized else mesh.verts_packed(), | |
faces.verts_idx, new_verts_uvs, faces.textures_idx, post_texture, | |
DEVICE | |
) | |
# save viewpoints | |
save_viewpoints(args, output_dir, dist_list, elev_list, azim_list, selected_view_ids) | |
# visualize viewpoints | |
visualize_refinement_viewpoints(output_dir, selected_view_ids, dist_list, elev_list, azim_list) | |
# output total time used and save to the output directory | |
print("=> total time used: {} s".format(time.time() - tik)) | |
with open(os.path.join(output_dir, "time.txt"), "w") as f: | |
f.write("total time used: {} s".format(time.time() - tik)) | |
return output_dir | |
if __name__ == "__main__": | |
args = init_args() | |
text2tex_call(args) |