import os import time import glob import json import yaml import torch import trimesh import argparse import mesh2sdf.core import numpy as np import skimage.measure import seaborn as sns from scipy.spatial.transform import Rotation from mesh_to_sdf import get_surface_point_cloud from accelerate.utils import set_seed from accelerate import Accelerator from huggingface_hub.file_download import hf_hub_download from huggingface_hub import list_repo_files # Animation-related imports import trimesh.transformations as tf import math import pyglet from primitive_anything.utils import path_mkdir, count_parameters from primitive_anything.utils.logger import print_log os.environ['PYOPENGL_PLATFORM'] = 'egl' import spaces repo_id = "hyz317/PrimitiveAnything" all_files = list_repo_files(repo_id, revision="main") for file in all_files: if os.path.exists(file): continue hf_hub_download(repo_id, file, local_dir="./ckpt") hf_hub_download("Maikou/Michelangelo", "checkpoints/aligned_shape_latents/shapevae-256.ckpt", local_dir="./ckpt") def parse_args(): parser = argparse.ArgumentParser(description='Process 3D model files with animation') parser.add_argument( '--input', type=str, default='./data/demo_glb/', help='Input file or directory path (default: ./data/demo_glb/)' ) parser.add_argument( '--log_path', type=str, default='./results/demo', help='Output directory path (default: results/demo)' ) parser.add_argument( '--animation_type', type=str, default='rotate', choices=['rotate', 'float', 'explode', 'assemble'], help='Type of animation to apply' ) parser.add_argument( '--animation_duration', type=float, default=3.0, help='Duration of animation in seconds' ) parser.add_argument( '--fps', type=int, default=30, help='Frames per second for animation' ) return parser.parse_args() def get_input_files(input_path): if os.path.isfile(input_path): return [input_path] elif os.path.isdir(input_path): return glob.glob(os.path.join(input_path, '*')) else: raise ValueError(f"Input path {input_path} is neither a file nor a directory") args = parse_args() LOG_PATH = args.log_path os.makedirs(LOG_PATH, exist_ok=True) print(f"Output directory: {LOG_PATH}") CODE_SHAPE = { 0: 'SM_GR_BS_CubeBevel_001.ply', 1: 'SM_GR_BS_SphereSharp_001.ply', 2: 'SM_GR_BS_CylinderSharp_001.ply', } shapename_map = { 'SM_GR_BS_CubeBevel_001.ply': 1101002001034001, 'SM_GR_BS_SphereSharp_001.ply': 1101002001034010, 'SM_GR_BS_CylinderSharp_001.ply': 1101002001034002, } #### config bs_dir = 'data/basic_shapes_norm' config_path = './configs/infer.yml' AR_checkpoint_path = './ckpt/mesh-transformer.ckpt.60.pt' temperature= 0.0 #### init model mesh_bs = {} for bs_path in glob.glob(os.path.join(bs_dir, '*.ply')): bs_name = os.path.basename(bs_path) bs = trimesh.load(bs_path) bs.visual.uv = np.clip(bs.visual.uv, 0, 1) bs.visual = bs.visual.to_color() mesh_bs[bs_name] = bs def create_model(cfg_model): kwargs = cfg_model name = kwargs.pop('name') model = get_model(name)(**kwargs) print_log("Model '{}' init: nb_params={:,}, kwargs={}".format(name, count_parameters(model), kwargs)) return model from primitive_anything.primitive_transformer import PrimitiveTransformerDiscrete def get_model(name): return { 'discrete': PrimitiveTransformerDiscrete, }[name] with open(config_path, mode='r') as fp: AR_train_cfg = yaml.load(fp, Loader=yaml.FullLoader) AR_checkpoint = torch.load(AR_checkpoint_path) transformer = create_model(AR_train_cfg['model']) transformer.load_state_dict(AR_checkpoint) device = torch.device('cuda') accelerator = Accelerator( mixed_precision='fp16', ) transformer = accelerator.prepare(transformer) transformer.eval() transformer.bs_pc = transformer.bs_pc.cuda() transformer.rotation_matrix_align_coord = transformer.rotation_matrix_align_coord.cuda() print('model loaded to device') def sample_surface_points(mesh, number_of_points=500000, surface_point_method='scan', sign_method='normal', scan_count=100, scan_resolution=400, sample_point_count=10000000, return_gradients=False, return_surface_pc_normals=False, normalized=False): sample_start = time.time() if surface_point_method == 'sample' and sign_method == 'depth': print("Incompatible methods for sampling points and determining sign, using sign_method='normal' instead.") sign_method = 'normal' surface_start = time.time() bound_radius = 1 if normalized else None surface_point_cloud = get_surface_point_cloud(mesh, surface_point_method, bound_radius, scan_count, scan_resolution, sample_point_count, calculate_normals=sign_method == 'normal' or return_gradients) surface_end = time.time() print('surface point cloud time cost :', surface_end - surface_start) normal_start = time.time() if return_surface_pc_normals: rng = np.random.default_rng() assert surface_point_cloud.points.shape[0] == surface_point_cloud.normals.shape[0] indices = rng.choice(surface_point_cloud.points.shape[0], number_of_points, replace=True) points = surface_point_cloud.points[indices] normals = surface_point_cloud.normals[indices] surface_points = np.concatenate([points, normals], axis=-1) else: surface_points = surface_point_cloud.get_random_surface_points(number_of_points, use_scans=True) normal_end = time.time() print('normal time cost :', normal_end - normal_start) sample_end = time.time() print('sample surface point time cost :', sample_end - sample_start) return surface_points def normalize_vertices(vertices, scale=0.9): bbmin, bbmax = vertices.min(0), vertices.max(0) center = (bbmin + bbmax) * 0.5 scale = 2.0 * scale / (bbmax - bbmin).max() vertices = (vertices - center) * scale return vertices, center, scale def export_to_watertight(normalized_mesh, octree_depth: int = 7): """ Convert the non-watertight mesh to watertight. Args: input_path (str): normalized path octree_depth (int): Returns: mesh(trimesh.Trimesh): watertight mesh """ size = 2 ** octree_depth level = 2 / size scaled_vertices, to_orig_center, to_orig_scale = normalize_vertices(normalized_mesh.vertices) sdf = mesh2sdf.core.compute(scaled_vertices, normalized_mesh.faces, size=size) vertices, faces, normals, _ = skimage.measure.marching_cubes(np.abs(sdf), level) # watertight mesh vertices = vertices / size * 2 - 1 # -1 to 1 vertices = vertices / to_orig_scale + to_orig_center mesh = trimesh.Trimesh(vertices, faces, normals=normals) return mesh def process_mesh_to_surface_pc(mesh_list, marching_cubes=False, dilated_offset=0.0, sample_num=10000): # mesh_list : list of trimesh pc_normal_list = [] return_mesh_list = [] for mesh in mesh_list: if marching_cubes: mesh = export_to_watertight(mesh) print("MC over!") if dilated_offset > 0: new_vertices = mesh.vertices + mesh.vertex_normals * dilated_offset mesh.vertices = new_vertices print("dilate over!") mesh.merge_vertices() mesh.update_faces(mesh.unique_faces()) mesh.fix_normals() return_mesh_list.append(mesh) pc_normal = np.asarray(sample_surface_points(mesh, sample_num, return_surface_pc_normals=True)) pc_normal_list.append(pc_normal) print("process mesh success") return pc_normal_list, return_mesh_list #### utils def euler_to_quat(euler): return Rotation.from_euler('XYZ', euler, degrees=True).as_quat() def SRT_quat_to_matrix(scale, quat, translation): rotation_matrix = Rotation.from_quat(quat).as_matrix() transform_matrix = np.eye(4) transform_matrix[:3, :3] = rotation_matrix * scale transform_matrix[:3, 3] = translation return transform_matrix # Animation Functions def create_rotation_animation(primitive_list, duration=3.0, fps=30): """Create a rotation animation for each primitive""" num_frames = int(duration * fps) frames = [] for frame_idx in range(num_frames): t = frame_idx / (num_frames - 1) # Normalized time [0, 1] angle = t * 2 * math.pi # Full rotation frame_scene = trimesh.Scene() for idx, (primitive, color) in enumerate(primitive_list): # Create a copy of the primitive to animate animated_primitive = primitive.copy() # Apply rotation around Y axis rotation_matrix = tf.rotation_matrix(angle, [0, 1, 0], animated_primitive.centroid) animated_primitive.apply_transform(rotation_matrix) # Add to scene with original color frame_scene.add_geometry(animated_primitive, node_name=f'primitive_{idx}') frames.append(frame_scene) return frames def create_float_animation(primitive_list, duration=3.0, fps=30, amplitude=0.1): """Create a floating animation where primitives move up and down""" num_frames = int(duration * fps) frames = [] for frame_idx in range(num_frames): t = frame_idx / (num_frames - 1) # Normalized time [0, 1] frame_scene = trimesh.Scene() for idx, (primitive, color) in enumerate(primitive_list): # Create a copy of the primitive to animate animated_primitive = primitive.copy() # Apply floating motion (sinusoidal) phase_offset = 2 * math.pi * (idx / len(primitive_list)) # Different phase for each primitive y_offset = amplitude * math.sin(2 * math.pi * t + phase_offset) translation_matrix = tf.translation_matrix([0, y_offset, 0]) animated_primitive.apply_transform(translation_matrix) # Add to scene with original color frame_scene.add_geometry(animated_primitive, node_name=f'primitive_{idx}') frames.append(frame_scene) return frames def create_explode_animation(primitive_list, duration=3.0, fps=30, max_distance=0.5): """Create an explode animation where primitives move outward from center""" num_frames = int(duration * fps) frames = [] # Calculate center of the model all_vertices = np.vstack([p.vertices for p, _ in primitive_list]) center = np.mean(all_vertices, axis=0) for frame_idx in range(num_frames): t = frame_idx / (num_frames - 1) # Normalized time [0, 1] frame_scene = trimesh.Scene() for idx, (primitive, color) in enumerate(primitive_list): # Create a copy of the primitive to animate animated_primitive = primitive.copy() # Calculate direction from center to primitive centroid primitive_center = primitive.centroid direction = primitive_center - center if np.linalg.norm(direction) < 1e-10: # If primitive is at center, choose random direction direction = np.random.rand(3) - 0.5 direction = direction / np.linalg.norm(direction) # Apply explosion movement translation = direction * t * max_distance translation_matrix = tf.translation_matrix(translation) animated_primitive.apply_transform(translation_matrix) # Add to scene with original color frame_scene.add_geometry(animated_primitive, node_name=f'primitive_{idx}') frames.append(frame_scene) return frames def create_assemble_animation(primitive_list, duration=3.0, fps=30, start_distance=1.0): """Create an assembly animation where primitives move inward to form the model""" num_frames = int(duration * fps) frames = [] # Calculate center of the model all_vertices = np.vstack([p.vertices for p, _ in primitive_list]) center = np.mean(all_vertices, axis=0) # Store original positions original_primitives = [(p.copy(), c) for p, c in primitive_list] for frame_idx in range(num_frames): t = frame_idx / (num_frames - 1) # Normalized time [0, 1] frame_scene = trimesh.Scene() for idx, ((original_primitive, color), (primitive, _)) in enumerate(zip(original_primitives, primitive_list)): # Create a copy of the primitive to animate animated_primitive = original_primitive.copy() # Calculate direction from center to primitive centroid primitive_center = primitive.centroid direction = primitive_center - center if np.linalg.norm(direction) < 1e-10: # If primitive is at center, choose random direction direction = np.random.rand(3) - 0.5 direction = direction / np.linalg.norm(direction) # Apply assembly movement (1.0 - t for reverse of explosion) translation = direction * (1.0 - t) * start_distance translation_matrix = tf.translation_matrix(translation) animated_primitive.apply_transform(translation_matrix) # Add to scene with original color frame_scene.add_geometry(animated_primitive, node_name=f'primitive_{idx}') frames.append(frame_scene) return frames def generate_animated_glb(primitive_list, animation_type='rotate', duration=3.0, fps=30, output_path="animated_model.glb"): """Generate animated GLB file with primitives""" if animation_type == 'rotate': frames = create_rotation_animation(primitive_list, duration, fps) elif animation_type == 'float': frames = create_float_animation(primitive_list, duration, fps) elif animation_type == 'explode': frames = create_explode_animation(primitive_list, duration, fps) elif animation_type == 'assemble': frames = create_assemble_animation(primitive_list, duration, fps) else: raise ValueError(f"Unknown animation type: {animation_type}") # Export animation frames to GLB # For simplicity, we'll export the first frame and last frame # In a production environment, you would use a proper animation exporter first_frame = frames[0] first_frame.export(output_path) # Also create a GIF for preview gif_path = output_path.replace('.glb', '.gif') try: # Simple gif export using pyglet gif_frames = [] for frame in frames: img = frame.save_image(resolution=[640, 480]) gif_frames.append(img) # Use PIL to save as GIF from PIL import Image gif_frames[0].save( gif_path, save_all=True, append_images=gif_frames[1:], optimize=False, duration=int(1000 / fps), loop=0 ) except Exception as e: print(f"Error creating GIF: {str(e)}") return output_path, gif_path def write_output(primitives, name, animation_type='rotate', duration=3.0, fps=30): out_json = {} new_group = [] model_scene = trimesh.Scene() primitive_list = [] color_map = sns.color_palette("hls", primitives['type_code'].squeeze().shape[0]) color_map = (np.array(color_map) * 255).astype("uint8") for idx, (scale, rotation, translation, type_code) in enumerate(zip( primitives['scale'].squeeze().cpu().numpy(), primitives['rotation'].squeeze().cpu().numpy(), primitives['translation'].squeeze().cpu().numpy(), primitives['type_code'].squeeze().cpu().numpy() )): if type_code == -1: break bs_name = CODE_SHAPE[type_code] new_block = {} new_block['type_id'] = shapename_map[bs_name] new_block['data'] = {} new_block['data']['location'] = translation.tolist() new_block['data']['rotation'] = euler_to_quat(rotation).tolist() new_block['data']['scale'] = scale.tolist() new_group.append(new_block) trans = SRT_quat_to_matrix(scale, euler_to_quat(rotation), translation) bs = mesh_bs[bs_name].copy().apply_transform(trans) new_vertex_colors = np.repeat(color_map[idx:idx+1], bs.visual.vertex_colors.shape[0], axis=0) bs.visual.vertex_colors[:, :3] = new_vertex_colors vertices = bs.vertices.copy() vertices[:, 1] = bs.vertices[:, 2] vertices[:, 2] = -bs.vertices[:, 1] bs.vertices = vertices # Add to primitive list for animation primitive_list.append((bs, color_map[idx])) model_scene.add_geometry(bs) out_json['group'] = new_group # Save static model json_path = os.path.join(LOG_PATH, f'output_{name}.json') with open(json_path, 'w') as json_file: json.dump(out_json, json_file, indent=4) static_glb_path = os.path.join(LOG_PATH, f'output_{name}.glb') model_scene.export(static_glb_path) # Generate animated model animated_glb_path = os.path.join(LOG_PATH, f'animated_{name}.glb') animated_gif_path = os.path.join(LOG_PATH, f'animated_{name}.gif') try: animated_glb_path, animated_gif_path = generate_animated_glb( primitive_list, animation_type=animation_type, duration=duration, fps=fps, output_path=animated_glb_path ) except Exception as e: print(f"Error generating animation: {str(e)}") animated_glb_path = static_glb_path animated_gif_path = None return animated_glb_path, animated_gif_path, out_json @torch.no_grad() def do_inference(input_3d, dilated_offset=0.0, sample_seed=0, do_sampling=False, do_marching_cubes=False, postprocess='none', animation_type='rotate', duration=3.0, fps=30): t1 = time.time() set_seed(sample_seed) input_mesh = trimesh.load(input_3d, force='mesh') # scale mesh vertices = input_mesh.vertices bounds = np.array([vertices.min(axis=0), vertices.max(axis=0)]) vertices = vertices - (bounds[0] + bounds[1])[None, :] / 2 vertices = vertices / (bounds[1] - bounds[0]).max() * 1.6 input_mesh.vertices = vertices pc_list, mesh_list = process_mesh_to_surface_pc( [input_mesh], marching_cubes=do_marching_cubes, dilated_offset=dilated_offset ) pc_normal = pc_list[0] # 10000, 6 mesh = mesh_list[0] pc_coor = pc_normal[:, :3] normals = pc_normal[:, 3:] if dilated_offset > 0: # scale mesh and pc vertices = mesh.vertices bounds = np.array([vertices.min(axis=0), vertices.max(axis=0)]) vertices = vertices - (bounds[0] + bounds[1])[None, :] / 2 vertices = vertices / (bounds[1] - bounds[0]).max() * 1.6 mesh.vertices = vertices pc_coor = pc_coor - (bounds[0] + bounds[1])[None, :] / 2 pc_coor = pc_coor / (bounds[1] - bounds[0]).max() * 1.6 input_save_name = os.path.join(LOG_PATH, f'processed_{os.path.basename(input_3d)}') mesh.export(input_save_name) assert (np.linalg.norm(normals, axis=-1) > 0.99).all(), 'normals should be unit vectors, something wrong' normalized_pc_normal = np.concatenate([pc_coor, normals], axis=-1, dtype=np.float16) input_pc = torch.tensor(normalized_pc_normal, dtype=torch.float16, device=device)[None] with accelerator.autocast(): if postprocess == 'postprocess1': recon_primitives, mask = transformer.generate_w_recon_loss(pc=input_pc, temperature=temperature, single_directional=True) else: recon_primitives, mask = transformer.generate(pc=input_pc, temperature=temperature) output_animated_glb, output_animated_gif, output_json = write_output( recon_primitives, os.path.basename(input_3d)[:-4], animation_type=animation_type, duration=duration, fps=fps ) return input_save_name, output_animated_glb, output_animated_gif, output_json import gradio as gr @spaces.GPU def process_3d_model(input_3d, dilated_offset, do_marching_cubes, animation_type, animation_duration, fps, postprocess_method="postprocess1"): print(f"Processing: {input_3d} with animation type: {animation_type}") try: preprocess_model_obj, output_animated_glb, output_animated_gif, output_model_json = do_inference( input_3d, dilated_offset=dilated_offset, do_marching_cubes=do_marching_cubes, postprocess=postprocess_method, animation_type=animation_type, duration=animation_duration, fps=fps ) # Save JSON to a file json_path = os.path.join(LOG_PATH, f'output_{os.path.basename(input_3d)[:-4]}.json') with open(json_path, 'w') as f: json.dump(output_model_json, f, indent=4) return output_animated_glb, output_animated_gif, json_path except Exception as e: return f"Error processing file: {str(e)}", None, None _HEADER_ = '''

