Spaces:
mashroo
/
Runtime error

CRM / inference.py
YoussefAnso's picture
Enhance color processing in generate3d function by applying clipping and gamma correction
bd713a9
raw
history blame
3.96 kB
import numpy as np
import torch
import time
from util.utils import get_tri
import tempfile
from util.renderer import Renderer
import os
from PIL import Image
import trimesh
from scipy.spatial import cKDTree
def generate3d(model, rgb, ccm, device):
model.renderer = Renderer(
tet_grid_size=model.tet_grid_size,
camera_angle_num=model.camera_angle_num,
scale=model.input.scale,
geo_type=model.geo_type
)
color_tri = torch.from_numpy(rgb) / 255
xyz_tri = torch.from_numpy(ccm[:, :, (2, 1, 0)]) / 255
color = color_tri.permute(2, 0, 1)
xyz = xyz_tri.permute(2, 0, 1)
def get_imgs(color):
color_list = []
color_list.append(color[:, :, 256 * 5:256 * (1 + 5)])
for i in range(0, 5):
color_list.append(color[:, :, 256 * i:256 * (1 + i)])
return torch.stack(color_list, dim=0)
triplane_color = get_imgs(color).permute(0, 2, 3, 1).unsqueeze(0).to(device)
color = get_imgs(color)
xyz = get_imgs(xyz)
color = get_tri(color, dim=0, blender=True, scale=1).unsqueeze(0)
xyz = get_tri(xyz, dim=0, blender=True, scale=1, fix=True).unsqueeze(0)
triplane = torch.cat([color, xyz], dim=1).to(device)
model.eval()
if model.denoising:
tnew = 20
tnew = torch.randint(tnew, tnew + 1, [triplane.shape[0]], dtype=torch.long, device=triplane.device)
noise_new = torch.randn_like(triplane) * 0.5 + 0.5
triplane = model.scheduler.add_noise(triplane, noise_new, tnew)
with torch.no_grad():
triplane_feature2 = model.unet2(triplane, tnew)
else:
triplane_feature2 = model.unet2(triplane)
with torch.no_grad():
data_config = {
'resolution': [1024, 1024],
"triview_color": triplane_color.to(device),
}
verts, faces = model.decode(data_config, triplane_feature2)
data_config['verts'] = verts[0]
data_config['faces'] = faces
from kiui.mesh_utils import clean_mesh
orig_verts = data_config['verts'].squeeze().cpu().numpy()
# Extract per-vertex color for the original mesh
orig_verts_tensor = data_config['verts'].unsqueeze(0) # shape [1, N, 3]
with torch.no_grad():
dec_verts = model.decoder(triplane_feature2, orig_verts_tensor)
orig_colors = model.rgbMlp(dec_verts).squeeze().detach().cpu().numpy()
print('orig_colors min/max BEFORE scaling:', orig_colors.min(), orig_colors.max())
# Comment out the scaling below if orig_colors is already in [0, 1]
# orig_colors = (orig_colors * 0.5 + 0.5).clip(0, 1) # scale to [0, 1]
print('orig_colors min/max AFTER scaling:', orig_colors.min(), orig_colors.max())
orig_colors = np.clip(orig_colors, 0, 1)
orig_colors = np.power(orig_colors, 1/2.2)
verts, faces = clean_mesh(
orig_verts.astype(np.float32),
data_config['faces'].squeeze().cpu().numpy().astype(np.int32),
repair=False, remesh=True, remesh_size=0.01, remesh_iters=1
)
data_config['verts'] = torch.from_numpy(verts).to(device).contiguous()
data_config['faces'] = torch.from_numpy(faces).to(device).contiguous()
# # Build KDTree from original verts
# tree = cKDTree(orig_verts)
# # For each new vertex, find the nearest old vertex and copy its color
# k = 3
# dists, idxs = tree.query(verts, k=k)
# # Use only the nearest neighbor for color assignment
# new_colors = orig_colors[idxs[:, 0]]
# Create the new mesh with colors
mesh = trimesh.Trimesh(vertices=verts, faces=faces, vertex_colors=orig_colors)
# === Export OBJ/MTL/PNG ===
obj_path = tempfile.NamedTemporaryFile(suffix=".obj", delete=False).name
base_path = obj_path[:-4] # remove .obj
texture_path = base_path + ".png"
mtl_path = base_path + ".mtl"
model.export_mesh(data_config, base_path, tri_fea_2=triplane_feature2) # writes .obj
return obj_path