|
import os |
|
import time |
|
import glob |
|
import json |
|
import yaml |
|
import torch |
|
import trimesh |
|
import argparse |
|
import mesh2sdf.core |
|
import numpy as np |
|
import skimage.measure |
|
import seaborn as sns |
|
from scipy.spatial.transform import Rotation |
|
from mesh_to_sdf import get_surface_point_cloud |
|
from accelerate.utils import set_seed |
|
from accelerate import Accelerator |
|
from huggingface_hub.file_download import hf_hub_download |
|
from huggingface_hub import list_repo_files |
|
|
|
from primitive_anything.utils import path_mkdir, count_parameters |
|
from primitive_anything.utils.logger import print_log |
|
|
|
os.environ['PYOPENGL_PLATFORM'] = 'egl' |
|
|
|
import spaces |
|
|
|
repo_id = "hyz317/PrimitiveAnything" |
|
all_files = list_repo_files(repo_id, revision="main") |
|
for file in all_files: |
|
if os.path.exists(file): |
|
continue |
|
hf_hub_download(repo_id, file, local_dir="./ckpt") |
|
hf_hub_download("Maikou/Michelangelo", "checkpoints/aligned_shape_latents/shapevae-256.ckpt", local_dir="./ckpt") |
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser(description='Process 3D model files') |
|
|
|
parser.add_argument( |
|
'--input', |
|
type=str, |
|
default='./data/demo_glb/', |
|
help='Input file or directory path (default: ./data/demo_glb/)' |
|
) |
|
|
|
parser.add_argument( |
|
'--log_path', |
|
type=str, |
|
default='./results/demo', |
|
help='Output directory path (default: results/demo)' |
|
) |
|
|
|
return parser.parse_args() |
|
|
|
def get_input_files(input_path): |
|
if os.path.isfile(input_path): |
|
return [input_path] |
|
elif os.path.isdir(input_path): |
|
return glob.glob(os.path.join(input_path, '*')) |
|
else: |
|
raise ValueError(f"Input path {input_path} is neither a file nor a directory") |
|
|
|
args = parse_args() |
|
|
|
|
|
LOG_PATH = args.log_path |
|
os.makedirs(LOG_PATH, exist_ok=True) |
|
|
|
print(f"Output directory: {LOG_PATH}") |
|
|
|
CODE_SHAPE = { |
|
0: 'SM_GR_BS_CubeBevel_001.ply', |
|
1: 'SM_GR_BS_SphereSharp_001.ply', |
|
2: 'SM_GR_BS_CylinderSharp_001.ply', |
|
} |
|
|
|
shapename_map = { |
|
'SM_GR_BS_CubeBevel_001.ply': 1101002001034001, |
|
'SM_GR_BS_SphereSharp_001.ply': 1101002001034010, |
|
'SM_GR_BS_CylinderSharp_001.ply': 1101002001034002, |
|
} |
|
|
|
|
|
bs_dir = 'data/basic_shapes_norm' |
|
config_path = './configs/infer.yml' |
|
AR_checkpoint_path = './ckpt/mesh-transformer.ckpt.60.pt' |
|
temperature= 0.0 |
|
|
|
mesh_bs = {} |
|
for bs_path in glob.glob(os.path.join(bs_dir, '*.ply')): |
|
bs_name = os.path.basename(bs_path) |
|
bs = trimesh.load(bs_path) |
|
bs.visual.uv = np.clip(bs.visual.uv, 0, 1) |
|
bs.visual = bs.visual.to_color() |
|
mesh_bs[bs_name] = bs |
|
|
|
def create_model(cfg_model): |
|
kwargs = cfg_model |
|
name = kwargs.pop('name') |
|
model = get_model(name)(**kwargs) |
|
print_log("Model '{}' init: nb_params={:,}, kwargs={}".format(name, count_parameters(model), kwargs)) |
|
return model |
|
|
|
from primitive_anything.primitive_transformer import PrimitiveTransformerDiscrete |
|
def get_model(name): |
|
return { |
|
'discrete': PrimitiveTransformerDiscrete, |
|
}[name] |
|
|
|
with open(config_path, mode='r') as fp: |
|
AR_train_cfg = yaml.load(fp, Loader=yaml.FullLoader) |
|
|
|
AR_checkpoint = torch.load(AR_checkpoint_path) |
|
|
|
transformer = create_model(AR_train_cfg['model']) |
|
transformer.load_state_dict(AR_checkpoint) |
|
|
|
device = torch.device('cuda') |
|
accelerator = Accelerator( |
|
mixed_precision='fp16', |
|
) |
|
transformer = accelerator.prepare(transformer) |
|
transformer.eval() |
|
transformer.bs_pc = transformer.bs_pc.cuda() |
|
transformer.rotation_matrix_align_coord = transformer.rotation_matrix_align_coord.cuda() |
|
print('model loaded to device') |
|
|
|
|
|
def sample_surface_points(mesh, number_of_points=500000, surface_point_method='scan', sign_method='normal', |
|
scan_count=100, scan_resolution=400, sample_point_count=10000000, return_gradients=False, |
|
return_surface_pc_normals=False, normalized=False): |
|
sample_start = time.time() |
|
if surface_point_method == 'sample' and sign_method == 'depth': |
|
print("Incompatible methods for sampling points and determining sign, using sign_method='normal' instead.") |
|
sign_method = 'normal' |
|
|
|
surface_start = time.time() |
|
bound_radius = 1 if normalized else None |
|
surface_point_cloud = get_surface_point_cloud(mesh, surface_point_method, bound_radius, scan_count, scan_resolution, |
|
sample_point_count, |
|
calculate_normals=sign_method == 'normal' or return_gradients) |
|
|
|
surface_end = time.time() |
|
print('surface point cloud time cost :', surface_end - surface_start) |
|
|
|
normal_start = time.time() |
|
if return_surface_pc_normals: |
|
rng = np.random.default_rng() |
|
assert surface_point_cloud.points.shape[0] == surface_point_cloud.normals.shape[0] |
|
indices = rng.choice(surface_point_cloud.points.shape[0], number_of_points, replace=True) |
|
points = surface_point_cloud.points[indices] |
|
normals = surface_point_cloud.normals[indices] |
|
surface_points = np.concatenate([points, normals], axis=-1) |
|
else: |
|
surface_points = surface_point_cloud.get_random_surface_points(number_of_points, use_scans=True) |
|
normal_end = time.time() |
|
print('normal time cost :', normal_end - normal_start) |
|
sample_end = time.time() |
|
print('sample surface point time cost :', sample_end - sample_start) |
|
return surface_points |
|
|
|
|
|
def normalize_vertices(vertices, scale=0.9): |
|
bbmin, bbmax = vertices.min(0), vertices.max(0) |
|
center = (bbmin + bbmax) * 0.5 |
|
scale = 2.0 * scale / (bbmax - bbmin).max() |
|
vertices = (vertices - center) * scale |
|
return vertices, center, scale |
|
|
|
|
|
def export_to_watertight(normalized_mesh, octree_depth: int = 7): |
|
""" |
|
Convert the non-watertight mesh to watertight. |
|
|
|
Args: |
|
input_path (str): normalized path |
|
octree_depth (int): |
|
|
|
Returns: |
|
mesh(trimesh.Trimesh): watertight mesh |
|
|
|
""" |
|
size = 2 ** octree_depth |
|
level = 2 / size |
|
|
|
scaled_vertices, to_orig_center, to_orig_scale = normalize_vertices(normalized_mesh.vertices) |
|
sdf = mesh2sdf.core.compute(scaled_vertices, normalized_mesh.faces, size=size) |
|
vertices, faces, normals, _ = skimage.measure.marching_cubes(np.abs(sdf), level) |
|
|
|
|
|
vertices = vertices / size * 2 - 1 |
|
vertices = vertices / to_orig_scale + to_orig_center |
|
mesh = trimesh.Trimesh(vertices, faces, normals=normals) |
|
|
|
return mesh |
|
|
|
|
|
def process_mesh_to_surface_pc(mesh_list, marching_cubes=False, dilated_offset=0.0, sample_num=10000): |
|
|
|
pc_normal_list = [] |
|
return_mesh_list = [] |
|
for mesh in mesh_list: |
|
if marching_cubes: |
|
mesh = export_to_watertight(mesh) |
|
print("MC over!") |
|
if dilated_offset > 0: |
|
new_vertices = mesh.vertices + mesh.vertex_normals * dilated_offset |
|
mesh.vertices = new_vertices |
|
print("dilate over!") |
|
|
|
mesh.merge_vertices() |
|
mesh.update_faces(mesh.unique_faces()) |
|
mesh.fix_normals() |
|
|
|
return_mesh_list.append(mesh) |
|
|
|
pc_normal = np.asarray(sample_surface_points(mesh, sample_num, return_surface_pc_normals=True)) |
|
pc_normal_list.append(pc_normal) |
|
print("process mesh success") |
|
return pc_normal_list, return_mesh_list |
|
|
|
|
|
|
|
def euler_to_quat(euler): |
|
return Rotation.from_euler('XYZ', euler, degrees=True).as_quat() |
|
|
|
def SRT_quat_to_matrix(scale, quat, translation): |
|
rotation_matrix = Rotation.from_quat(quat).as_matrix() |
|
transform_matrix = np.eye(4) |
|
transform_matrix[:3, :3] = rotation_matrix * scale |
|
transform_matrix[:3, 3] = translation |
|
return transform_matrix |
|
|
|
|
|
def write_output(primitives, name): |
|
out_json = {} |
|
out_json['operation'] = 0 |
|
out_json['type'] = 1 |
|
out_json['scene_id'] = None |
|
|
|
new_group = [] |
|
model_scene = trimesh.Scene() |
|
color_map = sns.color_palette("hls", primitives['type_code'].squeeze().shape[0]) |
|
color_map = (np.array(color_map) * 255).astype("uint8") |
|
for idx, (scale, rotation, translation, type_code) in enumerate(zip( |
|
primitives['scale'].squeeze().cpu().numpy(), |
|
primitives['rotation'].squeeze().cpu().numpy(), |
|
primitives['translation'].squeeze().cpu().numpy(), |
|
primitives['type_code'].squeeze().cpu().numpy() |
|
)): |
|
if type_code == -1: |
|
break |
|
bs_name = CODE_SHAPE[type_code] |
|
new_block = {} |
|
new_block['type_id'] = shapename_map[bs_name] |
|
new_block['data'] = {} |
|
new_block['data']['location'] = translation.tolist() |
|
new_block['data']['rotation'] = euler_to_quat(rotation).tolist() |
|
new_block['data']['scale'] = scale.tolist() |
|
new_block['data']['color'] = ['808080'] |
|
new_group.append(new_block) |
|
|
|
trans = SRT_quat_to_matrix(scale, euler_to_quat(rotation), translation) |
|
bs = mesh_bs[bs_name].copy().apply_transform(trans) |
|
new_vertex_colors = np.repeat(color_map[idx:idx+1], bs.visual.vertex_colors.shape[0], axis=0) |
|
bs.visual.vertex_colors[:, :3] = new_vertex_colors |
|
vertices = bs.vertices.copy() |
|
vertices[:, 1] = bs.vertices[:, 2] |
|
vertices[:, 2] = -bs.vertices[:, 1] |
|
bs.vertices = vertices |
|
model_scene.add_geometry(bs) |
|
out_json['group'] = new_group |
|
|
|
json_path = os.path.join(LOG_PATH, f'output_{name}.json') |
|
with open(json_path, 'w') as json_file: |
|
json.dump(out_json, json_file, indent=4) |
|
|
|
glb_path = os.path.join(LOG_PATH, f'output_{name}.glb') |
|
model_scene.export(glb_path) |
|
|
|
return glb_path, out_json |
|
|
|
|
|
@torch.no_grad() |
|
def do_inference(input_3d, dilated_offset=0.0, sample_seed=0, do_sampling=False, do_marching_cubes=False, postprocess='none'): |
|
t1 = time.time() |
|
set_seed(sample_seed) |
|
input_mesh = trimesh.load(input_3d, force='mesh') |
|
|
|
|
|
vertices = input_mesh.vertices |
|
bounds = np.array([vertices.min(axis=0), vertices.max(axis=0)]) |
|
vertices = vertices - (bounds[0] + bounds[1])[None, :] / 2 |
|
vertices = vertices / (bounds[1] - bounds[0]).max() * 1.6 |
|
input_mesh.vertices = vertices |
|
|
|
pc_list, mesh_list = process_mesh_to_surface_pc( |
|
[input_mesh], |
|
marching_cubes=do_marching_cubes, |
|
dilated_offset=dilated_offset |
|
) |
|
pc_normal = pc_list[0] |
|
mesh = mesh_list[0] |
|
|
|
pc_coor = pc_normal[:, :3] |
|
normals = pc_normal[:, 3:] |
|
|
|
if dilated_offset > 0: |
|
|
|
vertices = mesh.vertices |
|
bounds = np.array([vertices.min(axis=0), vertices.max(axis=0)]) |
|
vertices = vertices - (bounds[0] + bounds[1])[None, :] / 2 |
|
vertices = vertices / (bounds[1] - bounds[0]).max() * 1.6 |
|
mesh.vertices = vertices |
|
pc_coor = pc_coor - (bounds[0] + bounds[1])[None, :] / 2 |
|
pc_coor = pc_coor / (bounds[1] - bounds[0]).max() * 1.6 |
|
|
|
input_save_name = os.path.join(LOG_PATH, f'processed_{os.path.basename(input_3d)}') |
|
mesh.export(input_save_name) |
|
|
|
assert (np.linalg.norm(normals, axis=-1) > 0.99).all(), 'normals should be unit vectors, something wrong' |
|
normalized_pc_normal = np.concatenate([pc_coor, normals], axis=-1, dtype=np.float16) |
|
|
|
input_pc = torch.tensor(normalized_pc_normal, dtype=torch.float16, device=device)[None] |
|
|
|
with accelerator.autocast(): |
|
if postprocess == 'postprocess1': |
|
recon_primitives, mask = transformer.generate_w_recon_loss(pc=input_pc, temperature=temperature, single_directional=True) |
|
else: |
|
recon_primitives, mask = transformer.generate(pc=input_pc, temperature=temperature) |
|
|
|
output_glb, output_json = write_output(recon_primitives, os.path.basename(input_3d)[:-4]) |
|
|
|
return input_save_name, output_glb, output_json |
|
|
|
|
|
import gradio as gr |
|
|
|
@spaces.GPU |
|
def process_3d_model(input_3d, dilated_offset, do_marching_cubes, postprocess_method="postprocess1"): |
|
print(f"processing: {input_3d}") |
|
|
|
preprocess_model_obj, output_model_obj, output_model_json = do_inference( |
|
input_3d, |
|
dilated_offset=dilated_offset, |
|
do_marching_cubes=do_marching_cubes, |
|
postprocess=postprocess_method |
|
) |
|
return output_model_obj |
|
|
|
|
|
|
|
|
|
title = "3D Model Processing Demo" |
|
reminder = "Please upload your 3D model file and adjust parameters as needed." |
|
|
|
with gr.Blocks(title=title) as demo: |
|
|
|
gr.Markdown(f"# {title}") |
|
gr.Markdown(reminder) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
|
|
input_3d = gr.Model3D(label="Upload 3D Model File") |
|
dilated_offset = gr.Number(label="Dilated Offset", value=0.015) |
|
do_marching_cubes = gr.Checkbox(label="Perform Marching Cubes", value=True) |
|
submit_btn = gr.Button("Process Model") |
|
|
|
with gr.Column(): |
|
|
|
output = gr.Model3D(label="Primitive Assembly Predition") |
|
|
|
submit_btn.click( |
|
fn=process_3d_model, |
|
inputs=[input_3d, dilated_offset, do_marching_cubes], |
|
outputs=output |
|
) |
|
|
|
|
|
|
|
example_files = [ [f] for f in glob.glob('./data/demo_glb/*.glb') ] |
|
|
|
example = gr.Examples( |
|
examples=example_files, |
|
inputs=[input_3d], |
|
examples_per_page=14, |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(ssr_mode=False) |