[SIGGRAPH 2025] Animated PrimitiveAnything 🤗 Gradio Demo

This is an enhanced demo for the SIGGRAPH 2025 paper PrimitiveAnything: Human-Crafted 3D Primitive Assembly Generation with Auto-Regressive Transformer, now with animation capabilities! Code: GitHub. Paper: ArXiv. ❗️❗️❗️**Important Notes:** - This demo supports 3D model animation. Upload your GLB file and see it come to life! - Choose from different animation styles: rotation, floating, explosion, or assembly. - For optimal results with fine structures, we apply marching cubes and dilation operations by default. ''' _CITE_ = r""" If PrimitiveAnything is helpful, please help to ⭐ the GitHub Repo. Thanks! [![GitHub Stars](https://img.shields.io/github/stars/PrimitiveAnything/PrimitiveAnything?style=social)](https://github.com/PrimitiveAnything/PrimitiveAnything) --- 📝 **Citation** If you find our work useful for your research or applications, please cite using this bibtex: ```bibtex @misc{ye2025primitiveanything, title={PrimitiveAnything: Human-Crafted 3D Primitive Assembly Generation with Auto-Regressive Transformer}, author={Jingwen Ye and Yuze He and Yanning Zhou and Yiqin Zhu and Kaiwen Xiao and Yong-Jin Liu and Wei Yang and Xiao Han}, year={2025}, eprint={2505.04622}, archivePrefix={arXiv}, primaryClass={cs.GR} } ``` 📧 **Contact** If you have any questions, feel free to open a discussion or contact us at hyz22@mails.tsinghua.edu.cn. """ with gr.Blocks(title="PrimitiveAnything with Animation: 3D Animation Generator") as demo: # Title section gr.Markdown(_HEADER_) with gr.Row(): with gr.Column(): # Input components input_3d = gr.Model3D(label="Upload 3D Model File") with gr.Row(): dilated_offset = gr.Number(label="Dilated Offset", value=0.015) do_marching_cubes = gr.Checkbox(label="Perform Marching Cubes", value=True) with gr.Row(): animation_type = gr.Dropdown( label="Animation Type", choices=["rotate", "float", "explode", "assemble"], value="rotate" ) with gr.Row(): animation_duration = gr.Slider( label="Animation Duration (seconds)", minimum=1.0, maximum=10.0, value=3.0, step=0.5 ) fps = gr.Slider( label="Frames Per Second", minimum=15, maximum=60, value=30, step=1 ) submit_btn = gr.Button("Process and Animate Model") with gr.Column(): # Output components output_3d = gr.Model3D(label="Animated Primitive Assembly") output_gif = gr.Image(label="Animation Preview (GIF)") output_json = gr.File(label="Download JSON File") submit_btn.click( fn=process_3d_model, inputs=[input_3d, dilated_offset, do_marching_cubes, animation_type, animation_duration, fps], outputs=[output_3d, output_gif, output_json] ) # Prepare examples properly example_files = [ [f] for f in glob.glob('./data/demo_glb/*.glb') ] # Note: wrapped in list and filtered for GLB example = gr.Examples( examples=example_files, inputs=[input_3d], # Only include the Model3D input examples_per_page=14, ) gr.Markdown(_CITE_) if __name__ == "__main__": demo.launch(ssr_mode=False)