diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..b120390cdc5b5ff246990bfe20414757b71ae7bc 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +data/demo_glb/*.glb filter=lfs diff=lfs merge=lfs -text +assets/*.jpg filter=lfs diff=lfs merge=lfs -text +data/basic_shapes_norm_pc10000/*.ply filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..68f186095ba760ff96bd6801e52a1acdc2df3742 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +**/__pycache__/ +ckpt +gradio_cached_examples +results \ No newline at end of file diff --git a/app.py b/app.py new file mode 100755 index 0000000000000000000000000000000000000000..f454e11ea17d6ffe1836fdfa099984f2b0264331 --- /dev/null +++ b/app.py @@ -0,0 +1,390 @@ +import os +import time +import glob +import json +import yaml +import torch +import trimesh +import argparse +import mesh2sdf.core +import numpy as np +import skimage.measure +import seaborn as sns +from scipy.spatial.transform import Rotation +from mesh_to_sdf import get_surface_point_cloud +from accelerate.utils import set_seed +from accelerate import Accelerator +from huggingface_hub.file_download import hf_hub_download +from huggingface_hub import list_repo_files + +from primitive_anything.utils import path_mkdir, count_parameters +from primitive_anything.utils.logger import print_log + +os.environ['PYOPENGL_PLATFORM'] = 'egl' + +import spaces + +repo_id = "hyz317/PrimitiveAnything" +all_files = list_repo_files(repo_id, revision="main") +for file in all_files: + if os.path.exists(file): + continue + hf_hub_download(repo_id, file, local_dir="./ckpt") +hf_hub_download("Maikou/Michelangelo", "checkpoints/aligned_shape_latents/shapevae-256.ckpt", local_dir="./ckpt") + +def parse_args(): + parser = argparse.ArgumentParser(description='Process 3D model files') + + parser.add_argument( + '--input', + type=str, + default='./data/demo_glb/', + help='Input file or directory path (default: ./data/demo_glb/)' + ) + + parser.add_argument( + '--log_path', + type=str, + default='./results/demo', + help='Output directory path (default: results/demo)' + ) + + return parser.parse_args() + +def get_input_files(input_path): + if os.path.isfile(input_path): + return [input_path] + elif os.path.isdir(input_path): + return glob.glob(os.path.join(input_path, '*')) + else: + raise ValueError(f"Input path {input_path} is neither a file nor a directory") + +args = parse_args() + +# Create output directory (keeping your original variable name) +LOG_PATH = args.log_path +os.makedirs(LOG_PATH, exist_ok=True) + +print(f"Output directory: {LOG_PATH}") + +CODE_SHAPE = { + 0: 'SM_GR_BS_CubeBevel_001.ply', + 1: 'SM_GR_BS_SphereSharp_001.ply', + 2: 'SM_GR_BS_CylinderSharp_001.ply', +} + +shapename_map = { + 'SM_GR_BS_CubeBevel_001.ply': 1101002001034001, + 'SM_GR_BS_SphereSharp_001.ply': 1101002001034010, + 'SM_GR_BS_CylinderSharp_001.ply': 1101002001034002, +} + +#### config +bs_dir = 'data/basic_shapes_norm' +config_path = './configs/infer.yml' +AR_checkpoint_path = './ckpt/mesh-transformer.ckpt.60.pt' +temperature= 0.0 +#### init model +mesh_bs = {} +for bs_path in glob.glob(os.path.join(bs_dir, '*.ply')): + bs_name = os.path.basename(bs_path) + bs = trimesh.load(bs_path) + bs.visual.uv = np.clip(bs.visual.uv, 0, 1) + bs.visual = bs.visual.to_color() + mesh_bs[bs_name] = bs + +def create_model(cfg_model): + kwargs = cfg_model + name = kwargs.pop('name') + model = get_model(name)(**kwargs) + print_log("Model '{}' init: nb_params={:,}, kwargs={}".format(name, count_parameters(model), kwargs)) + return model + +from primitive_anything.primitive_transformer import PrimitiveTransformerDiscrete +def get_model(name): + return { + 'discrete': PrimitiveTransformerDiscrete, + }[name] + +with open(config_path, mode='r') as fp: + AR_train_cfg = yaml.load(fp, Loader=yaml.FullLoader) + +AR_checkpoint = torch.load(AR_checkpoint_path) + +transformer = create_model(AR_train_cfg['model']) +transformer.load_state_dict(AR_checkpoint) + +device = torch.device('cuda') +accelerator = Accelerator( + mixed_precision='fp16', +) +transformer = accelerator.prepare(transformer) +transformer.eval() +transformer.bs_pc = transformer.bs_pc.cuda() +transformer.rotation_matrix_align_coord = transformer.rotation_matrix_align_coord.cuda() +print('model loaded to device') + + +def sample_surface_points(mesh, number_of_points=500000, surface_point_method='scan', sign_method='normal', + scan_count=100, scan_resolution=400, sample_point_count=10000000, return_gradients=False, + return_surface_pc_normals=False, normalized=False): + sample_start = time.time() + if surface_point_method == 'sample' and sign_method == 'depth': + print("Incompatible methods for sampling points and determining sign, using sign_method='normal' instead.") + sign_method = 'normal' + + surface_start = time.time() + bound_radius = 1 if normalized else None + surface_point_cloud = get_surface_point_cloud(mesh, surface_point_method, bound_radius, scan_count, scan_resolution, + sample_point_count, + calculate_normals=sign_method == 'normal' or return_gradients) + + surface_end = time.time() + print('surface point cloud time cost :', surface_end - surface_start) + + normal_start = time.time() + if return_surface_pc_normals: + rng = np.random.default_rng() + assert surface_point_cloud.points.shape[0] == surface_point_cloud.normals.shape[0] + indices = rng.choice(surface_point_cloud.points.shape[0], number_of_points, replace=True) + points = surface_point_cloud.points[indices] + normals = surface_point_cloud.normals[indices] + surface_points = np.concatenate([points, normals], axis=-1) + else: + surface_points = surface_point_cloud.get_random_surface_points(number_of_points, use_scans=True) + normal_end = time.time() + print('normal time cost :', normal_end - normal_start) + sample_end = time.time() + print('sample surface point time cost :', sample_end - sample_start) + return surface_points + + +def normalize_vertices(vertices, scale=0.9): + bbmin, bbmax = vertices.min(0), vertices.max(0) + center = (bbmin + bbmax) * 0.5 + scale = 2.0 * scale / (bbmax - bbmin).max() + vertices = (vertices - center) * scale + return vertices, center, scale + + +def export_to_watertight(normalized_mesh, octree_depth: int = 7): + """ + Convert the non-watertight mesh to watertight. + + Args: + input_path (str): normalized path + octree_depth (int): + + Returns: + mesh(trimesh.Trimesh): watertight mesh + + """ + size = 2 ** octree_depth + level = 2 / size + + scaled_vertices, to_orig_center, to_orig_scale = normalize_vertices(normalized_mesh.vertices) + sdf = mesh2sdf.core.compute(scaled_vertices, normalized_mesh.faces, size=size) + vertices, faces, normals, _ = skimage.measure.marching_cubes(np.abs(sdf), level) + + # watertight mesh + vertices = vertices / size * 2 - 1 # -1 to 1 + vertices = vertices / to_orig_scale + to_orig_center + mesh = trimesh.Trimesh(vertices, faces, normals=normals) + + return mesh + + +def process_mesh_to_surface_pc(mesh_list, marching_cubes=False, dilated_offset=0.0, sample_num=10000): + # mesh_list : list of trimesh + pc_normal_list = [] + return_mesh_list = [] + for mesh in mesh_list: + if marching_cubes: + mesh = export_to_watertight(mesh) + print("MC over!") + if dilated_offset > 0: + new_vertices = mesh.vertices + mesh.vertex_normals * dilated_offset + mesh.vertices = new_vertices + print("dilate over!") + + mesh.merge_vertices() + mesh.update_faces(mesh.unique_faces()) + mesh.fix_normals() + + return_mesh_list.append(mesh) + + pc_normal = np.asarray(sample_surface_points(mesh, sample_num, return_surface_pc_normals=True)) + pc_normal_list.append(pc_normal) + print("process mesh success") + return pc_normal_list, return_mesh_list + + +#### utils +def euler_to_quat(euler): + return Rotation.from_euler('XYZ', euler, degrees=True).as_quat() + +def SRT_quat_to_matrix(scale, quat, translation): + rotation_matrix = Rotation.from_quat(quat).as_matrix() + transform_matrix = np.eye(4) + transform_matrix[:3, :3] = rotation_matrix * scale + transform_matrix[:3, 3] = translation + return transform_matrix + + +def write_output(primitives, name): + out_json = {} + out_json['operation'] = 0 + out_json['type'] = 1 + out_json['scene_id'] = None + + new_group = [] + model_scene = trimesh.Scene() + color_map = sns.color_palette("hls", primitives['type_code'].squeeze().shape[0]) + color_map = (np.array(color_map) * 255).astype("uint8") + for idx, (scale, rotation, translation, type_code) in enumerate(zip( + primitives['scale'].squeeze().cpu().numpy(), + primitives['rotation'].squeeze().cpu().numpy(), + primitives['translation'].squeeze().cpu().numpy(), + primitives['type_code'].squeeze().cpu().numpy() + )): + if type_code == -1: + break + bs_name = CODE_SHAPE[type_code] + new_block = {} + new_block['type_id'] = shapename_map[bs_name] + new_block['data'] = {} + new_block['data']['location'] = translation.tolist() + new_block['data']['rotation'] = euler_to_quat(rotation).tolist() + new_block['data']['scale'] = scale.tolist() + new_block['data']['color'] = ['808080'] + new_group.append(new_block) + + trans = SRT_quat_to_matrix(scale, euler_to_quat(rotation), translation) + bs = mesh_bs[bs_name].copy().apply_transform(trans) + new_vertex_colors = np.repeat(color_map[idx:idx+1], bs.visual.vertex_colors.shape[0], axis=0) + bs.visual.vertex_colors[:, :3] = new_vertex_colors + vertices = bs.vertices.copy() + vertices[:, 1] = bs.vertices[:, 2] + vertices[:, 2] = -bs.vertices[:, 1] + bs.vertices = vertices + model_scene.add_geometry(bs) + out_json['group'] = new_group + + json_path = os.path.join(LOG_PATH, f'output_{name}.json') + with open(json_path, 'w') as json_file: + json.dump(out_json, json_file, indent=4) + + glb_path = os.path.join(LOG_PATH, f'output_{name}.glb') + model_scene.export(glb_path) + + return glb_path, out_json + + +@torch.no_grad() +def do_inference(input_3d, dilated_offset=0.0, sample_seed=0, do_sampling=False, do_marching_cubes=False, postprocess='none'): + t1 = time.time() + set_seed(sample_seed) + input_mesh = trimesh.load(input_3d, force='mesh') + + # scale mesh + vertices = input_mesh.vertices + bounds = np.array([vertices.min(axis=0), vertices.max(axis=0)]) + vertices = vertices - (bounds[0] + bounds[1])[None, :] / 2 + vertices = vertices / (bounds[1] - bounds[0]).max() * 1.6 + input_mesh.vertices = vertices + + pc_list, mesh_list = process_mesh_to_surface_pc( + [input_mesh], + marching_cubes=do_marching_cubes, + dilated_offset=dilated_offset + ) + pc_normal = pc_list[0] # 10000, 6 + mesh = mesh_list[0] + + pc_coor = pc_normal[:, :3] + normals = pc_normal[:, 3:] + + if dilated_offset > 0: + # scale mesh and pc + vertices = mesh.vertices + bounds = np.array([vertices.min(axis=0), vertices.max(axis=0)]) + vertices = vertices - (bounds[0] + bounds[1])[None, :] / 2 + vertices = vertices / (bounds[1] - bounds[0]).max() * 1.6 + mesh.vertices = vertices + pc_coor = pc_coor - (bounds[0] + bounds[1])[None, :] / 2 + pc_coor = pc_coor / (bounds[1] - bounds[0]).max() * 1.6 + + input_save_name = os.path.join(LOG_PATH, f'processed_{os.path.basename(input_3d)}') + mesh.export(input_save_name) + + assert (np.linalg.norm(normals, axis=-1) > 0.99).all(), 'normals should be unit vectors, something wrong' + normalized_pc_normal = np.concatenate([pc_coor, normals], axis=-1, dtype=np.float16) + + input_pc = torch.tensor(normalized_pc_normal, dtype=torch.float16, device=device)[None] + + with accelerator.autocast(): + if postprocess == 'postprocess1': + recon_primitives, mask = transformer.generate_w_recon_loss(pc=input_pc, temperature=temperature, single_directional=True) + else: + recon_primitives, mask = transformer.generate(pc=input_pc, temperature=temperature) + + output_glb, output_json = write_output(recon_primitives, os.path.basename(input_3d)[:-4]) + + return input_save_name, output_glb, output_json + + +import gradio as gr + +@spaces.GPU +def process_3d_model(input_3d, dilated_offset, do_marching_cubes, postprocess_method="postprocess1"): + print(f"processing: {input_3d}") + # try: + preprocess_model_obj, output_model_obj, output_model_json = do_inference( + input_3d, + dilated_offset=dilated_offset, + do_marching_cubes=do_marching_cubes, + postprocess=postprocess_method + ) + return output_model_obj + # except Exception as e: + # return f"Error processing file: {str(e)}" + +# Title and reminder placeholders +title = "3D Model Processing Demo" +reminder = "Please upload your 3D model file and adjust parameters as needed." + +with gr.Blocks(title=title) as demo: + # Title section + gr.Markdown(f"# {title}") + gr.Markdown(reminder) + + with gr.Row(): + with gr.Column(): + # Input components + input_3d = gr.Model3D(label="Upload 3D Model File") + dilated_offset = gr.Number(label="Dilated Offset", value=0.015) + do_marching_cubes = gr.Checkbox(label="Perform Marching Cubes", value=True) + submit_btn = gr.Button("Process Model") + + with gr.Column(): + # Output components + output = gr.Model3D(label="Primitive Assembly Predition") + + submit_btn.click( + fn=process_3d_model, + inputs=[input_3d, dilated_offset, do_marching_cubes], + outputs=output + ) + + + # Prepare examples properly + example_files = [ [f] for f in glob.glob('./data/demo_glb/*.glb') ] # Note: wrapped in list and filtered for GLB + + example = gr.Examples( + examples=example_files, + inputs=[input_3d], # Only include the Model3D input + examples_per_page=14, + ) + +if __name__ == "__main__": + demo.launch() \ No newline at end of file diff --git a/assets/teaser.jpg b/assets/teaser.jpg new file mode 100644 index 0000000000000000000000000000000000000000..030ffc2cf437f32b5dfbfada9714cb2b0ce63458 --- /dev/null +++ b/assets/teaser.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ae89c8078e3379126ff0f0723ee4598a99ac7515d8ac70d62f479b243d80b792 +size 1409691 diff --git a/configs/infer.yml b/configs/infer.yml new file mode 100755 index 0000000000000000000000000000000000000000..ed66f9fc2f70e9e9c31697d04e7b4d735ca91897 --- /dev/null +++ b/configs/infer.yml @@ -0,0 +1,52 @@ +dataset: + name: base + pc_dir: ./data/test_pc + bs_dir: data/basic_shapes_norm + max_length: 144 + range_scale: [0, 1] + range_rotation: [-180, 180] + range_translation: [-1, 1] + rotation_type: euler + pc_format: pn +model: + attn_depth: 6 + attn_heads: 6 + bin_smooth_blur_sigma: -1 + bs_pc_dir: data/basic_shapes_norm_pc10000 + coarse_pre_gateloop_depth: 3 + continuous_range_rotation: + - -181 + - 181 + continuous_range_scale: + - 0 + - 1 + continuous_range_translation: + - -1 + - 1 + dim: 768 + dim_rotation_embed: 16 + dim_scale_embed: 16 + dim_translation_embed: 16 + dim_type_embed: 48 + dropout: 0.0 + embed_order: ctrs + gateloop_use_heinsen: false + loss_weight: + eos: 1.0 + reconstruction: 1.0 + rotation: 1.0 + scale: 1.0 + translation: 1.0 + type: 1.0 + max_primitive_len: 144 + name: discrete + num_discrete_rotation: 181 + num_discrete_scale: 128 + num_discrete_translation: 128 + num_type: 3 + shape_cond_with_cat: true + shape_cond_with_cross_attn: false + shape_cond_with_film: false + shape_condition_dim: 768 + shape_condition_len: 77 + shape_condition_model_type: michelangelo diff --git a/data/basic_shapes_norm/SM_GR_BS_CubeBevel_001.ply b/data/basic_shapes_norm/SM_GR_BS_CubeBevel_001.ply new file mode 100755 index 0000000000000000000000000000000000000000..2e99e2df88f8c307d453a203469b3280264a623e Binary files /dev/null and b/data/basic_shapes_norm/SM_GR_BS_CubeBevel_001.ply differ diff --git a/data/basic_shapes_norm/SM_GR_BS_CylinderSharp_001.ply b/data/basic_shapes_norm/SM_GR_BS_CylinderSharp_001.ply new file mode 100755 index 0000000000000000000000000000000000000000..476c24e0059d72abc73340731ee20d39e415ae8a Binary files /dev/null and b/data/basic_shapes_norm/SM_GR_BS_CylinderSharp_001.ply differ diff --git a/data/basic_shapes_norm/SM_GR_BS_SphereSharp_001.ply b/data/basic_shapes_norm/SM_GR_BS_SphereSharp_001.ply new file mode 100755 index 0000000000000000000000000000000000000000..60573951583ce85eb74da84f597dc461a377fc2b Binary files /dev/null and b/data/basic_shapes_norm/SM_GR_BS_SphereSharp_001.ply differ diff --git a/data/basic_shapes_norm/basic_shapes.json b/data/basic_shapes_norm/basic_shapes.json new file mode 100755 index 0000000000000000000000000000000000000000..5b75b9524a1691be3ef6a316e3d9b48f13d883ac --- /dev/null +++ b/data/basic_shapes_norm/basic_shapes.json @@ -0,0 +1,89 @@ +{ + "SM_GR_BS_CubeBevel_001.ply": { + "name": "SM_GR_BS_CubeBevel_001.ply", + "tform_bs_to_normalized": [ + [ + 0.02, + 0.0, + 0.0, + 0.0 + ], + [ + 0.0, + 0.02, + 0.0, + 9.701276818911235e-18 + ], + [ + 0.0, + 0.0, + 0.019999999999999997, + -0.9999999999999999 + ], + [ + 0.0, + 0.0, + 0.0, + 1.0 + ] + ] + }, + "SM_GR_BS_CylinderSharp_001.ply": { + "name": "SM_GR_BS_CylinderSharp_001.ply", + "tform_bs_to_normalized": [ + [ + 0.006666668023003748, + 0.0, + 0.0, + -2.0345056221459462e-07 + ], + [ + 0.0, + 0.006666667683919426, + 0.0, + -5.086263794939386e-08 + ], + [ + 0.0, + 0.0, + 0.006666665445429783, + -0.9999998370794186 + ], + [ + 0.0, + 0.0, + 0.0, + 1.0 + ] + ] + }, + "SM_GR_BS_SphereSharp_001.ply": { + "name": "SM_GR_BS_SphereSharp_001.ply", + "tform_bs_to_normalized": [ + [ + 0.006666666666666667, + 0.0, + 0.0, + 0.0 + ], + [ + 0.0, + 0.006666666666666667, + 0.0, + 0.0 + ], + [ + 0.0, + 0.0, + 0.006666666666666667, + -1.0 + ], + [ + 0.0, + 0.0, + 0.0, + 1.0 + ] + ] + } +} \ No newline at end of file diff --git a/data/basic_shapes_norm_pc10000/SM_GR_BS_CubeBevel_001.ply b/data/basic_shapes_norm_pc10000/SM_GR_BS_CubeBevel_001.ply new file mode 100755 index 0000000000000000000000000000000000000000..ed843a92d125bb018c9cbc32b482f408d1ba77cd --- /dev/null +++ b/data/basic_shapes_norm_pc10000/SM_GR_BS_CubeBevel_001.ply @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ba980c1fb389e30783f09b07d35e788e08a97776d933b8bfd346147c9a7e86a0 +size 510265 diff --git a/data/basic_shapes_norm_pc10000/SM_GR_BS_CylinderSharp_001.ply b/data/basic_shapes_norm_pc10000/SM_GR_BS_CylinderSharp_001.ply new file mode 100755 index 0000000000000000000000000000000000000000..456003e2f9f3b8cc9ef1910f3a4f96d7008def82 --- /dev/null +++ b/data/basic_shapes_norm_pc10000/SM_GR_BS_CylinderSharp_001.ply @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ab8fb7aa7ec39237474d0a6e77da1d7070742f61af21e6b44dc9998fac1913cc +size 510265 diff --git a/data/basic_shapes_norm_pc10000/SM_GR_BS_SphereSharp_001.ply b/data/basic_shapes_norm_pc10000/SM_GR_BS_SphereSharp_001.ply new file mode 100755 index 0000000000000000000000000000000000000000..71391e250a85ddbf2422bd8f9a8a70c748b83a2f --- /dev/null +++ b/data/basic_shapes_norm_pc10000/SM_GR_BS_SphereSharp_001.ply @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8765da7294292422d077267c1b71b9ea055f831aab3840d869656632ee6e8569 +size 510265 diff --git a/data/demo_glb/barbell.glb b/data/demo_glb/barbell.glb new file mode 100644 index 0000000000000000000000000000000000000000..f7bcceeb9c48286a53daf3d2ffb694cdd60ec61c --- /dev/null +++ b/data/demo_glb/barbell.glb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8a9b9c124c321d6d18342b12407fc7327bdd56c8720d317e7b8694c10c851936 +size 769528 diff --git a/data/demo_glb/book.glb b/data/demo_glb/book.glb new file mode 100644 index 0000000000000000000000000000000000000000..39aebb161e2a34cd91292e73b6a288c371e8faf9 --- /dev/null +++ b/data/demo_glb/book.glb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:be526e9bca2ce3a74387f2dde7f6a25c9502a7c6d4f9fc671b244d09c18a9d94 +size 5369916 diff --git a/data/demo_glb/bunny.glb b/data/demo_glb/bunny.glb new file mode 100644 index 0000000000000000000000000000000000000000..32eed742804f7e01336de502662de217aabd14af --- /dev/null +++ b/data/demo_glb/bunny.glb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f62b2169b7cda3662de660d0b2e8ce2a1ccfe3bd186f243d890440e8cf7a0766 +size 27518016 diff --git a/data/demo_glb/desk.glb b/data/demo_glb/desk.glb new file mode 100644 index 0000000000000000000000000000000000000000..4b7f7d520994fad4e1a15648df9619e45e41c90a --- /dev/null +++ b/data/demo_glb/desk.glb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4c8fd8041f1e870ba285572f3fb3e129107678ab9a311524a8376cc404cc332e +size 33679548 diff --git a/data/demo_glb/man.glb b/data/demo_glb/man.glb new file mode 100644 index 0000000000000000000000000000000000000000..0eae899d243a946a1189847c25647a099ee29d0c --- /dev/null +++ b/data/demo_glb/man.glb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:063a56d0a56d3866cf36170bbafad93924fd350882ac0f69e727cb43dc203351 +size 31784 diff --git a/data/demo_glb/micky.glb b/data/demo_glb/micky.glb new file mode 100644 index 0000000000000000000000000000000000000000..8899bdaa4d05d34cbfafdb1ee348886e5f9e3227 --- /dev/null +++ b/data/demo_glb/micky.glb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fc6410c2f2a588c5b064f9255a1c1657a9dff061ae6f7342df693c80eef0c69d +size 294576 diff --git a/data/demo_glb/pac.glb b/data/demo_glb/pac.glb new file mode 100644 index 0000000000000000000000000000000000000000..e20f2569c0ca87b59ab1af572d1b7dbaee2bef20 --- /dev/null +++ b/data/demo_glb/pac.glb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dc39816cc71440fbc31d99c24f24ed42c25daf7319ba07b8dd3e34c1ea083578 +size 274004 diff --git a/data/demo_glb/robot.glb b/data/demo_glb/robot.glb new file mode 100644 index 0000000000000000000000000000000000000000..5123b056c244fb87e99e29ca90b53dcc2e3f7286 --- /dev/null +++ b/data/demo_glb/robot.glb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cf201fbe21d73428f88e4a7e428849148ebd500cc4ce6ac3929638a53c5376ae +size 28116940 diff --git a/data/demo_glb/rocket.glb b/data/demo_glb/rocket.glb new file mode 100644 index 0000000000000000000000000000000000000000..6bbf1751b7fbe70a5ff1da2ee7dc8d64a0cc802b --- /dev/null +++ b/data/demo_glb/rocket.glb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:29601d218e0d51a4dced2f1ed2a898e80f0d223b1e04212d45b0dda4ad670d1c +size 1426588 diff --git a/data/demo_glb/sheep.glb b/data/demo_glb/sheep.glb new file mode 100644 index 0000000000000000000000000000000000000000..ed3ee76e49280d2f949892f390c7583698566da8 --- /dev/null +++ b/data/demo_glb/sheep.glb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:232b1303e56ec1682536c72bc9409585930985492dcbdfa101cdfb96d0b4fbf2 +size 28732 diff --git a/data/demo_glb/shelf.glb b/data/demo_glb/shelf.glb new file mode 100644 index 0000000000000000000000000000000000000000..b19d7f8fbf26bdd68d461083092d32637be56bf7 --- /dev/null +++ b/data/demo_glb/shelf.glb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d63916c1ef1d5b2fc4d56e20c76316bd977f93819323f73e7f5e1c59df21e284 +size 3091336 diff --git a/data/demo_glb/table.glb b/data/demo_glb/table.glb new file mode 100644 index 0000000000000000000000000000000000000000..787cb3b23bcc7269dab0f31f1e285e87626bd7ea --- /dev/null +++ b/data/demo_glb/table.glb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:743c96d7aa1bef88576c5f28d5e06144f93e27a2e6ea5ef8bd85669d1213af9f +size 20093692 diff --git a/data/demo_glb/vent.glb b/data/demo_glb/vent.glb new file mode 100644 index 0000000000000000000000000000000000000000..bc5b1cd68ac6440803f1f0b27bcd201fac404c4e --- /dev/null +++ b/data/demo_glb/vent.glb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:98d8d0d2c8d164fc75361d7d1408e9102f9148482c0343e0cc87d21950e20ab1 +size 1785468 diff --git a/data/demo_glb/walkman.glb b/data/demo_glb/walkman.glb new file mode 100644 index 0000000000000000000000000000000000000000..cb8b2c47926b5d02d16940382390846c284d7d88 --- /dev/null +++ b/data/demo_glb/walkman.glb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:992355cf45609881223561d1081a05483c2bad488ed7148f87243259aa36be1b +size 158156 diff --git a/pre-requirements.txt b/pre-requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..7ec53d12d36eadd28748766a73b18453bcc261f8 --- /dev/null +++ b/pre-requirements.txt @@ -0,0 +1,36 @@ +--extra-index-url https://download.pytorch.org/whl/cu121 +--extra-index-url https://data.dgl.ai/wheels/torch-2.1/cu121/repo.html +torch==2.2.0 +torchvision==0.17.0 +dgl +accelerate +beartype +einops +gateloop_transformer +matplotlib +scikit-learn +pandas +pytorch_custom_utils +gradio +pydantic==2.10.6 +x_transformers +torch_redstone +torchdata==0.9.0 +toolz +environs +jaxtyping +omegaconf +ema_pytorch +local_attention==1.9.15 +taylor_series_linear_attention +transformers +vector_quantize_pytorch +open3d +trimesh +pytorch_lightning +scikit-image +opencv-python +mesh2sdf +seaborn +mesh_to_sdf +point_cloud_utils diff --git a/primitive_anything/__init__.py b/primitive_anything/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/primitive_anything/michelangelo/__init__.py b/primitive_anything/michelangelo/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..7970340f74f682096f3b8b24ea1375c0db6e2690 --- /dev/null +++ b/primitive_anything/michelangelo/__init__.py @@ -0,0 +1,51 @@ +import os + +from omegaconf import OmegaConf +import torch +from torch import nn + +from .utils.misc import instantiate_from_config +from ..utils import default, exists + + +def load_model(): + model_config = OmegaConf.load(os.path.join(os.path.dirname(__file__), "shapevae-256.yaml")) + # print(model_config) + if hasattr(model_config, "model"): + model_config = model_config.model + ckpt_path = "./ckpt/shapevae-256.ckpt" + + model = instantiate_from_config(model_config, ckpt_path=ckpt_path) + # model = model.cuda() + model = model.eval() + + return model + + +class ShapeConditioner(nn.Module): + def __init__( + self, + *, + dim_latent = None + ): + super().__init__() + self.model = load_model() + + self.dim_model_out = 768 + dim_latent = default(dim_latent, self.dim_model_out) + self.dim_latent = dim_latent + + def forward( + self, + shape = None, + shape_embed = None, + ): + assert exists(shape) ^ exists(shape_embed) + + if not exists(shape_embed): + point_feature = self.model.encode_latents(shape) + shape_latents = self.model.to_shape_latents(point_feature[:, 1:]) + shape_head = point_feature[:, 0:1] + shape_embed = torch.cat([point_feature[:, 1:], shape_latents], dim=-1) + # shape_embed = torch.cat([point_feature[:, 1:], shape_latents], dim=-2) # cat tmp + return shape_head, shape_embed \ No newline at end of file diff --git a/primitive_anything/michelangelo/data/__init__.py b/primitive_anything/michelangelo/data/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..40a96afc6ff09d58a702b76e3f7dd412fe975e26 --- /dev/null +++ b/primitive_anything/michelangelo/data/__init__.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- diff --git a/primitive_anything/michelangelo/data/templates.json b/primitive_anything/michelangelo/data/templates.json new file mode 100755 index 0000000000000000000000000000000000000000..f1a355f6bb61a98b4e229e7aece7b0035c3b0592 --- /dev/null +++ b/primitive_anything/michelangelo/data/templates.json @@ -0,0 +1,69 @@ +{ + "shape": [ + "a point cloud model of {}.", + "There is a {} in the scene.", + "There is the {} in the scene.", + "a photo of a {} in the scene.", + "a photo of the {} in the scene.", + "a photo of one {} in the scene.", + "itap of a {}.", + "itap of my {}.", + "itap of the {}.", + "a photo of a {}.", + "a photo of my {}.", + "a photo of the {}.", + "a photo of one {}.", + "a photo of many {}.", + "a good photo of a {}.", + "a good photo of the {}.", + "a bad photo of a {}.", + "a bad photo of the {}.", + "a photo of a nice {}.", + "a photo of the nice {}.", + "a photo of a cool {}.", + "a photo of the cool {}.", + "a photo of a weird {}.", + "a photo of the weird {}.", + "a photo of a small {}.", + "a photo of the small {}.", + "a photo of a large {}.", + "a photo of the large {}.", + "a photo of a clean {}.", + "a photo of the clean {}.", + "a photo of a dirty {}.", + "a photo of the dirty {}.", + "a bright photo of a {}.", + "a bright photo of the {}.", + "a dark photo of a {}.", + "a dark photo of the {}.", + "a photo of a hard to see {}.", + "a photo of the hard to see {}.", + "a low resolution photo of a {}.", + "a low resolution photo of the {}.", + "a cropped photo of a {}.", + "a cropped photo of the {}.", + "a close-up photo of a {}.", + "a close-up photo of the {}.", + "a jpeg corrupted photo of a {}.", + "a jpeg corrupted photo of the {}.", + "a blurry photo of a {}.", + "a blurry photo of the {}.", + "a pixelated photo of a {}.", + "a pixelated photo of the {}.", + "a black and white photo of the {}.", + "a black and white photo of a {}", + "a plastic {}.", + "the plastic {}.", + "a toy {}.", + "the toy {}.", + "a plushie {}.", + "the plushie {}.", + "a cartoon {}.", + "the cartoon {}.", + "an embroidered {}.", + "the embroidered {}.", + "a painting of the {}.", + "a painting of a {}." + ] + +} \ No newline at end of file diff --git a/primitive_anything/michelangelo/data/transforms.py b/primitive_anything/michelangelo/data/transforms.py new file mode 100755 index 0000000000000000000000000000000000000000..5f7034b9c6b86e5151be264ac28b0b50961cab77 --- /dev/null +++ b/primitive_anything/michelangelo/data/transforms.py @@ -0,0 +1,407 @@ +# -*- coding: utf-8 -*- +import os +import time +import numpy as np +import warnings +import random +from omegaconf.listconfig import ListConfig +from webdataset import pipelinefilter +import torch +import torchvision.transforms.functional as TVF +from torchvision.transforms import InterpolationMode +from torchvision.transforms.transforms import _interpolation_modes_from_int +from typing import Sequence + +from ..utils import instantiate_from_config + + +def _uid_buffer_pick(buf_dict, rng): + uid_keys = list(buf_dict.keys()) + selected_uid = rng.choice(uid_keys) + buf = buf_dict[selected_uid] + + k = rng.randint(0, len(buf) - 1) + sample = buf[k] + buf[k] = buf[-1] + buf.pop() + + if len(buf) == 0: + del buf_dict[selected_uid] + + return sample + + +def _add_to_buf_dict(buf_dict, sample): + key = sample["__key__"] + uid, uid_sample_id = key.split("_") + if uid not in buf_dict: + buf_dict[uid] = [] + buf_dict[uid].append(sample) + + return buf_dict + + +def _uid_shuffle(data, bufsize=1000, initial=100, rng=None, handler=None): + """Shuffle the data in the stream. + + This uses a buffer of size `bufsize`. Shuffling at + startup is less random; this is traded off against + yielding samples quickly. + + data: iterator + bufsize: buffer size for shuffling + returns: iterator + rng: either random module or random.Random instance + + """ + if rng is None: + rng = random.Random(int((os.getpid() + time.time()) * 1e9)) + initial = min(initial, bufsize) + buf_dict = dict() + current_samples = 0 + for sample in data: + _add_to_buf_dict(buf_dict, sample) + current_samples += 1 + + if current_samples < bufsize: + try: + _add_to_buf_dict(buf_dict, next(data)) # skipcq: PYL-R1708 + current_samples += 1 + except StopIteration: + pass + + if current_samples >= initial: + current_samples -= 1 + yield _uid_buffer_pick(buf_dict, rng) + + while current_samples > 0: + current_samples -= 1 + yield _uid_buffer_pick(buf_dict, rng) + + +uid_shuffle = pipelinefilter(_uid_shuffle) + + +class RandomSample(object): + def __init__(self, + num_volume_samples: int = 1024, + num_near_samples: int = 1024): + + super().__init__() + + self.num_volume_samples = num_volume_samples + self.num_near_samples = num_near_samples + + def __call__(self, sample): + rng = np.random.default_rng() + + # 1. sample surface input + total_surface = sample["surface"] + ind = rng.choice(total_surface.shape[0], replace=False) + surface = total_surface[ind] + + # 2. sample volume/near geometric points + vol_points = sample["vol_points"] + vol_label = sample["vol_label"] + near_points = sample["near_points"] + near_label = sample["near_label"] + + ind = rng.choice(vol_points.shape[0], self.num_volume_samples, replace=False) + vol_points = vol_points[ind] + vol_label = vol_label[ind] + vol_points_labels = np.concatenate([vol_points, vol_label[:, np.newaxis]], axis=1) + + ind = rng.choice(near_points.shape[0], self.num_near_samples, replace=False) + near_points = near_points[ind] + near_label = near_label[ind] + near_points_labels = np.concatenate([near_points, near_label[:, np.newaxis]], axis=1) + + # concat sampled volume and near points + geo_points = np.concatenate([vol_points_labels, near_points_labels], axis=0) + + sample = { + "surface": surface, + "geo_points": geo_points + } + + return sample + + +class SplitRandomSample(object): + def __init__(self, + use_surface_sample: bool = False, + num_surface_samples: int = 4096, + num_volume_samples: int = 1024, + num_near_samples: int = 1024): + + super().__init__() + + self.use_surface_sample = use_surface_sample + self.num_surface_samples = num_surface_samples + self.num_volume_samples = num_volume_samples + self.num_near_samples = num_near_samples + + def __call__(self, sample): + + rng = np.random.default_rng() + + # 1. sample surface input + surface = sample["surface"] + + if self.use_surface_sample: + replace = surface.shape[0] < self.num_surface_samples + ind = rng.choice(surface.shape[0], self.num_surface_samples, replace=replace) + surface = surface[ind] + + # 2. sample volume/near geometric points + vol_points = sample["vol_points"] + vol_label = sample["vol_label"] + near_points = sample["near_points"] + near_label = sample["near_label"] + + ind = rng.choice(vol_points.shape[0], self.num_volume_samples, replace=False) + vol_points = vol_points[ind] + vol_label = vol_label[ind] + vol_points_labels = np.concatenate([vol_points, vol_label[:, np.newaxis]], axis=1) + + ind = rng.choice(near_points.shape[0], self.num_near_samples, replace=False) + near_points = near_points[ind] + near_label = near_label[ind] + near_points_labels = np.concatenate([near_points, near_label[:, np.newaxis]], axis=1) + + # concat sampled volume and near points + geo_points = np.concatenate([vol_points_labels, near_points_labels], axis=0) + + sample = { + "surface": surface, + "geo_points": geo_points + } + + return sample + + +class FeatureSelection(object): + + VALID_SURFACE_FEATURE_DIMS = { + "none": [0, 1, 2], # xyz + "watertight_normal": [0, 1, 2, 3, 4, 5], # xyz, normal + "normal": [0, 1, 2, 6, 7, 8] + } + + def __init__(self, surface_feature_type: str): + + self.surface_feature_type = surface_feature_type + self.surface_dims = self.VALID_SURFACE_FEATURE_DIMS[surface_feature_type] + + def __call__(self, sample): + sample["surface"] = sample["surface"][:, self.surface_dims] + return sample + + +class AxisScaleTransform(object): + def __init__(self, interval=(0.75, 1.25), jitter=True, jitter_scale=0.005): + assert isinstance(interval, (tuple, list, ListConfig)) + self.interval = interval + self.min_val = interval[0] + self.max_val = interval[1] + self.inter_size = interval[1] - interval[0] + self.jitter = jitter + self.jitter_scale = jitter_scale + + def __call__(self, sample): + + surface = sample["surface"][..., 0:3] + geo_points = sample["geo_points"][..., 0:3] + + scaling = torch.rand(1, 3) * self.inter_size + self.min_val + # print(scaling) + surface = surface * scaling + geo_points = geo_points * scaling + + scale = (1 / torch.abs(surface).max().item()) * 0.999999 + surface *= scale + geo_points *= scale + + if self.jitter: + surface += self.jitter_scale * torch.randn_like(surface) + surface.clamp_(min=-1.015, max=1.015) + + sample["surface"][..., 0:3] = surface + sample["geo_points"][..., 0:3] = geo_points + + return sample + + +class ToTensor(object): + + def __init__(self, tensor_keys=("surface", "geo_points", "tex_points")): + self.tensor_keys = tensor_keys + + def __call__(self, sample): + for key in self.tensor_keys: + if key not in sample: + continue + + sample[key] = torch.tensor(sample[key], dtype=torch.float32) + + return sample + + +class AxisScale(object): + def __init__(self, interval=(0.75, 1.25), jitter=True, jitter_scale=0.005): + assert isinstance(interval, (tuple, list, ListConfig)) + self.interval = interval + self.jitter = jitter + self.jitter_scale = jitter_scale + + def __call__(self, surface, *args): + scaling = torch.rand(1, 3) * 0.5 + 0.75 + # print(scaling) + surface = surface * scaling + scale = (1 / torch.abs(surface).max().item()) * 0.999999 + surface *= scale + + args_outputs = [] + for _arg in args: + _arg = _arg * scaling * scale + args_outputs.append(_arg) + + if self.jitter: + surface += self.jitter_scale * torch.randn_like(surface) + surface.clamp_(min=-1, max=1) + + if len(args) == 0: + return surface + else: + return surface, *args_outputs + + +class RandomResize(torch.nn.Module): + """Apply randomly Resize with a given probability.""" + + def __init__( + self, + size, + resize_radio=(0.5, 1), + allow_resize_interpolations=(InterpolationMode.BICUBIC, InterpolationMode.BILINEAR, InterpolationMode.BILINEAR), + interpolation=InterpolationMode.BICUBIC, + max_size=None, + antialias=None, + ): + super().__init__() + if not isinstance(size, (int, Sequence)): + raise TypeError(f"Size should be int or sequence. Got {type(size)}") + if isinstance(size, Sequence) and len(size) not in (1, 2): + raise ValueError("If size is a sequence, it should have 1 or 2 values") + + self.size = size + self.max_size = max_size + # Backward compatibility with integer value + if isinstance(interpolation, int): + warnings.warn( + "Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. " + "Please use InterpolationMode enum." + ) + interpolation = _interpolation_modes_from_int(interpolation) + + self.interpolation = interpolation + self.antialias = antialias + + self.resize_radio = resize_radio + self.allow_resize_interpolations = allow_resize_interpolations + + def random_resize_params(self): + radio = torch.rand(1) * (self.resize_radio[1] - self.resize_radio[0]) + self.resize_radio[0] + + if isinstance(self.size, int): + size = int(self.size * radio) + elif isinstance(self.size, Sequence): + size = list(self.size) + size = (int(size[0] * radio), int(size[1] * radio)) + else: + raise RuntimeError() + + interpolation = self.allow_resize_interpolations[ + torch.randint(low=0, high=len(self.allow_resize_interpolations), size=(1,)) + ] + return size, interpolation + + def forward(self, img): + size, interpolation = self.random_resize_params() + img = TVF.resize(img, size, interpolation, self.max_size, self.antialias) + img = TVF.resize(img, self.size, self.interpolation, self.max_size, self.antialias) + return img + + def __repr__(self) -> str: + detail = f"(size={self.size}, interpolation={self.interpolation.value}," + detail += f"max_size={self.max_size}, antialias={self.antialias}), resize_radio={self.resize_radio}" + return f"{self.__class__.__name__}{detail}" + + +class Compose(object): + """Composes several transforms together. This transform does not support torchscript. + Please, see the note below. + + Args: + transforms (list of ``Transform`` objects): list of transforms to compose. + + Example: + >>> transforms.Compose([ + >>> transforms.CenterCrop(10), + >>> transforms.ToTensor(), + >>> ]) + + .. note:: + In order to script the transformations, please use ``torch.nn.Sequential`` as below. + + >>> transforms = torch.nn.Sequential( + >>> transforms.CenterCrop(10), + >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + >>> ) + >>> scripted_transforms = torch.jit.script(transforms) + + Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require + `lambda` functions or ``PIL.Image``. + + """ + + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, *args): + for t in self.transforms: + args = t(*args) + return args + + def __repr__(self): + format_string = self.__class__.__name__ + '(' + for t in self.transforms: + format_string += '\n' + format_string += ' {0}'.format(t) + format_string += '\n)' + return format_string + + +def identity(*args, **kwargs): + if len(args) == 1: + return args[0] + else: + return args + + +def build_transforms(cfg): + + if cfg is None: + return identity + + transforms = [] + + for transform_name, cfg_instance in cfg.items(): + transform_instance = instantiate_from_config(cfg_instance) + transforms.append(transform_instance) + print(f"Build transform: {transform_instance}") + + transforms = Compose(transforms) + + return transforms + diff --git a/primitive_anything/michelangelo/data/utils.py b/primitive_anything/michelangelo/data/utils.py new file mode 100755 index 0000000000000000000000000000000000000000..af06ed0c8849819a5d2b72ece805e8ec26079ea9 --- /dev/null +++ b/primitive_anything/michelangelo/data/utils.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- + +import torch +import numpy as np + + +def worker_init_fn(_): + worker_info = torch.utils.data.get_worker_info() + worker_id = worker_info.id + + # dataset = worker_info.dataset + # split_size = dataset.num_records // worker_info.num_workers + # # reset num_records to the true number to retain reliable length information + # dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size] + # current_id = np.random.choice(len(np.random.get_state()[1]), 1) + # return np.random.seed(np.random.get_state()[1][current_id] + worker_id) + + return np.random.seed(np.random.get_state()[1][0] + worker_id) + + +def collation_fn(samples, combine_tensors=True, combine_scalars=True): + """ + + Args: + samples (list[dict]): + combine_tensors: + combine_scalars: + + Returns: + + """ + + result = {} + + keys = samples[0].keys() + + for key in keys: + result[key] = [] + + for sample in samples: + for key in keys: + val = sample[key] + result[key].append(val) + + for key in keys: + val_list = result[key] + if isinstance(val_list[0], (int, float)): + if combine_scalars: + result[key] = np.array(result[key]) + + elif isinstance(val_list[0], torch.Tensor): + if combine_tensors: + result[key] = torch.stack(val_list) + + elif isinstance(val_list[0], np.ndarray): + if combine_tensors: + result[key] = np.stack(val_list) + + return result diff --git a/primitive_anything/michelangelo/graphics/__init__.py b/primitive_anything/michelangelo/graphics/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..40a96afc6ff09d58a702b76e3f7dd412fe975e26 --- /dev/null +++ b/primitive_anything/michelangelo/graphics/__init__.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- diff --git a/primitive_anything/michelangelo/graphics/primitives/__init__.py b/primitive_anything/michelangelo/graphics/primitives/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..cb910878f98a83209a41b562d339d12d39f42e89 --- /dev/null +++ b/primitive_anything/michelangelo/graphics/primitives/__init__.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- + +from .volume import generate_dense_grid_points + +from .mesh import ( + MeshOutput, + save_obj, + savemeshtes2 +) diff --git a/primitive_anything/michelangelo/graphics/primitives/mesh.py b/primitive_anything/michelangelo/graphics/primitives/mesh.py new file mode 100755 index 0000000000000000000000000000000000000000..3e5e8a551378b8e86d041967736cacaf904dbf54 --- /dev/null +++ b/primitive_anything/michelangelo/graphics/primitives/mesh.py @@ -0,0 +1,114 @@ +# -*- coding: utf-8 -*- + +import os +import cv2 +import numpy as np +import PIL.Image +from typing import Optional + +import trimesh + + +def save_obj(pointnp_px3, facenp_fx3, fname): + fid = open(fname, "w") + write_str = "" + for pidx, p in enumerate(pointnp_px3): + pp = p + write_str += "v %f %f %f\n" % (pp[0], pp[1], pp[2]) + + for i, f in enumerate(facenp_fx3): + f1 = f + 1 + write_str += "f %d %d %d\n" % (f1[0], f1[1], f1[2]) + fid.write(write_str) + fid.close() + return + + +def savemeshtes2(pointnp_px3, tcoords_px2, facenp_fx3, facetex_fx3, tex_map, fname): + fol, na = os.path.split(fname) + na, _ = os.path.splitext(na) + + matname = "%s/%s.mtl" % (fol, na) + fid = open(matname, "w") + fid.write("newmtl material_0\n") + fid.write("Kd 1 1 1\n") + fid.write("Ka 0 0 0\n") + fid.write("Ks 0.4 0.4 0.4\n") + fid.write("Ns 10\n") + fid.write("illum 2\n") + fid.write("map_Kd %s.png\n" % na) + fid.close() + #### + + fid = open(fname, "w") + fid.write("mtllib %s.mtl\n" % na) + + for pidx, p in enumerate(pointnp_px3): + pp = p + fid.write("v %f %f %f\n" % (pp[0], pp[1], pp[2])) + + for pidx, p in enumerate(tcoords_px2): + pp = p + fid.write("vt %f %f\n" % (pp[0], pp[1])) + + fid.write("usemtl material_0\n") + for i, f in enumerate(facenp_fx3): + f1 = f + 1 + f2 = facetex_fx3[i] + 1 + fid.write("f %d/%d %d/%d %d/%d\n" % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2])) + fid.close() + + PIL.Image.fromarray(np.ascontiguousarray(tex_map), "RGB").save( + os.path.join(fol, "%s.png" % na)) + + return + + +class MeshOutput(object): + + def __init__(self, + mesh_v: np.ndarray, + mesh_f: np.ndarray, + vertex_colors: Optional[np.ndarray] = None, + uvs: Optional[np.ndarray] = None, + mesh_tex_idx: Optional[np.ndarray] = None, + tex_map: Optional[np.ndarray] = None): + + self.mesh_v = mesh_v + self.mesh_f = mesh_f + self.vertex_colors = vertex_colors + self.uvs = uvs + self.mesh_tex_idx = mesh_tex_idx + self.tex_map = tex_map + + def contain_uv_texture(self): + return (self.uvs is not None) and (self.mesh_tex_idx is not None) and (self.tex_map is not None) + + def contain_vertex_colors(self): + return self.vertex_colors is not None + + def export(self, fname): + + if self.contain_uv_texture(): + savemeshtes2( + self.mesh_v, + self.uvs, + self.mesh_f, + self.mesh_tex_idx, + self.tex_map, + fname + ) + + elif self.contain_vertex_colors(): + mesh_obj = trimesh.Trimesh(vertices=self.mesh_v, faces=self.mesh_f, vertex_colors=self.vertex_colors) + mesh_obj.export(fname) + + else: + save_obj( + self.mesh_v, + self.mesh_f, + fname + ) + + + diff --git a/primitive_anything/michelangelo/graphics/primitives/volume.py b/primitive_anything/michelangelo/graphics/primitives/volume.py new file mode 100755 index 0000000000000000000000000000000000000000..e8cb1d3f41fd00d18af5a6c751d49c68770fe04a --- /dev/null +++ b/primitive_anything/michelangelo/graphics/primitives/volume.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- + +import numpy as np + + +def generate_dense_grid_points(bbox_min: np.ndarray, + bbox_max: np.ndarray, + octree_depth: int, + indexing: str = "ij"): + length = bbox_max - bbox_min + num_cells = np.exp2(octree_depth) + x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32) + y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32) + z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32) + [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing) + xyz = np.stack((xs, ys, zs), axis=-1) + xyz = xyz.reshape(-1, 3) + grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1] + + return xyz, grid_size, length + diff --git a/primitive_anything/michelangelo/models/__init__.py b/primitive_anything/michelangelo/models/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..40a96afc6ff09d58a702b76e3f7dd412fe975e26 --- /dev/null +++ b/primitive_anything/michelangelo/models/__init__.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- diff --git a/primitive_anything/michelangelo/models/asl_diffusion/__init__.py b/primitive_anything/michelangelo/models/asl_diffusion/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..40a96afc6ff09d58a702b76e3f7dd412fe975e26 --- /dev/null +++ b/primitive_anything/michelangelo/models/asl_diffusion/__init__.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- diff --git a/primitive_anything/michelangelo/models/asl_diffusion/asl_diffuser_pl_module.py b/primitive_anything/michelangelo/models/asl_diffusion/asl_diffuser_pl_module.py new file mode 100755 index 0000000000000000000000000000000000000000..73727356791812cb74bcc4610c345617ef48f04a --- /dev/null +++ b/primitive_anything/michelangelo/models/asl_diffusion/asl_diffuser_pl_module.py @@ -0,0 +1,483 @@ +# -*- coding: utf-8 -*- + +from omegaconf import DictConfig +from typing import List, Tuple, Dict, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.optim import lr_scheduler +import pytorch_lightning as pl +from pytorch_lightning.utilities import rank_zero_only + +from einops import rearrange + +from diffusers.schedulers import ( + DDPMScheduler, + DDIMScheduler, + KarrasVeScheduler, + DPMSolverMultistepScheduler +) + +from ...utils import instantiate_from_config +# from ..tsal.tsal_base import ShapeAsLatentPLModule +from ..tsal.tsal_base import AlignedShapeAsLatentPLModule +from .inference_utils import ddim_sample + +SchedulerType = Union[DDIMScheduler, KarrasVeScheduler, DPMSolverMultistepScheduler] + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class ASLDiffuser(pl.LightningModule): + first_stage_model: Optional[AlignedShapeAsLatentPLModule] + # cond_stage_model: Optional[Union[nn.Module, pl.LightningModule]] + model: nn.Module + + def __init__(self, *, + first_stage_config, + denoiser_cfg, + scheduler_cfg, + optimizer_cfg, + loss_cfg, + first_stage_key: str = "surface", + cond_stage_key: str = "image", + cond_stage_trainable: bool = True, + scale_by_std: bool = False, + z_scale_factor: float = 1.0, + ckpt_path: Optional[str] = None, + ignore_keys: Union[Tuple[str], List[str]] = ()): + + super().__init__() + + self.first_stage_key = first_stage_key + self.cond_stage_key = cond_stage_key + self.cond_stage_trainable = cond_stage_trainable + + # 1. initialize first stage. + # Note: the condition model contained in the first stage model. + self.first_stage_config = first_stage_config + self.first_stage_model = None + # self.instantiate_first_stage(first_stage_config) + + # 2. initialize conditional stage + # self.instantiate_cond_stage(cond_stage_config) + self.cond_stage_model = { + "image": self.encode_image, + "image_unconditional_embedding": self.empty_img_cond, + "text": self.encode_text, + "text_unconditional_embedding": self.empty_text_cond, + "surface": self.encode_surface, + "surface_unconditional_embedding": self.empty_surface_cond, + } + + # 3. diffusion model + self.model = instantiate_from_config( + denoiser_cfg, device=None, dtype=None + ) + + self.optimizer_cfg = optimizer_cfg + + # 4. scheduling strategy + self.scheduler_cfg = scheduler_cfg + + self.noise_scheduler: DDPMScheduler = instantiate_from_config(scheduler_cfg.noise) + self.denoise_scheduler: SchedulerType = instantiate_from_config(scheduler_cfg.denoise) + + # 5. loss configures + self.loss_cfg = loss_cfg + + self.scale_by_std = scale_by_std + if scale_by_std: + self.register_buffer("z_scale_factor", torch.tensor(z_scale_factor)) + else: + self.z_scale_factor = z_scale_factor + + self.ckpt_path = ckpt_path + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + def instantiate_first_stage(self, config): + model = instantiate_from_config(config) + self.first_stage_model = model.eval() + self.first_stage_model.train = disabled_train + for param in self.first_stage_model.parameters(): + param.requires_grad = False + + self.first_stage_model = self.first_stage_model.to(self.device) + + # def instantiate_cond_stage(self, config): + # if not self.cond_stage_trainable: + # if config == "__is_first_stage__": + # print("Using first stage also as cond stage.") + # self.cond_stage_model = self.first_stage_model + # elif config == "__is_unconditional__": + # print(f"Training {self.__class__.__name__} as an unconditional model.") + # self.cond_stage_model = None + # # self.be_unconditional = True + # else: + # model = instantiate_from_config(config) + # self.cond_stage_model = model.eval() + # self.cond_stage_model.train = disabled_train + # for param in self.cond_stage_model.parameters(): + # param.requires_grad = False + # else: + # assert config != "__is_first_stage__" + # assert config != "__is_unconditional__" + # model = instantiate_from_config(config) + # self.cond_stage_model = model + + def init_from_ckpt(self, path, ignore_keys=()): + state_dict = torch.load(path, map_location="cpu")["state_dict"] + + keys = list(state_dict.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del state_dict[k] + + missing, unexpected = self.load_state_dict(state_dict, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + print(f"Unexpected Keys: {unexpected}") + + @property + def zero_rank(self): + if self._trainer: + zero_rank = self.trainer.local_rank == 0 + else: + zero_rank = True + + return zero_rank + + def configure_optimizers(self) -> Tuple[List, List]: + + lr = self.learning_rate + + trainable_parameters = list(self.model.parameters()) + # if the conditional encoder is trainable + + # if self.cond_stage_trainable: + # conditioner_params = [p for p in self.cond_stage_model.parameters() if p.requires_grad] + # trainable_parameters += conditioner_params + # print(f"number of trainable conditional parameters: {len(conditioner_params)}.") + + if self.optimizer_cfg is None: + optimizers = [torch.optim.AdamW(trainable_parameters, lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)] + schedulers = [] + else: + optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=trainable_parameters) + scheduler_func = instantiate_from_config( + self.optimizer_cfg.scheduler, + max_decay_steps=self.trainer.max_steps, + lr_max=lr + ) + scheduler = { + "scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule), + "interval": "step", + "frequency": 1 + } + optimizers = [optimizer] + schedulers = [scheduler] + + return optimizers, schedulers + + @torch.no_grad() + def encode_text(self, text): + + b = text.shape[0] + text_tokens = rearrange(text, "b t l -> (b t) l") + text_embed = self.first_stage_model.model.encode_text_embed(text_tokens) + text_embed = rearrange(text_embed, "(b t) d -> b t d", b=b) + text_embed = text_embed.mean(dim=1) + text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True) + + return text_embed + + @torch.no_grad() + def encode_image(self, img): + + return self.first_stage_model.model.encode_image_embed(img) + + @torch.no_grad() + def encode_surface(self, surface): + + return self.first_stage_model.model.encode_shape_embed(surface, return_latents=False) + + @torch.no_grad() + def empty_text_cond(self, cond): + + return torch.zeros_like(cond, device=cond.device) + + @torch.no_grad() + def empty_img_cond(self, cond): + + return torch.zeros_like(cond, device=cond.device) + + @torch.no_grad() + def empty_surface_cond(self, cond): + + return torch.zeros_like(cond, device=cond.device) + + @torch.no_grad() + def encode_first_stage(self, surface: torch.FloatTensor, sample_posterior=True): + + z_q = self.first_stage_model.encode(surface, sample_posterior) + z_q = self.z_scale_factor * z_q + + return z_q + + @torch.no_grad() + def decode_first_stage(self, z_q: torch.FloatTensor, **kwargs): + + z_q = 1. / self.z_scale_factor * z_q + latents = self.first_stage_model.decode(z_q, **kwargs) + return latents + + @rank_zero_only + @torch.no_grad() + def on_train_batch_start(self, batch, batch_idx): + # only for very first batch + if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 \ + and batch_idx == 0 and self.ckpt_path is None: + # set rescale weight to 1./std of encodings + print("### USING STD-RESCALING ###") + + z_q = self.encode_first_stage(batch[self.first_stage_key]) + z = z_q.detach() + + del self.z_scale_factor + self.register_buffer("z_scale_factor", 1. / z.flatten().std()) + print(f"setting self.z_scale_factor to {self.z_scale_factor}") + + print("### USING STD-RESCALING ###") + + def compute_loss(self, model_outputs, split): + """ + + Args: + model_outputs (dict): + - x_0: + - noise: + - noise_prior: + - noise_pred: + - noise_pred_prior: + + split (str): + + Returns: + + """ + + pred = model_outputs["pred"] + + if self.noise_scheduler.prediction_type == "epsilon": + target = model_outputs["noise"] + elif self.noise_scheduler.prediction_type == "sample": + target = model_outputs["x_0"] + else: + raise NotImplementedError(f"Prediction Type: {self.noise_scheduler.prediction_type} not yet supported.") + + if self.loss_cfg.loss_type == "l1": + simple = F.l1_loss(pred, target, reduction="mean") + elif self.loss_cfg.loss_type in ["mse", "l2"]: + simple = F.mse_loss(pred, target, reduction="mean") + else: + raise NotImplementedError(f"Loss Type: {self.loss_cfg.loss_type} not yet supported.") + + total_loss = simple + + loss_dict = { + f"{split}/total_loss": total_loss.clone().detach(), + f"{split}/simple": simple.detach(), + } + + return total_loss, loss_dict + + def forward(self, batch): + """ + + Args: + batch: + + Returns: + + """ + + if self.first_stage_model is None: + self.instantiate_first_stage(self.first_stage_config) + + latents = self.encode_first_stage(batch[self.first_stage_key]) + + # conditions = self.cond_stage_model.encode(batch[self.cond_stage_key]) + + conditions = self.cond_stage_model[self.cond_stage_key](batch[self.cond_stage_key]).unsqueeze(1) + + mask = torch.rand((len(conditions), 1, 1), device=conditions.device, dtype=conditions.dtype) >= 0.1 + conditions = conditions * mask.to(conditions) + + # Sample noise that we"ll add to the latents + # [batch_size, n_token, latent_dim] + noise = torch.randn_like(latents) + bs = latents.shape[0] + # Sample a random timestep for each motion + timesteps = torch.randint( + 0, + self.noise_scheduler.config.num_train_timesteps, + (bs,), + device=latents.device, + ) + timesteps = timesteps.long() + # Add noise to the latents according to the noise magnitude at each timestep + noisy_z = self.noise_scheduler.add_noise(latents, noise, timesteps) + + # diffusion model forward + noise_pred = self.model(noisy_z, timesteps, conditions) + + diffusion_outputs = { + "x_0": noisy_z, + "noise": noise, + "pred": noise_pred + } + + return diffusion_outputs + + def training_step(self, batch: Dict[str, Union[torch.FloatTensor, List[str]]], + batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor: + """ + + Args: + batch (dict): the batch sample, and it contains: + - surface (torch.FloatTensor): + - image (torch.FloatTensor): if provide, [bs, 3, h, w], item range [0, 1] + - depth (torch.FloatTensor): if provide, [bs, 1, h, w], item range [-1, 1] + - normal (torch.FloatTensor): if provide, [bs, 3, h, w], item range [-1, 1] + - text (list of str): + + batch_idx (int): + + optimizer_idx (int): + + Returns: + loss (torch.FloatTensor): + + """ + + diffusion_outputs = self(batch) + + loss, loss_dict = self.compute_loss(diffusion_outputs, "train") + self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True) + + return loss + + def validation_step(self, batch: Dict[str, torch.FloatTensor], + batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor: + """ + + Args: + batch (dict): the batch sample, and it contains: + - surface_pc (torch.FloatTensor): [n_pts, 4] + - surface_feats (torch.FloatTensor): [n_pts, c] + - text (list of str): + + batch_idx (int): + + optimizer_idx (int): + + Returns: + loss (torch.FloatTensor): + + """ + + diffusion_outputs = self(batch) + + loss, loss_dict = self.compute_loss(diffusion_outputs, "val") + self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True) + + return loss + + @torch.no_grad() + def sample(self, + batch: Dict[str, Union[torch.FloatTensor, List[str]]], + sample_times: int = 1, + steps: Optional[int] = None, + guidance_scale: Optional[float] = None, + eta: float = 0.0, + return_intermediates: bool = False, **kwargs): + + if self.first_stage_model is None: + self.instantiate_first_stage(self.first_stage_config) + + if steps is None: + steps = self.scheduler_cfg.num_inference_steps + + if guidance_scale is None: + guidance_scale = self.scheduler_cfg.guidance_scale + do_classifier_free_guidance = guidance_scale > 0 + + # conditional encode + xc = batch[self.cond_stage_key] + # cond = self.cond_stage_model[self.cond_stage_key](xc) + cond = self.cond_stage_model[self.cond_stage_key](xc).unsqueeze(1) + + if do_classifier_free_guidance: + """ + Note: There are two kinds of uncond for text. + 1: using "" as uncond text; (in SAL diffusion) + 2: zeros_like(cond) as uncond text; (in MDM) + """ + # un_cond = self.cond_stage_model.unconditional_embedding(batch_size=len(xc)) + un_cond = self.cond_stage_model[f"{self.cond_stage_key}_unconditional_embedding"](cond) + # un_cond = torch.zeros_like(cond, device=cond.device) + cond = torch.cat([un_cond, cond], dim=0) + + outputs = [] + latents = None + + if not return_intermediates: + for _ in range(sample_times): + sample_loop = ddim_sample( + self.denoise_scheduler, + self.model, + shape=self.first_stage_model.latent_shape, + cond=cond, + steps=steps, + guidance_scale=guidance_scale, + do_classifier_free_guidance=do_classifier_free_guidance, + device=self.device, + eta=eta, + disable_prog=not self.zero_rank + ) + for sample, t in sample_loop: + latents = sample + outputs.append(self.decode_first_stage(latents, **kwargs)) + else: + + sample_loop = ddim_sample( + self.denoise_scheduler, + self.model, + shape=self.first_stage_model.latent_shape, + cond=cond, + steps=steps, + guidance_scale=guidance_scale, + do_classifier_free_guidance=do_classifier_free_guidance, + device=self.device, + eta=eta, + disable_prog=not self.zero_rank + ) + + iter_size = steps // sample_times + i = 0 + for sample, t in sample_loop: + latents = sample + if i % iter_size == 0 or i == steps - 1: + outputs.append(self.decode_first_stage(latents, **kwargs)) + i += 1 + + return outputs diff --git a/primitive_anything/michelangelo/models/asl_diffusion/asl_udt.py b/primitive_anything/michelangelo/models/asl_diffusion/asl_udt.py new file mode 100755 index 0000000000000000000000000000000000000000..a89dbc544a56004a4604fae9992d4134c274d392 --- /dev/null +++ b/primitive_anything/michelangelo/models/asl_diffusion/asl_udt.py @@ -0,0 +1,104 @@ +# -*- coding: utf-8 -*- + +import torch +import torch.nn as nn +from typing import Optional +from diffusers.models.embeddings import Timesteps +import math + +from ..modules.transformer_blocks import MLP +from ..modules.diffusion_transformer import UNetDiffusionTransformer + + +class ConditionalASLUDTDenoiser(nn.Module): + + def __init__(self, *, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + input_channels: int, + output_channels: int, + n_ctx: int, + width: int, + layers: int, + heads: int, + context_dim: int, + context_ln: bool = True, + skip_ln: bool = False, + init_scale: float = 0.25, + flip_sin_to_cos: bool = False, + use_checkpoint: bool = False): + super().__init__() + + self.use_checkpoint = use_checkpoint + + init_scale = init_scale * math.sqrt(1.0 / width) + + self.backbone = UNetDiffusionTransformer( + device=device, + dtype=dtype, + n_ctx=n_ctx, + width=width, + layers=layers, + heads=heads, + skip_ln=skip_ln, + init_scale=init_scale, + use_checkpoint=use_checkpoint + ) + self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype) + self.input_proj = nn.Linear(input_channels, width, device=device, dtype=dtype) + self.output_proj = nn.Linear(width, output_channels, device=device, dtype=dtype) + + # timestep embedding + self.time_embed = Timesteps(width, flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=0) + self.time_proj = MLP( + device=device, dtype=dtype, width=width, init_scale=init_scale + ) + + self.context_embed = nn.Sequential( + nn.LayerNorm(context_dim, device=device, dtype=dtype), + nn.Linear(context_dim, width, device=device, dtype=dtype), + ) + + if context_ln: + self.context_embed = nn.Sequential( + nn.LayerNorm(context_dim, device=device, dtype=dtype), + nn.Linear(context_dim, width, device=device, dtype=dtype), + ) + else: + self.context_embed = nn.Linear(context_dim, width, device=device, dtype=dtype) + + def forward(self, + model_input: torch.FloatTensor, + timestep: torch.LongTensor, + context: torch.FloatTensor): + + r""" + Args: + model_input (torch.FloatTensor): [bs, n_data, c] + timestep (torch.LongTensor): [bs,] + context (torch.FloatTensor): [bs, context_tokens, c] + + Returns: + sample (torch.FloatTensor): [bs, n_data, c] + + """ + + _, n_data, _ = model_input.shape + + # 1. time + t_emb = self.time_proj(self.time_embed(timestep)).unsqueeze(dim=1) + + # 2. conditions projector + context = self.context_embed(context) + + # 3. denoiser + x = self.input_proj(model_input) + x = torch.cat([t_emb, context, x], dim=1) + x = self.backbone(x) + x = self.ln_post(x) + x = x[:, -n_data:] + sample = self.output_proj(x) + + return sample + + diff --git a/primitive_anything/michelangelo/models/asl_diffusion/base.py b/primitive_anything/michelangelo/models/asl_diffusion/base.py new file mode 100755 index 0000000000000000000000000000000000000000..a979197ae9990929aecbca42ce081a2b1aa1f465 --- /dev/null +++ b/primitive_anything/michelangelo/models/asl_diffusion/base.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- + +import torch +import torch.nn as nn + + +class BaseDenoiser(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, x, t, context): + raise NotImplementedError diff --git a/primitive_anything/michelangelo/models/asl_diffusion/clip_asl_diffuser_pl_module.py b/primitive_anything/michelangelo/models/asl_diffusion/clip_asl_diffuser_pl_module.py new file mode 100755 index 0000000000000000000000000000000000000000..c8bff3fe2d8ad7129e4a58ba36e10318648eca68 --- /dev/null +++ b/primitive_anything/michelangelo/models/asl_diffusion/clip_asl_diffuser_pl_module.py @@ -0,0 +1,393 @@ +# -*- coding: utf-8 -*- + +from omegaconf import DictConfig +from typing import List, Tuple, Dict, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.optim import lr_scheduler +import pytorch_lightning as pl +from pytorch_lightning.utilities import rank_zero_only + +from diffusers.schedulers import ( + DDPMScheduler, + DDIMScheduler, + KarrasVeScheduler, + DPMSolverMultistepScheduler +) + +from ...utils import instantiate_from_config +from ..tsal.tsal_base import AlignedShapeAsLatentPLModule +from .inference_utils import ddim_sample + +SchedulerType = Union[DDIMScheduler, KarrasVeScheduler, DPMSolverMultistepScheduler] + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class ClipASLDiffuser(pl.LightningModule): + first_stage_model: Optional[AlignedShapeAsLatentPLModule] + cond_stage_model: Optional[Union[nn.Module, pl.LightningModule]] + model: nn.Module + + def __init__(self, *, + first_stage_config, + cond_stage_config, + denoiser_cfg, + scheduler_cfg, + optimizer_cfg, + loss_cfg, + first_stage_key: str = "surface", + cond_stage_key: str = "image", + scale_by_std: bool = False, + z_scale_factor: float = 1.0, + ckpt_path: Optional[str] = None, + ignore_keys: Union[Tuple[str], List[str]] = ()): + + super().__init__() + + self.first_stage_key = first_stage_key + self.cond_stage_key = cond_stage_key + + # 1. lazy initialize first stage + self.instantiate_first_stage(first_stage_config) + + # 2. initialize conditional stage + self.instantiate_cond_stage(cond_stage_config) + + # 3. diffusion model + self.model = instantiate_from_config( + denoiser_cfg, device=None, dtype=None + ) + + self.optimizer_cfg = optimizer_cfg + + # 4. scheduling strategy + self.scheduler_cfg = scheduler_cfg + + self.noise_scheduler: DDPMScheduler = instantiate_from_config(scheduler_cfg.noise) + self.denoise_scheduler: SchedulerType = instantiate_from_config(scheduler_cfg.denoise) + + # 5. loss configures + self.loss_cfg = loss_cfg + + self.scale_by_std = scale_by_std + if scale_by_std: + self.register_buffer("z_scale_factor", torch.tensor(z_scale_factor)) + else: + self.z_scale_factor = z_scale_factor + + self.ckpt_path = ckpt_path + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + def instantiate_non_trainable_model(self, config): + model = instantiate_from_config(config) + model = model.eval() + model.train = disabled_train + for param in model.parameters(): + param.requires_grad = False + + return model + + def instantiate_first_stage(self, first_stage_config): + self.first_stage_model = self.instantiate_non_trainable_model(first_stage_config) + self.first_stage_model.set_shape_model_only() + + def instantiate_cond_stage(self, cond_stage_config): + self.cond_stage_model = self.instantiate_non_trainable_model(cond_stage_config) + + def init_from_ckpt(self, path, ignore_keys=()): + state_dict = torch.load(path, map_location="cpu")["state_dict"] + + keys = list(state_dict.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del state_dict[k] + + missing, unexpected = self.load_state_dict(state_dict, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + print(f"Unexpected Keys: {unexpected}") + + @property + def zero_rank(self): + if self._trainer: + zero_rank = self.trainer.local_rank == 0 + else: + zero_rank = True + + return zero_rank + + def configure_optimizers(self) -> Tuple[List, List]: + + lr = self.learning_rate + + trainable_parameters = list(self.model.parameters()) + if self.optimizer_cfg is None: + optimizers = [torch.optim.AdamW(trainable_parameters, lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)] + schedulers = [] + else: + optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=trainable_parameters) + scheduler_func = instantiate_from_config( + self.optimizer_cfg.scheduler, + max_decay_steps=self.trainer.max_steps, + lr_max=lr + ) + scheduler = { + "scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule), + "interval": "step", + "frequency": 1 + } + optimizers = [optimizer] + schedulers = [scheduler] + + return optimizers, schedulers + + @torch.no_grad() + def encode_first_stage(self, surface: torch.FloatTensor, sample_posterior=True): + + z_q = self.first_stage_model.encode(surface, sample_posterior) + z_q = self.z_scale_factor * z_q + + return z_q + + @torch.no_grad() + def decode_first_stage(self, z_q: torch.FloatTensor, **kwargs): + + z_q = 1. / self.z_scale_factor * z_q + latents = self.first_stage_model.decode(z_q, **kwargs) + return latents + + @rank_zero_only + @torch.no_grad() + def on_train_batch_start(self, batch, batch_idx): + # only for very first batch + if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 \ + and batch_idx == 0 and self.ckpt_path is None: + # set rescale weight to 1./std of encodings + print("### USING STD-RESCALING ###") + + z_q = self.encode_first_stage(batch[self.first_stage_key]) + z = z_q.detach() + + del self.z_scale_factor + self.register_buffer("z_scale_factor", 1. / z.flatten().std()) + print(f"setting self.z_scale_factor to {self.z_scale_factor}") + + print("### USING STD-RESCALING ###") + + def compute_loss(self, model_outputs, split): + """ + + Args: + model_outputs (dict): + - x_0: + - noise: + - noise_prior: + - noise_pred: + - noise_pred_prior: + + split (str): + + Returns: + + """ + + pred = model_outputs["pred"] + + if self.noise_scheduler.prediction_type == "epsilon": + target = model_outputs["noise"] + elif self.noise_scheduler.prediction_type == "sample": + target = model_outputs["x_0"] + else: + raise NotImplementedError(f"Prediction Type: {self.noise_scheduler.prediction_type} not yet supported.") + + if self.loss_cfg.loss_type == "l1": + simple = F.l1_loss(pred, target, reduction="mean") + elif self.loss_cfg.loss_type in ["mse", "l2"]: + simple = F.mse_loss(pred, target, reduction="mean") + else: + raise NotImplementedError(f"Loss Type: {self.loss_cfg.loss_type} not yet supported.") + + total_loss = simple + + loss_dict = { + f"{split}/total_loss": total_loss.clone().detach(), + f"{split}/simple": simple.detach(), + } + + return total_loss, loss_dict + + def forward(self, batch): + """ + + Args: + batch: + + Returns: + + """ + + latents = self.encode_first_stage(batch[self.first_stage_key]) + conditions = self.cond_stage_model.encode(batch[self.cond_stage_key]) + + # Sample noise that we"ll add to the latents + # [batch_size, n_token, latent_dim] + noise = torch.randn_like(latents) + bs = latents.shape[0] + # Sample a random timestep for each motion + timesteps = torch.randint( + 0, + self.noise_scheduler.config.num_train_timesteps, + (bs,), + device=latents.device, + ) + timesteps = timesteps.long() + # Add noise to the latents according to the noise magnitude at each timestep + noisy_z = self.noise_scheduler.add_noise(latents, noise, timesteps) + + # diffusion model forward + noise_pred = self.model(noisy_z, timesteps, conditions) + + diffusion_outputs = { + "x_0": noisy_z, + "noise": noise, + "pred": noise_pred + } + + return diffusion_outputs + + def training_step(self, batch: Dict[str, Union[torch.FloatTensor, List[str]]], + batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor: + """ + + Args: + batch (dict): the batch sample, and it contains: + - surface (torch.FloatTensor): + - image (torch.FloatTensor): if provide, [bs, 3, h, w], item range [0, 1] + - depth (torch.FloatTensor): if provide, [bs, 1, h, w], item range [-1, 1] + - normal (torch.FloatTensor): if provide, [bs, 3, h, w], item range [-1, 1] + - text (list of str): + + batch_idx (int): + + optimizer_idx (int): + + Returns: + loss (torch.FloatTensor): + + """ + + diffusion_outputs = self(batch) + + loss, loss_dict = self.compute_loss(diffusion_outputs, "train") + self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True) + + return loss + + def validation_step(self, batch: Dict[str, torch.FloatTensor], + batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor: + """ + + Args: + batch (dict): the batch sample, and it contains: + - surface_pc (torch.FloatTensor): [n_pts, 4] + - surface_feats (torch.FloatTensor): [n_pts, c] + - text (list of str): + + batch_idx (int): + + optimizer_idx (int): + + Returns: + loss (torch.FloatTensor): + + """ + + diffusion_outputs = self(batch) + + loss, loss_dict = self.compute_loss(diffusion_outputs, "val") + self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True) + + return loss + + @torch.no_grad() + def sample(self, + batch: Dict[str, Union[torch.FloatTensor, List[str]]], + sample_times: int = 1, + steps: Optional[int] = None, + guidance_scale: Optional[float] = None, + eta: float = 0.0, + return_intermediates: bool = False, **kwargs): + + if steps is None: + steps = self.scheduler_cfg.num_inference_steps + + if guidance_scale is None: + guidance_scale = self.scheduler_cfg.guidance_scale + do_classifier_free_guidance = guidance_scale > 0 + + # conditional encode + xc = batch[self.cond_stage_key] + + # print(self.first_stage_model.device, self.cond_stage_model.device, self.device) + + cond = self.cond_stage_model(xc) + + if do_classifier_free_guidance: + un_cond = self.cond_stage_model.unconditional_embedding(batch_size=len(xc)) + cond = torch.cat([un_cond, cond], dim=0) + + outputs = [] + latents = None + + if not return_intermediates: + for _ in range(sample_times): + sample_loop = ddim_sample( + self.denoise_scheduler, + self.model, + shape=self.first_stage_model.latent_shape, + cond=cond, + steps=steps, + guidance_scale=guidance_scale, + do_classifier_free_guidance=do_classifier_free_guidance, + device=self.device, + eta=eta, + disable_prog=not self.zero_rank + ) + for sample, t in sample_loop: + latents = sample + outputs.append(self.decode_first_stage(latents, **kwargs)) + else: + + sample_loop = ddim_sample( + self.denoise_scheduler, + self.model, + shape=self.first_stage_model.latent_shape, + cond=cond, + steps=steps, + guidance_scale=guidance_scale, + do_classifier_free_guidance=do_classifier_free_guidance, + device=self.device, + eta=eta, + disable_prog=not self.zero_rank + ) + + iter_size = steps // sample_times + i = 0 + for sample, t in sample_loop: + latents = sample + if i % iter_size == 0 or i == steps - 1: + outputs.append(self.decode_first_stage(latents, **kwargs)) + i += 1 + + return outputs diff --git a/primitive_anything/michelangelo/models/asl_diffusion/inference_utils.py b/primitive_anything/michelangelo/models/asl_diffusion/inference_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..967d5c52a8e33a6759d1c4891b0d21d1c9f95442 --- /dev/null +++ b/primitive_anything/michelangelo/models/asl_diffusion/inference_utils.py @@ -0,0 +1,80 @@ +# -*- coding: utf-8 -*- + +import torch +from tqdm import tqdm +from typing import Tuple, List, Union, Optional +from diffusers.schedulers import DDIMScheduler + + +__all__ = ["ddim_sample"] + + +def ddim_sample(ddim_scheduler: DDIMScheduler, + diffusion_model: torch.nn.Module, + shape: Union[List[int], Tuple[int]], + cond: torch.FloatTensor, + steps: int, + eta: float = 0.0, + guidance_scale: float = 3.0, + do_classifier_free_guidance: bool = True, + generator: Optional[torch.Generator] = None, + device: torch.device = "cuda:0", + disable_prog: bool = True): + + assert steps > 0, f"{steps} must > 0." + + # init latents + bsz = cond.shape[0] + if do_classifier_free_guidance: + bsz = bsz // 2 + + latents = torch.randn( + (bsz, *shape), + generator=generator, + device=cond.device, + dtype=cond.dtype, + ) + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * ddim_scheduler.init_noise_sigma + # set timesteps + ddim_scheduler.set_timesteps(steps) + timesteps = ddim_scheduler.timesteps.to(device) + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, and between [0, 1] + extra_step_kwargs = { + "eta": eta, + "generator": generator + } + + # reverse + for i, t in enumerate(tqdm(timesteps, disable=disable_prog, desc="DDIM Sampling:", leave=False)): + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + torch.cat([latents] * 2) + if do_classifier_free_guidance + else latents + ) + # latent_model_input = scheduler.scale_model_input(latent_model_input, t) + # predict the noise residual + timestep_tensor = torch.tensor([t], dtype=torch.long, device=device) + timestep_tensor = timestep_tensor.expand(latent_model_input.shape[0]) + noise_pred = diffusion_model.forward(latent_model_input, timestep_tensor, cond) + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + # text_embeddings_for_guidance = encoder_hidden_states.chunk( + # 2)[1] if do_classifier_free_guidance else encoder_hidden_states + # compute the previous noisy sample x_t -> x_t-1 + latents = ddim_scheduler.step( + noise_pred, t, latents, **extra_step_kwargs + ).prev_sample + + yield latents, t + + +def karra_sample(): + pass diff --git a/primitive_anything/michelangelo/models/conditional_encoders/__init__.py b/primitive_anything/michelangelo/models/conditional_encoders/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..f644ce0eac101dbd60ffdb0225a7560a5dc25735 --- /dev/null +++ b/primitive_anything/michelangelo/models/conditional_encoders/__init__.py @@ -0,0 +1,3 @@ +# -*- coding: utf-8 -*- + +from .clip import CLIPEncoder diff --git a/primitive_anything/michelangelo/models/conditional_encoders/clip.py b/primitive_anything/michelangelo/models/conditional_encoders/clip.py new file mode 100755 index 0000000000000000000000000000000000000000..099b237d543981cca70f92ccbbb0c1c560aa0f2a --- /dev/null +++ b/primitive_anything/michelangelo/models/conditional_encoders/clip.py @@ -0,0 +1,89 @@ +# -*- coding: utf-8 -*- + +import torch +import numpy as np +from PIL import Image +from dataclasses import dataclass +from torchvision.transforms import Normalize +from transformers import CLIPModel, CLIPTokenizer +from transformers.utils import ModelOutput +from typing import Iterable, Optional, Union, List + + +ImageType = Union[np.ndarray, torch.Tensor, Image.Image] + + +@dataclass +class CLIPEmbedOutput(ModelOutput): + last_hidden_state: torch.FloatTensor = None + pooler_output: torch.FloatTensor = None + embeds: torch.FloatTensor = None + + +class CLIPEncoder(torch.nn.Module): + + def __init__(self, model_path="openai/clip-vit-base-patch32"): + + super().__init__() + + # Load the CLIP model and processor + self.model: CLIPModel = CLIPModel.from_pretrained(model_path) + self.tokenizer = CLIPTokenizer.from_pretrained(model_path) + self.image_preprocess = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + self.model.training = False + for p in self.model.parameters(): + p.requires_grad = False + + @torch.no_grad() + def encode_image(self, images: Iterable[Optional[ImageType]]): + pixel_values = self.image_preprocess(images) + + vision_outputs = self.model.vision_model(pixel_values=pixel_values) + + pooler_output = vision_outputs[1] # pooled_output + image_features = self.model.visual_projection(pooler_output) + + visual_embeds = CLIPEmbedOutput( + last_hidden_state=vision_outputs.last_hidden_state, + pooler_output=pooler_output, + embeds=image_features + ) + + return visual_embeds + + @torch.no_grad() + def encode_text(self, texts: List[str]): + text_inputs = self.tokenizer(texts, padding=True, return_tensors="pt") + + text_outputs = self.model.text_model(input_ids=text_inputs) + + pooler_output = text_outputs[1] # pooled_output + text_features = self.model.text_projection(pooler_output) + + text_embeds = CLIPEmbedOutput( + last_hidden_state=text_outputs.last_hidden_state, + pooler_output=pooler_output, + embeds=text_features + ) + + return text_embeds + + def forward(self, + images: Iterable[Optional[ImageType]], + texts: List[str]): + + visual_embeds = self.encode_image(images) + text_embeds = self.encode_text(texts) + + return visual_embeds, text_embeds + + + + + + + + + + diff --git a/primitive_anything/michelangelo/models/conditional_encoders/encoder_factory.py b/primitive_anything/michelangelo/models/conditional_encoders/encoder_factory.py new file mode 100755 index 0000000000000000000000000000000000000000..0556f11c04e8c71c2d96be8eb11717d8f669ee7d --- /dev/null +++ b/primitive_anything/michelangelo/models/conditional_encoders/encoder_factory.py @@ -0,0 +1,562 @@ +# -*- coding: utf-8 -*- +import os + +import torch +import torch.nn as nn +from torchvision import transforms +from transformers import CLIPModel, CLIPTokenizer +from collections import OrderedDict + +from ...data.transforms import RandomResize + + +class AbstractEncoder(nn.Module): + embedding_dim: int + + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + +class ClassEmbedder(nn.Module): + def __init__(self, embed_dim, n_classes=1000, key="class"): + super().__init__() + self.key = key + self.embedding = nn.Embedding(n_classes, embed_dim) + + def forward(self, batch, key=None): + if key is None: + key = self.key + # this is for use in crossattn + c = batch[key][:, None] + c = self.embedding(c) + return c + + +class FrozenCLIPTextEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from Hugging Face)""" + + def __init__( + self, + version="openai/clip-vit-large-patch14", + tokenizer_version=None, + device="cuda", + max_length=77, + zero_embedding_radio: float = 0.1, + ): + super().__init__() + self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_version or version) + + self.device = device + self.max_length = max_length + self.zero_embedding_radio = zero_embedding_radio + + self.clip_dict = OrderedDict() + self.clip_name = os.path.split(version)[-1] + + transformer = CLIPModel.from_pretrained(version).text_model + + for param in transformer.parameters(): + param.requires_grad = False + self.clip_dict[self.clip_name] = transformer + + self._move_flag = False + + @property + def clip(self): + return self.clip_dict[self.clip_name] + + def move(self): + if self._move_flag: + return + + self.clip_dict[self.clip_name] = self.clip_dict[self.clip_name].to(self.device) + self._move_flag = True + + def unconditional_embedding(self, batch_size): + empty_text = [""] * batch_size + empty_z = self.forward(empty_text) + return empty_z + + def forward(self, text): + self.move() + + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.clip(input_ids=tokens) + + z = outputs.last_hidden_state + return z + + def encode(self, text): + batch_size = len(text) + batch_mask = torch.rand((batch_size,)) + for i in range(batch_size): + if batch_mask[i] < self.zero_embedding_radio: + text[i] = "" + + return self(text) + +class FrozenAlignedCLIPTextEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from Hugging Face)""" + + def __init__( + self, + version="openai/clip-vit-large-patch14", + tokenizer_version=None, + device="cuda", + max_length=77, + zero_embedding_radio: float = 0.1, + ): + super().__init__() + self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_version or version) + + self.device = device + self.max_length = max_length + self.zero_embedding_radio = zero_embedding_radio + + self.clip_dict = OrderedDict() + self.clip_name = os.path.split(version)[-1] + + transformer = CLIPModel.from_pretrained(version).text_model + + for param in transformer.parameters(): + param.requires_grad = False + self.clip_dict[self.clip_name] = transformer + + self._move_flag = False + + @property + def clip(self): + return self.clip_dict[self.clip_name] + + def move(self): + if self._move_flag: + return + + self.clip_dict[self.clip_name] = self.clip_dict[self.clip_name].to(self.device) + self._move_flag = True + + def unconditional_embedding(self, batch_size): + empty_text = [""] * batch_size + empty_z = self.forward(empty_text) + return empty_z + + def forward(self, text): + self.move() + + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.clip(input_ids=tokens) + + z = outputs.last_hidden_state + return z + + def encode(self, text): + batch_size = len(text) + batch_mask = torch.rand((batch_size,)) + for i in range(batch_size): + if batch_mask[i] < self.zero_embedding_radio: + text[i] = "" + + return self(text) + + +class FrozenCLIPImageEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from Hugging Face)""" + + def __init__( + self, + version="openai/clip-vit-large-patch14", + device="cuda", + zero_embedding_radio=0.1, + normalize_embedding=True, + num_projection_vector=0, + linear_mapping_bias=True, + reverse_visual_projection=False, + ): + super().__init__() + + self.device = device + + self.clip_dict = OrderedDict() + self.clip_name = os.path.split(version)[-1] + + clip_model = CLIPModel.from_pretrained(version) + clip_model.text_model = None + clip_model.text_projection = None + clip_model = clip_model.eval() + for param in self.parameters(): + param.requires_grad = False + self.clip_dict[self.clip_name] = clip_model + + self.transform = transforms.Compose( + [ + transforms.Resize(224, transforms.InterpolationMode.BICUBIC, antialias=True), + transforms.CenterCrop(224), # crop a (224, 224) square + transforms.Normalize( + mean=[0.48145466, 0.4578275, 0.40821073], + std=[0.26862954, 0.26130258, 0.27577711], + ), + ] + ) + self.zero_embedding_radio = zero_embedding_radio + + self.num_projection_vector = num_projection_vector + self.reverse_visual_projection = reverse_visual_projection + self.normalize_embedding = normalize_embedding + + embedding_dim = ( + clip_model.visual_projection.in_features + if reverse_visual_projection + else clip_model.visual_projection.out_features + ) + self.embedding_dim = embedding_dim + if self.num_projection_vector > 0: + self.projection = nn.Linear( + embedding_dim, + clip_model.visual_projection.out_features * num_projection_vector, + bias=linear_mapping_bias, + ) + nn.init.normal_(self.projection.weight, std=embedding_dim ** -0.5) + + self._move_flag = False + + @property + def clip(self): + return self.clip_dict[self.clip_name] + + def unconditional_embedding(self, batch_size): + zero = torch.zeros( + batch_size, + 1, + self.embedding_dim, + device=self.device, + dtype=self.clip.visual_projection.weight.dtype, + ) + if self.num_projection_vector > 0: + zero = self.projection(zero).view(batch_size, self.num_projection_vector, -1) + return zero + + def forward(self, image, value_range=(-1, 1), zero_embedding_radio=0): + if value_range is not None: + low, high = value_range + image = (image - low) / (high - low) + + image = image.to(self.device, dtype=self.clip.visual_projection.weight.dtype) + + if self.reverse_visual_projection: + z = self.clip.vision_model(self.transform(image))[1] + else: + z = self.clip.get_image_features(self.transform(image)) + + if self.normalize_embedding: + z = z / z.norm(dim=-1, keepdim=True) + if z.ndim == 2: + z = z.unsqueeze(dim=-2) + + if zero_embedding_radio > 0: + mask = torch.rand((len(image), 1, 1), device=z.device, dtype=z.dtype) < zero_embedding_radio + z = z * mask.to(z) + + if self.num_projection_vector > 0: + z = self.projection(z).view(len(image), self.num_projection_vector, -1) + + return z + + def move(self): + if self._move_flag: + return + + self.clip_dict[self.clip_name] = self.clip_dict[self.clip_name].to(self.device) + self._move_flag = True + + def encode(self, image): + self.move() + return self(image, zero_embedding_radio=self.zero_embedding_radio) + + +class FrozenCLIPImageGridEmbedder(AbstractEncoder): + + def __init__( + self, + version="openai/clip-vit-large-patch14", + device="cuda", + zero_embedding_radio=0.1, + ): + super().__init__() + + self.device = device + + self.clip_dict = OrderedDict() + self.clip_name = os.path.split(version)[-1] + + clip_model: CLIPModel = CLIPModel.from_pretrained(version) + clip_model.text_model = None + clip_model.text_projection = None + clip_model = clip_model.eval() + for param in self.parameters(): + param.requires_grad = False + self.clip_dict[self.clip_name] = clip_model + + self.transform = transforms.Compose( + [ + transforms.Resize(224, transforms.InterpolationMode.BILINEAR, antialias=True), + transforms.CenterCrop(224), # crop a (224, 224) square + transforms.Normalize( + mean=[0.48145466, 0.4578275, 0.40821073], + std=[0.26862954, 0.26130258, 0.27577711], + ), + ] + ) + self.zero_embedding_radio = zero_embedding_radio + self.embedding_dim = clip_model.vision_embed_dim + + self._move_flag = False + + @property + def clip(self): + return self.clip_dict[self.clip_name] + + def move(self): + if self._move_flag: + return + + self.clip_dict[self.clip_name] = self.clip_dict[self.clip_name].to(self.device) + self._move_flag = True + + def unconditional_embedding(self, batch_size): + zero = torch.zeros( + batch_size, + self.clip.vision_model.embeddings.num_positions, + self.embedding_dim, + device=self.device, + dtype=self.clip.visual_projection.weight.dtype, + ) + return zero + + def forward(self, image, value_range=(-1, 1), zero_embedding_radio=0): + self.move() + + if value_range is not None: + low, high = value_range + image = (image - low) / (high - low) + + image = image.to(self.device, dtype=self.clip.visual_projection.weight.dtype) + + z = self.clip.vision_model(self.transform(image)).last_hidden_state + + if zero_embedding_radio > 0: + mask = torch.rand((len(image), 1, 1), device=z.device, dtype=z.dtype) >= zero_embedding_radio + z = z * mask.to(z) + + return z + + def encode(self, image): + return self(image, zero_embedding_radio=self.zero_embedding_radio) + + +class MoECLIPImageEncoder(nn.Module): + def __init__( + self, + versions, + hidden_state_dim, + num_projection_vector=8, + zero_embedding_radio=0.1, + device="cuda", + precision="fp16", + normalize=False, + clip_max=0, + transform_type="base", + argument_p=0.2, + ): + super().__init__() + + self.device = torch.device(device) + self.hidden_state_dim = hidden_state_dim + self.zero_embedding_radio = zero_embedding_radio + self.num_projection_vector = num_projection_vector + self.dtype = dict(fp16=torch.float16, fp32=torch.float32, bf16=torch.bfloat16)[precision] + self.normalize = normalize + self.clip_max = clip_max + + if transform_type == "base": + self.transform = transforms.Compose( + [ + transforms.Resize(224, transforms.InterpolationMode.BICUBIC, antialias=True), + transforms.CenterCrop(224), # crop a (224, 224) square + transforms.Normalize( + mean=[0.48145466, 0.4578275, 0.40821073], + std=[0.26862954, 0.26130258, 0.27577711], + ), + ] + ) + elif transform_type == "crop_blur_resize": + self.transform = transforms.Compose( + [ + transforms.Resize(224, transforms.InterpolationMode.BICUBIC, antialias=True), + transforms.CenterCrop(224), # crop a (224, 224) square + transforms.RandomApply( + transforms=[ + transforms.RandomResizedCrop( + size=224, + scale=(0.8, 1.0), + ratio=(0.99, 1.01), + interpolation=transforms.InterpolationMode.BICUBIC, + ), + ], + p=argument_p, + ), + transforms.RandomApply( + transforms=[ + transforms.GaussianBlur(kernel_size=9, sigma=(0.1, 5)), + ], + p=argument_p, + ), + transforms.RandomApply( + transforms=[ + RandomResize(size=224, resize_radio=(0.2, 1)), + ], + p=argument_p, + ), + transforms.Normalize( + mean=[0.48145466, 0.4578275, 0.40821073], + std=[0.26862954, 0.26130258, 0.27577711], + ), + ] + ) + else: + raise ValueError(f"invalid {transform_type=}") + + if isinstance(versions, str): + versions = (versions,) + + # 如果直接把clips定位为当前类的子module,1. 会在保存ckp时存无用的多个权重。 2. pl会调用to,导致layer_norm的权重也被转换成fp16 + clips = OrderedDict() + + for v in versions: + # 因为clips不是子module,直接指定device="cuda"会错误地导致clip模型权重都被放到cuda:0上。 + clips[v], _ = clip.load(name=v, device="cpu", jit=False, download_root=None) + delattr(clips[v], "transformer") + clips[v].eval() + clips[v].requires_grad_(False) + + self.clips_hidden_dim = sum(clips[v].ln_final.weight.size(0) for v in clips) + + if self.num_projection_vector == 0: + self.projection = nn.Identity() + else: + self.projection = nn.Linear(self.clips_hidden_dim, hidden_state_dim * self.num_projection_vector, bias=True) + self.projection.to(dtype=self.dtype) + nn.init.normal_(self.projection.weight, std=self.clips_hidden_dim ** -0.5) + + self.clips = clips + + self._move_flag = False + + def move(self): + if self._move_flag: + return + + def convert_weights(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.type(self.dtype) + if l.bias is not None: + l.bias.data = l.bias.data.type(self.dtype) + + if isinstance(l, nn.MultiheadAttention): + for attr in [ + *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], + "in_proj_bias", + "bias_k", + "bias_v", + ]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.type(self.dtype) + + for name in ["text_projection", "proj"]: + if hasattr(l, name): + attr = getattr(l, name) + if attr is not None: + attr.data = attr.data.type(self.dtype) + + model.apply(_convert_weights_to_fp16) + + for k in self.clips: + self.clips[k].to(self.device) + convert_weights(self.clips[k]) # fp32 -> self.dtype + self._move_flag = True + + def unconditional_embedding(self, batch_size=None): + zero = torch.zeros( + batch_size, + self.clips_hidden_dim, + device=self.device, + dtype=self.dtype, + ) + if self.num_projection_vector > 0: + zero = self.projection(zero).view(batch_size, self.num_projection_vector, -1) + return zero + + def convert_embedding(self, z): + if self.num_projection_vector > 0: + z = self.projection(z.type(self.projection.weight.dtype)).view(len(z), self.num_projection_vector, -1) + return z + + def forward(self, image, value_range=(-1, 1), zero_embedding_radio=0): + if value_range is not None: + low, high = value_range + image = (image - low) / (high - low) + + image = self.transform(image) + + with torch.no_grad(): + embs = [] + for v in self.clips: + x = self.clips[v].encode_image(image) + if self.normalize: + x = x / x.norm(p=2, dim=-1, keepdim=True) * (x.size(-1) ** 0.5) + # clip_max only works with normalization + if self.clip_max > 0: + x = x.clamp(-self.clip_max, self.clip_max) + embs.append(x) + + z = torch.cat(embs, dim=-1) + if self.normalize: + z /= z.size(-1) ** 0.5 + + if zero_embedding_radio > 0: + mask = torch.rand((len(image), 1, 1), device=z.device, dtype=z.dtype) >= zero_embedding_radio + z = z + mask.to(z) + + if self.num_projection_vector > 0: + z = self.projection(z).view(len(image), self.num_projection_vector, -1) + return z + + def encode(self, image): + self.move() + return self(image, zero_embedding_radio=self.zero_embedding_radio) diff --git a/primitive_anything/michelangelo/models/modules/__init__.py b/primitive_anything/michelangelo/models/modules/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..0729b49eadf964584d3524d9c0f6adec3f04a6a9 --- /dev/null +++ b/primitive_anything/michelangelo/models/modules/__init__.py @@ -0,0 +1,3 @@ +# -*- coding: utf-8 -*- + +from .checkpoint import checkpoint diff --git a/primitive_anything/michelangelo/models/modules/checkpoint.py b/primitive_anything/michelangelo/models/modules/checkpoint.py new file mode 100755 index 0000000000000000000000000000000000000000..4fef818bc15de279a06f9175aeadf85924ff18c0 --- /dev/null +++ b/primitive_anything/michelangelo/models/modules/checkpoint.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +""" +Adapted from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/nn.py#L124 +""" + +import torch +from typing import Callable, Iterable, Sequence, Union + + +def checkpoint( + func: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor]]], + inputs: Sequence[torch.Tensor], + params: Iterable[torch.Tensor], + flag: bool, + use_deepspeed: bool = False +): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + :param use_deepspeed: if True, use deepspeed + """ + if flag: + if use_deepspeed: + import deepspeed + return deepspeed.checkpointing.checkpoint(func, *inputs) + + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + @torch.cuda.amp.custom_fwd + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads diff --git a/primitive_anything/michelangelo/models/modules/diffusion_transformer.py b/primitive_anything/michelangelo/models/modules/diffusion_transformer.py new file mode 100755 index 0000000000000000000000000000000000000000..7efa7df4f37d87cd6c73fbf13a85251b94044530 --- /dev/null +++ b/primitive_anything/michelangelo/models/modules/diffusion_transformer.py @@ -0,0 +1,218 @@ +# -*- coding: utf-8 -*- + +import math +import torch +import torch.nn as nn +from typing import Optional + +from .checkpoint import checkpoint +from .transformer_blocks import ( + init_linear, + MLP, + MultiheadCrossAttention, + MultiheadAttention, + ResidualAttentionBlock +) + + +class AdaLayerNorm(nn.Module): + def __init__(self, + device: torch.device, + dtype: torch.dtype, + width: int): + + super().__init__() + + self.silu = nn.SiLU(inplace=True) + self.linear = nn.Linear(width, width * 2, device=device, dtype=dtype) + self.layernorm = nn.LayerNorm(width, elementwise_affine=False, device=device, dtype=dtype) + + def forward(self, x, timestep): + emb = self.linear(timestep) + scale, shift = torch.chunk(emb, 2, dim=2) + x = self.layernorm(x) * (1 + scale) + shift + return x + + +class DitBlock(nn.Module): + def __init__( + self, + *, + device: torch.device, + dtype: torch.dtype, + n_ctx: int, + width: int, + heads: int, + context_dim: int, + qkv_bias: bool = False, + init_scale: float = 1.0, + use_checkpoint: bool = False + ): + super().__init__() + + self.use_checkpoint = use_checkpoint + + self.attn = MultiheadAttention( + device=device, + dtype=dtype, + n_ctx=n_ctx, + width=width, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias + ) + self.ln_1 = AdaLayerNorm(device, dtype, width) + + if context_dim is not None: + self.ln_2 = AdaLayerNorm(device, dtype, width) + self.cross_attn = MultiheadCrossAttention( + device=device, + dtype=dtype, + width=width, + heads=heads, + data_width=context_dim, + init_scale=init_scale, + qkv_bias=qkv_bias + ) + + self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale) + self.ln_3 = AdaLayerNorm(device, dtype, width) + + def forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None): + return checkpoint(self._forward, (x, t, context), self.parameters(), self.use_checkpoint) + + def _forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None): + x = x + self.attn(self.ln_1(x, t)) + if context is not None: + x = x + self.cross_attn(self.ln_2(x, t), context) + x = x + self.mlp(self.ln_3(x, t)) + return x + + +class DiT(nn.Module): + def __init__( + self, + *, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + n_ctx: int, + width: int, + layers: int, + heads: int, + context_dim: int, + init_scale: float = 0.25, + qkv_bias: bool = False, + use_checkpoint: bool = False + ): + super().__init__() + self.n_ctx = n_ctx + self.width = width + self.layers = layers + + self.resblocks = nn.ModuleList( + [ + DitBlock( + device=device, + dtype=dtype, + n_ctx=n_ctx, + width=width, + heads=heads, + context_dim=context_dim, + qkv_bias=qkv_bias, + init_scale=init_scale, + use_checkpoint=use_checkpoint + ) + for _ in range(layers) + ] + ) + + def forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None): + for block in self.resblocks: + x = block(x, t, context) + return x + + +class UNetDiffusionTransformer(nn.Module): + def __init__( + self, + *, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + n_ctx: int, + width: int, + layers: int, + heads: int, + init_scale: float = 0.25, + qkv_bias: bool = False, + skip_ln: bool = False, + use_checkpoint: bool = False + ): + super().__init__() + + self.n_ctx = n_ctx + self.width = width + self.layers = layers + + self.encoder = nn.ModuleList() + for _ in range(layers): + resblock = ResidualAttentionBlock( + device=device, + dtype=dtype, + n_ctx=n_ctx, + width=width, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + use_checkpoint=use_checkpoint + ) + self.encoder.append(resblock) + + self.middle_block = ResidualAttentionBlock( + device=device, + dtype=dtype, + n_ctx=n_ctx, + width=width, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + use_checkpoint=use_checkpoint + ) + + self.decoder = nn.ModuleList() + for _ in range(layers): + resblock = ResidualAttentionBlock( + device=device, + dtype=dtype, + n_ctx=n_ctx, + width=width, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + use_checkpoint=use_checkpoint + ) + linear = nn.Linear(width * 2, width, device=device, dtype=dtype) + init_linear(linear, init_scale) + + layer_norm = nn.LayerNorm(width, device=device, dtype=dtype) if skip_ln else None + + self.decoder.append(nn.ModuleList([resblock, linear, layer_norm])) + + def forward(self, x: torch.Tensor): + + enc_outputs = [] + for block in self.encoder: + x = block(x) + enc_outputs.append(x) + + x = self.middle_block(x) + + for i, (resblock, linear, layer_norm) in enumerate(self.decoder): + x = torch.cat([enc_outputs.pop(), x], dim=-1) + x = linear(x) + + if layer_norm is not None: + x = layer_norm(x) + + x = resblock(x) + + return x diff --git a/primitive_anything/michelangelo/models/modules/distributions.py b/primitive_anything/michelangelo/models/modules/distributions.py new file mode 100755 index 0000000000000000000000000000000000000000..cf1cdcd53f1eb534b55d92ae1bd0b9854f6b890c --- /dev/null +++ b/primitive_anything/michelangelo/models/modules/distributions.py @@ -0,0 +1,100 @@ +import torch +import numpy as np +from typing import Union, List + + +class AbstractDistribution(object): + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters: Union[torch.Tensor, List[torch.Tensor]], deterministic=False, feat_dim=1): + self.feat_dim = feat_dim + self.parameters = parameters + + if isinstance(parameters, list): + self.mean = parameters[0] + self.logvar = parameters[1] + else: + self.mean, self.logvar = torch.chunk(parameters, 2, dim=feat_dim) + + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean) + + def sample(self): + x = self.mean + self.std * torch.randn_like(self.mean) + return x + + def kl(self, other=None, dims=(1, 2, 3)): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.mean(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=dims) + else: + return 0.5 * torch.mean( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=dims) + + def nll(self, sample, dims=(1, 2, 3)): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/primitive_anything/michelangelo/models/modules/embedder.py b/primitive_anything/michelangelo/models/modules/embedder.py new file mode 100755 index 0000000000000000000000000000000000000000..223de828f44903a3ce96b59d1cc5621e0989b535 --- /dev/null +++ b/primitive_anything/michelangelo/models/modules/embedder.py @@ -0,0 +1,213 @@ +# -*- coding: utf-8 -*- + +import numpy as np +import torch +import torch.nn as nn +import math + +VALID_EMBED_TYPES = ["identity", "fourier", "hashgrid", "sphere_harmonic", "triplane_fourier"] + + +class FourierEmbedder(nn.Module): + """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts + each feature dimension of `x[..., i]` into: + [ + sin(x[..., i]), + sin(f_1*x[..., i]), + sin(f_2*x[..., i]), + ... + sin(f_N * x[..., i]), + cos(x[..., i]), + cos(f_1*x[..., i]), + cos(f_2*x[..., i]), + ... + cos(f_N * x[..., i]), + x[..., i] # only present if include_input is True. + ], here f_i is the frequency. + + Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs]. + If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...]; + Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]. + + Args: + num_freqs (int): the number of frequencies, default is 6; + logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], + otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]; + input_dim (int): the input dimension, default is 3; + include_input (bool): include the input tensor or not, default is True. + + Attributes: + frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], + otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1); + + out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1), + otherwise, it is input_dim * num_freqs * 2. + + """ + + def __init__(self, + num_freqs: int = 6, + logspace: bool = True, + input_dim: int = 3, + include_input: bool = True, + include_pi: bool = True) -> None: + + """The initialization""" + + super().__init__() + + if logspace: + frequencies = 2.0 ** torch.arange( + num_freqs, + dtype=torch.float32 + ) + else: + frequencies = torch.linspace( + 1.0, + 2.0 ** (num_freqs - 1), + num_freqs, + dtype=torch.float32 + ) + + if include_pi: + frequencies *= torch.pi + + self.register_buffer("frequencies", frequencies, persistent=False) + self.include_input = include_input + self.num_freqs = num_freqs + + self.out_dim = self.get_dims(input_dim) + + def get_dims(self, input_dim): + temp = 1 if self.include_input or self.num_freqs == 0 else 0 + out_dim = input_dim * (self.num_freqs * 2 + temp) + + return out_dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ Forward process. + + Args: + x: tensor of shape [..., dim] + + Returns: + embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)] + where temp is 1 if include_input is True and 0 otherwise. + """ + + if self.num_freqs > 0: + embed = (x[..., None].contiguous() * self.frequencies).view(*x.shape[:-1], -1) + if self.include_input: + return torch.cat((x, embed.sin(), embed.cos()), dim=-1) + else: + return torch.cat((embed.sin(), embed.cos()), dim=-1) + else: + return x + + +class LearnedFourierEmbedder(nn.Module): + """ following @crowsonkb "s lead with learned sinusoidal pos emb """ + """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ + + def __init__(self, in_channels, dim): + super().__init__() + assert (dim % 2) == 0 + half_dim = dim // 2 + per_channel_dim = half_dim // in_channels + self.weights = nn.Parameter(torch.randn(per_channel_dim)) + + def forward(self, x): + """ + + Args: + x (torch.FloatTensor): [..., c] + + Returns: + x (torch.FloatTensor): [..., d] + """ + + # [b, t, c, 1] * [1, d] = [b, t, c, d] -> [b, t, c * d] + freqs = (x[..., None] * self.weights[None] * 2 * np.pi).view(*x.shape[:-1], -1) + fouriered = torch.cat((x, freqs.sin(), freqs.cos()), dim=-1) + return fouriered + + +class TriplaneLearnedFourierEmbedder(nn.Module): + def __init__(self, in_channels, dim): + super().__init__() + + self.yz_plane_embedder = LearnedFourierEmbedder(in_channels, dim) + self.xz_plane_embedder = LearnedFourierEmbedder(in_channels, dim) + self.xy_plane_embedder = LearnedFourierEmbedder(in_channels, dim) + + self.out_dim = in_channels + dim + + def forward(self, x): + + yz_embed = self.yz_plane_embedder(x) + xz_embed = self.xz_plane_embedder(x) + xy_embed = self.xy_plane_embedder(x) + + embed = yz_embed + xz_embed + xy_embed + + return embed + + +def sequential_pos_embed(num_len, embed_dim): + assert embed_dim % 2 == 0 + + pos = torch.arange(num_len, dtype=torch.float32) + omega = torch.arange(embed_dim // 2, dtype=torch.float32) + omega /= embed_dim / 2. + omega = 1. / 10000 ** omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = torch.sin(out) # (M, D/2) + emb_cos = torch.cos(out) # (M, D/2) + + embeddings = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) + + return embeddings + + +def timestep_embedding(timesteps, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].to(timesteps.dtype) * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +def get_embedder(embed_type="fourier", num_freqs=-1, input_dim=3, degree=4, + num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, + log2_hashmap_size=19, desired_resolution=None): + if embed_type == "identity" or (embed_type == "fourier" and num_freqs == -1): + return nn.Identity(), input_dim + + elif embed_type == "fourier": + embedder_obj = FourierEmbedder(num_freqs=num_freqs, input_dim=input_dim, + logspace=True, include_input=True) + return embedder_obj, embedder_obj.out_dim + + elif embed_type == "hashgrid": + raise NotImplementedError + + elif embed_type == "sphere_harmonic": + raise NotImplementedError + + else: + raise ValueError(f"{embed_type} is not valid. Currently only supprts {VALID_EMBED_TYPES}") diff --git a/primitive_anything/michelangelo/models/modules/transformer_blocks.py b/primitive_anything/michelangelo/models/modules/transformer_blocks.py new file mode 100755 index 0000000000000000000000000000000000000000..64349ccf326636aaba16f880354397cc9e80285d --- /dev/null +++ b/primitive_anything/michelangelo/models/modules/transformer_blocks.py @@ -0,0 +1,286 @@ +# -*- coding: utf-8 -*- + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Optional + +from .checkpoint import checkpoint + + +def init_linear(l, stddev): + nn.init.normal_(l.weight, std=stddev) + if l.bias is not None: + nn.init.constant_(l.bias, 0.0) + + +class MultiheadAttention(nn.Module): + def __init__( + self, + *, + device: torch.device, + dtype: torch.dtype, + n_ctx: int, + width: int, + heads: int, + init_scale: float, + qkv_bias: bool, + flash: bool = False + ): + super().__init__() + self.n_ctx = n_ctx + self.width = width + self.heads = heads + self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias, device=device, dtype=dtype) + self.c_proj = nn.Linear(width, width, device=device, dtype=dtype) + self.attention = QKVMultiheadAttention(device=device, dtype=dtype, heads=heads, n_ctx=n_ctx, flash=flash) + init_linear(self.c_qkv, init_scale) + init_linear(self.c_proj, init_scale) + + def forward(self, x): + x = self.c_qkv(x) + x = checkpoint(self.attention, (x,), (), True) + x = self.c_proj(x) + return x + + +class QKVMultiheadAttention(nn.Module): + def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int, flash: bool = False): + super().__init__() + self.device = device + self.dtype = dtype + self.heads = heads + self.n_ctx = n_ctx + self.flash = flash + + def forward(self, qkv): + bs, n_ctx, width = qkv.shape + attn_ch = width // self.heads // 3 + scale = 1 / math.sqrt(math.sqrt(attn_ch)) + qkv = qkv.view(bs, n_ctx, self.heads, -1) + q, k, v = torch.split(qkv, attn_ch, dim=-1) + + if self.flash: + out = F.scaled_dot_product_attention(q, k, v) + else: + weight = torch.einsum( + "bthc,bshc->bhts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + wdtype = weight.dtype + weight = torch.softmax(weight.float(), dim=-1).type(wdtype) + out = torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) + + return out + + +class ResidualAttentionBlock(nn.Module): + def __init__( + self, + *, + device: torch.device, + dtype: torch.dtype, + n_ctx: int, + width: int, + heads: int, + init_scale: float = 1.0, + qkv_bias: bool = True, + flash: bool = False, + use_checkpoint: bool = False + ): + super().__init__() + + self.use_checkpoint = use_checkpoint + + self.attn = MultiheadAttention( + device=device, + dtype=dtype, + n_ctx=n_ctx, + width=width, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash + ) + self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype) + self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale) + self.ln_2 = nn.LayerNorm(width, device=device, dtype=dtype) + + def _forward(self, x: torch.Tensor): + x = x + self.attn(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + def forward(self, x: torch.Tensor): + return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint) + + +class MultiheadCrossAttention(nn.Module): + def __init__( + self, + *, + device: torch.device, + dtype: torch.dtype, + width: int, + heads: int, + init_scale: float, + qkv_bias: bool = True, + flash: bool = False, + n_data: Optional[int] = None, + data_width: Optional[int] = None, + ): + super().__init__() + self.n_data = n_data + self.width = width + self.heads = heads + self.data_width = width if data_width is None else data_width + self.c_q = nn.Linear(width, width, bias=qkv_bias, device=device, dtype=dtype) + self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias, device=device, dtype=dtype) + self.c_proj = nn.Linear(width, width, device=device, dtype=dtype) + self.attention = QKVMultiheadCrossAttention( + device=device, dtype=dtype, heads=heads, n_data=n_data, flash=flash + ) + init_linear(self.c_q, init_scale) + init_linear(self.c_kv, init_scale) + init_linear(self.c_proj, init_scale) + + def forward(self, x, data): + x = self.c_q(x) + data = self.c_kv(data) + x = checkpoint(self.attention, (x, data), (), True) + x = self.c_proj(x) + return x + + +class QKVMultiheadCrossAttention(nn.Module): + def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, + flash: bool = False, n_data: Optional[int] = None): + + super().__init__() + self.device = device + self.dtype = dtype + self.heads = heads + self.n_data = n_data + self.flash = flash + + def forward(self, q, kv): + _, n_ctx, _ = q.shape + bs, n_data, width = kv.shape + attn_ch = width // self.heads // 2 + scale = 1 / math.sqrt(math.sqrt(attn_ch)) + q = q.view(bs, n_ctx, self.heads, -1) + kv = kv.view(bs, n_data, self.heads, -1) + k, v = torch.split(kv, attn_ch, dim=-1) + + if self.flash: + out = F.scaled_dot_product_attention(q, k, v) + else: + weight = torch.einsum( + "bthc,bshc->bhts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + wdtype = weight.dtype + weight = torch.softmax(weight.float(), dim=-1).type(wdtype) + out = torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) + + return out + + +class ResidualCrossAttentionBlock(nn.Module): + def __init__( + self, + *, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + n_data: Optional[int] = None, + width: int, + heads: int, + data_width: Optional[int] = None, + init_scale: float = 0.25, + qkv_bias: bool = True, + flash: bool = False + ): + super().__init__() + + if data_width is None: + data_width = width + + self.attn = MultiheadCrossAttention( + device=device, + dtype=dtype, + n_data=n_data, + width=width, + heads=heads, + data_width=data_width, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash, + ) + self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype) + self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype) + self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale) + self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype) + + def forward(self, x: torch.Tensor, data: torch.Tensor): + x = x + self.attn(self.ln_1(x), self.ln_2(data)) + x = x + self.mlp(self.ln_3(x)) + return x + + +class MLP(nn.Module): + def __init__(self, *, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + width: int, + init_scale: float): + super().__init__() + self.width = width + self.c_fc = nn.Linear(width, width * 4, device=device, dtype=dtype) + self.c_proj = nn.Linear(width * 4, width, device=device, dtype=dtype) + self.gelu = nn.GELU() + init_linear(self.c_fc, init_scale) + init_linear(self.c_proj, init_scale) + + def forward(self, x): + return self.c_proj(self.gelu(self.c_fc(x))) + + +class Transformer(nn.Module): + def __init__( + self, + *, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + n_ctx: int, + width: int, + layers: int, + heads: int, + init_scale: float = 0.25, + qkv_bias: bool = True, + flash: bool = False, + use_checkpoint: bool = False + ): + super().__init__() + self.n_ctx = n_ctx + self.width = width + self.layers = layers + self.resblocks = nn.ModuleList( + [ + ResidualAttentionBlock( + device=device, + dtype=dtype, + n_ctx=n_ctx, + width=width, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash, + use_checkpoint=use_checkpoint + ) + for _ in range(layers) + ] + ) + + def forward(self, x: torch.Tensor): + for block in self.resblocks: + x = block(x) + return x diff --git a/primitive_anything/michelangelo/models/modules/transformer_vit.py b/primitive_anything/michelangelo/models/modules/transformer_vit.py new file mode 100755 index 0000000000000000000000000000000000000000..4bd8822007dcb86c3328858a149dc5f8e51510df --- /dev/null +++ b/primitive_anything/michelangelo/models/modules/transformer_vit.py @@ -0,0 +1,308 @@ +# -*- coding: utf-8 -*- + +import math +import torch +import torch.nn as nn +from typing import Optional +import warnings + +from .checkpoint import checkpoint + + +def _trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor | nn.Parameter, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are + applied while sampling the normal with mean/std applied, therefore a, b args + should be adjusted to match the range of mean, std args. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + with torch.no_grad(): + return _trunc_normal_(tensor, mean, std, a, b) + + +def init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + +class MultiheadAttention(nn.Module): + def __init__( + self, + *, + device: torch.device, + dtype: torch.dtype, + n_ctx: int, + width: int, + heads: int, + qkv_bias: bool + ): + super().__init__() + self.n_ctx = n_ctx + self.width = width + self.heads = heads + self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias, device=device, dtype=dtype) + self.c_proj = nn.Linear(width, width, device=device, dtype=dtype) + self.attention = QKVMultiheadAttention(device=device, dtype=dtype, heads=heads, n_ctx=n_ctx) + + def forward(self, x): + x = self.c_qkv(x) + x = checkpoint(self.attention, (x,), (), True) + x = self.c_proj(x) + return x + + +class QKVMultiheadAttention(nn.Module): + def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int): + super().__init__() + self.device = device + self.dtype = dtype + self.heads = heads + self.n_ctx = n_ctx + + def forward(self, qkv): + bs, n_ctx, width = qkv.shape + attn_ch = width // self.heads // 3 + scale = 1 / math.sqrt(attn_ch) + qkv = qkv.view(bs, n_ctx, self.heads, -1) + q, k, v = torch.split(qkv, attn_ch, dim=-1) + weight = torch.einsum("bthc,bshc->bhts", q, k) * scale + wdtype = weight.dtype + weight = torch.softmax(weight.float(), dim=-1).type(wdtype) + return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) + + +class ResidualAttentionBlock(nn.Module): + def __init__( + self, + *, + device: torch.device, + dtype: torch.dtype, + n_ctx: int, + width: int, + heads: int, + qkv_bias: bool = True, + use_checkpoint: bool = False + ): + super().__init__() + + self.use_checkpoint = use_checkpoint + + self.attn = MultiheadAttention( + device=device, + dtype=dtype, + n_ctx=n_ctx, + width=width, + heads=heads, + qkv_bias=qkv_bias + ) + self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype) + self.mlp = MLP(device=device, dtype=dtype, width=width) + self.ln_2 = nn.LayerNorm(width, device=device, dtype=dtype) + + def _forward(self, x: torch.Tensor): + x = x + self.attn(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + def forward(self, x: torch.Tensor): + return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint) + + +class MultiheadCrossAttention(nn.Module): + def __init__( + self, + *, + device: torch.device, + dtype: torch.dtype, + width: int, + heads: int, + qkv_bias: bool = True, + n_data: Optional[int] = None, + data_width: Optional[int] = None, + ): + super().__init__() + self.n_data = n_data + self.width = width + self.heads = heads + self.data_width = width if data_width is None else data_width + self.c_q = nn.Linear(width, width, bias=qkv_bias, device=device, dtype=dtype) + self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias, device=device, dtype=dtype) + self.c_proj = nn.Linear(width, width, device=device, dtype=dtype) + self.attention = QKVMultiheadCrossAttention( + device=device, dtype=dtype, heads=heads, n_data=n_data + ) + + def forward(self, x, data): + x = self.c_q(x) + data = self.c_kv(data) + x = checkpoint(self.attention, (x, data), (), True) + x = self.c_proj(x) + return x + + +class QKVMultiheadCrossAttention(nn.Module): + def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_data: Optional[int] = None): + super().__init__() + self.device = device + self.dtype = dtype + self.heads = heads + self.n_data = n_data + + def forward(self, q, kv): + _, n_ctx, _ = q.shape + bs, n_data, width = kv.shape + attn_ch = width // self.heads // 2 + scale = 1 / math.sqrt(attn_ch) + q = q.view(bs, n_ctx, self.heads, -1) + kv = kv.view(bs, n_data, self.heads, -1) + k, v = torch.split(kv, attn_ch, dim=-1) + weight = torch.einsum("bthc,bshc->bhts", q, k) * scale + wdtype = weight.dtype + weight = torch.softmax(weight.float(), dim=-1).type(wdtype) + return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) + + +class ResidualCrossAttentionBlock(nn.Module): + def __init__( + self, + *, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + n_data: Optional[int] = None, + width: int, + heads: int, + data_width: Optional[int] = None, + qkv_bias: bool = True + ): + super().__init__() + + if data_width is None: + data_width = width + + self.attn = MultiheadCrossAttention( + device=device, + dtype=dtype, + n_data=n_data, + width=width, + heads=heads, + data_width=data_width, + qkv_bias=qkv_bias + ) + self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype) + self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype) + self.mlp = MLP(device=device, dtype=dtype, width=width) + self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype) + + def forward(self, x: torch.Tensor, data: torch.Tensor): + x = x + self.attn(self.ln_1(x), self.ln_2(data)) + x = x + self.mlp(self.ln_3(x)) + return x + + +class MLP(nn.Module): + def __init__(self, *, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + width: int): + super().__init__() + self.width = width + self.c_fc = nn.Linear(width, width * 4, device=device, dtype=dtype) + self.c_proj = nn.Linear(width * 4, width, device=device, dtype=dtype) + self.gelu = nn.GELU() + + def forward(self, x): + return self.c_proj(self.gelu(self.c_fc(x))) + + +class Transformer(nn.Module): + def __init__( + self, + *, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + n_ctx: int, + width: int, + layers: int, + heads: int, + qkv_bias: bool = True, + use_checkpoint: bool = False + ): + super().__init__() + self.n_ctx = n_ctx + self.width = width + self.layers = layers + self.resblocks = nn.ModuleList( + [ + ResidualAttentionBlock( + device=device, + dtype=dtype, + n_ctx=n_ctx, + width=width, + heads=heads, + qkv_bias=qkv_bias, + use_checkpoint=use_checkpoint + ) + for _ in range(layers) + ] + ) + + self.apply(init_weights) + + def forward(self, x: torch.Tensor): + for block in self.resblocks: + x = block(x) + return x diff --git a/primitive_anything/michelangelo/models/tsal/__init__.py b/primitive_anything/michelangelo/models/tsal/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..40a96afc6ff09d58a702b76e3f7dd412fe975e26 --- /dev/null +++ b/primitive_anything/michelangelo/models/tsal/__init__.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- diff --git a/primitive_anything/michelangelo/models/tsal/asl_pl_module.py b/primitive_anything/michelangelo/models/tsal/asl_pl_module.py new file mode 100755 index 0000000000000000000000000000000000000000..e2ee9f1ac22476fe566c227efd3f1fbd079d7fc3 --- /dev/null +++ b/primitive_anything/michelangelo/models/tsal/asl_pl_module.py @@ -0,0 +1,373 @@ +# -*- coding: utf-8 -*- + +from typing import List, Tuple, Dict, Optional +from omegaconf import DictConfig + +import torch +import torch.nn.functional as F +from torch.optim import lr_scheduler +import pytorch_lightning as pl +from typing import Union +from functools import partial + +from ...utils import instantiate_from_config + +from .inference_utils import extract_geometry +from .tsal_base import ( + AlignedShapeAsLatentModule, + ShapeAsLatentModule, + Latent2MeshOutput, + AlignedMeshOutput +) + + +class AlignedShapeAsLatentPLModule(pl.LightningModule): + + def __init__(self, *, + shape_module_cfg, + aligned_module_cfg, + loss_cfg, + optimizer_cfg: Optional[DictConfig] = None, + ckpt_path: Optional[str] = None, + ignore_keys: Union[Tuple[str], List[str]] = ()): + + super().__init__() + + shape_model: ShapeAsLatentModule = instantiate_from_config( + shape_module_cfg, device=None, dtype=None + ) + self.model: AlignedShapeAsLatentModule = instantiate_from_config( + aligned_module_cfg, shape_model=shape_model + ) + + self.loss = instantiate_from_config(loss_cfg) + + self.optimizer_cfg = optimizer_cfg + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + self.save_hyperparameters() + + def set_shape_model_only(self): + self.model.set_shape_model_only() + + @property + def latent_shape(self): + return self.model.shape_model.latent_shape + + @property + def zero_rank(self): + if self._trainer: + zero_rank = self.trainer.local_rank == 0 + else: + zero_rank = True + + return zero_rank + + def init_from_ckpt(self, path, ignore_keys=()): + state_dict = torch.load(path, map_location="cpu")["state_dict"] + + keys = list(state_dict.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del state_dict[k] + + missing, unexpected = self.load_state_dict(state_dict, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + print(f"Unexpected Keys: {unexpected}") + + def configure_optimizers(self) -> Tuple[List, List]: + lr = self.learning_rate + + trainable_parameters = list(self.model.parameters()) + + if self.optimizer_cfg is None: + optimizers = [torch.optim.AdamW(trainable_parameters, lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)] + schedulers = [] + else: + optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=trainable_parameters) + scheduler_func = instantiate_from_config( + self.optimizer_cfg.scheduler, + max_decay_steps=self.trainer.max_steps, + lr_max=lr + ) + scheduler = { + "scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule), + "interval": "step", + "frequency": 1 + } + optimizers = [optimizer] + schedulers = [scheduler] + + return optimizers, schedulers + + def forward(self, + surface: torch.FloatTensor, + image: torch.FloatTensor, + text: torch.FloatTensor, + volume_queries: torch.FloatTensor): + + """ + + Args: + surface (torch.FloatTensor): + image (torch.FloatTensor): + text (torch.FloatTensor): + volume_queries (torch.FloatTensor): + + Returns: + + """ + + embed_outputs, shape_z = self.model(surface, image, text) + + shape_zq, posterior = self.model.shape_model.encode_kl_embed(shape_z) + latents = self.model.shape_model.decode(shape_zq) + logits = self.model.shape_model.query_geometry(volume_queries, latents) + + return embed_outputs, logits, posterior + + def encode(self, surface: torch.FloatTensor, sample_posterior=True): + + pc = surface[..., 0:3] + feats = surface[..., 3:6] + + shape_embed, shape_zq, posterior = self.model.shape_model.encode( + pc=pc, feats=feats, sample_posterior=sample_posterior + ) + + return shape_zq + + def encode_latents(self, surface: torch.FloatTensor): + + pc = surface[..., 0:3] + feats = surface[..., 3:6] + + shape_embed, shape_latents = self.model.shape_model.encode_latents( + pc=pc, feats=feats + ) + shape_embed = shape_embed.unsqueeze(1) + assert shape_embed.shape[1] == 1 and shape_latents.shape[1] == 256 + cat_latents = torch.cat([shape_embed, shape_latents], dim=1) + + return cat_latents + + def to_shape_latents(self, latents): + + shape_zq, posterior = self.model.shape_model.encode_kl_embed(latents, sample_posterior = False) + return self.model.shape_model.decode(shape_zq) + + def decode(self, + z_q, + bounds: Union[Tuple[float], List[float], float] = 1.1, + octree_depth: int = 7, + num_chunks: int = 10000) -> List[Latent2MeshOutput]: + + latents = self.model.shape_model.decode(z_q) # latents: [bs, num_latents, dim] + outputs = self.latent2mesh(latents, bounds=bounds, octree_depth=octree_depth, num_chunks=num_chunks) + + return outputs + + def training_step(self, batch: Dict[str, torch.FloatTensor], + batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor: + """ + + Args: + batch (dict): the batch sample, and it contains: + - surface (torch.FloatTensor): [bs, n_surface, (3 + input_dim)] + - image (torch.FloatTensor): [bs, 3, 224, 224] + - text (torch.FloatTensor): [bs, num_templates, 77] + - geo_points (torch.FloatTensor): [bs, n_pts, (3 + 1)] + + batch_idx (int): + + optimizer_idx (int): + + Returns: + loss (torch.FloatTensor): + + """ + + surface = batch["surface"] + image = batch["image"] + text = batch["text"] + + volume_queries = batch["geo_points"][..., 0:3] + shape_labels = batch["geo_points"][..., -1] + + embed_outputs, shape_logits, posteriors = self(surface, image, text, volume_queries) + + aeloss, log_dict_ae = self.loss( + **embed_outputs, + posteriors=posteriors, + shape_logits=shape_logits, + shape_labels=shape_labels, + split="train" + ) + + self.log_dict(log_dict_ae, prog_bar=True, logger=True, batch_size=shape_logits.shape[0], + sync_dist=False, rank_zero_only=True) + + return aeloss + + def validation_step(self, batch: Dict[str, torch.FloatTensor], batch_idx: int) -> torch.FloatTensor: + + surface = batch["surface"] + image = batch["image"] + text = batch["text"] + + volume_queries = batch["geo_points"][..., 0:3] + shape_labels = batch["geo_points"][..., -1] + + embed_outputs, shape_logits, posteriors = self(surface, image, text, volume_queries) + + aeloss, log_dict_ae = self.loss( + **embed_outputs, + posteriors=posteriors, + shape_logits=shape_logits, + shape_labels=shape_labels, + split="val" + ) + self.log_dict(log_dict_ae, prog_bar=True, logger=True, batch_size=shape_logits.shape[0], + sync_dist=False, rank_zero_only=True) + + return aeloss + + def visual_alignment(self, + surface: torch.FloatTensor, + image: torch.FloatTensor, + text: torch.FloatTensor, + description: Optional[List[str]] = None, + bounds: Union[Tuple[float], List[float]] = (-1.25, -1.25, -1.25, 1.25, 1.25, 1.25), + octree_depth: int = 7, + num_chunks: int = 10000) -> List[AlignedMeshOutput]: + + """ + + Args: + surface: + image: + text: + description: + bounds: + octree_depth: + num_chunks: + + Returns: + mesh_outputs (List[AlignedMeshOutput]): the mesh outputs list. + + """ + + outputs = [] + + device = surface.device + bs = surface.shape[0] + + embed_outputs, shape_z = self.model(surface, image, text) + + # calculate the similarity + image_embed = embed_outputs["image_embed"] + text_embed = embed_outputs["text_embed"] + shape_embed = embed_outputs["shape_embed"] + + # normalized features + shape_embed = F.normalize(shape_embed, dim=-1, p=2) + text_embed = F.normalize(text_embed, dim=-1, p=2) + image_embed = F.normalize(image_embed, dim=-1, p=2) + + # B x B + shape_text_similarity = (100.0 * shape_embed @ text_embed.T).softmax(dim=-1) + + # B x B + shape_image_similarity = (100.0 * shape_embed @ image_embed.T).softmax(dim=-1) + + # shape reconstruction + shape_zq, posterior = self.model.shape_model.encode_kl_embed(shape_z) + latents = self.model.shape_model.decode(shape_zq) + geometric_func = partial(self.model.shape_model.query_geometry, latents=latents) + + # 2. decode geometry + mesh_v_f, has_surface = extract_geometry( + geometric_func=geometric_func, + device=device, + batch_size=bs, + bounds=bounds, + octree_depth=octree_depth, + num_chunks=num_chunks, + disable=not self.zero_rank + ) + + # 3. decode texture + for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)): + if not is_surface: + outputs.append(None) + continue + + out = AlignedMeshOutput() + out.mesh_v = mesh_v + out.mesh_f = mesh_f + out.surface = surface[i].cpu().numpy() + out.image = image[i].cpu().numpy() + if description is not None: + out.text = description[i] + out.shape_text_similarity = shape_text_similarity[i, i] + out.shape_image_similarity = shape_image_similarity[i, i] + + outputs.append(out) + + return outputs + + def latent2mesh(self, + latents: torch.FloatTensor, + bounds: Union[Tuple[float], List[float], float] = 1.1, + octree_depth: int = 7, + num_chunks: int = 10000) -> List[Latent2MeshOutput]: + + """ + + Args: + latents: [bs, num_latents, dim] + bounds: + octree_depth: + num_chunks: + + Returns: + mesh_outputs (List[MeshOutput]): the mesh outputs list. + + """ + + outputs = [] + + geometric_func = partial(self.model.shape_model.query_geometry, latents=latents) + + # 2. decode geometry + device = latents.device + mesh_v_f, has_surface = extract_geometry( + geometric_func=geometric_func, + device=device, + batch_size=len(latents), + bounds=bounds, + octree_depth=octree_depth, + num_chunks=num_chunks, + disable=not self.zero_rank + ) + + # 3. decode texture + for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)): + if not is_surface: + outputs.append(None) + continue + + out = Latent2MeshOutput() + out.mesh_v = mesh_v + out.mesh_f = mesh_f + + outputs.append(out) + + return outputs + diff --git a/primitive_anything/michelangelo/models/tsal/clip_asl_module.py b/primitive_anything/michelangelo/models/tsal/clip_asl_module.py new file mode 100755 index 0000000000000000000000000000000000000000..2a3ceda6eb04e26ed4845107296b58d85ec49cd7 --- /dev/null +++ b/primitive_anything/michelangelo/models/tsal/clip_asl_module.py @@ -0,0 +1,114 @@ +# -*- coding: utf-8 -*- + +import torch +from torch import nn +from einops import rearrange +from transformers import CLIPModel + +from .tsal_base import AlignedShapeAsLatentModule + + +class CLIPAlignedShapeAsLatentModule(AlignedShapeAsLatentModule): + + def __init__(self, *, + shape_model, + projection_dim=768): + + super().__init__() + + self.shape_model = shape_model + self.shape_projection = nn.Parameter(torch.empty(self.shape_model.width, projection_dim)) + nn.init.normal_(self.shape_projection, std=projection_dim ** -0.5) + + def set_shape_model_only(self): + self.clip_model = None + + def encode_shape_embed(self, surface, return_latents: bool = False): + """ + + Args: + surface (torch.FloatTensor): [bs, n, 3 + c] + return_latents (bool): + + Returns: + x (torch.FloatTensor): [bs, projection_dim] + shape_latents (torch.FloatTensor): [bs, m, d] + """ + + pc = surface[..., 0:3] + feats = surface[..., 3:] + + shape_embed, shape_latents = self.shape_model.encode_latents(pc, feats) + x = shape_embed @ self.shape_projection + + if return_latents: + return x, shape_latents + else: + return x + + def encode_image_embed(self, image): + """ + + Args: + image (torch.FloatTensor): [bs, 3, h, w] + + Returns: + x (torch.FloatTensor): [bs, projection_dim] + """ + + x = self.clip_model.get_image_features(image) + + return x + + def encode_text_embed(self, text): + x = self.clip_model.get_text_features(text) + return x + + def forward(self, surface, image, text): + """ + + Args: + surface (torch.FloatTensor): + image (torch.FloatTensor): [bs, 3, 224, 224] + text (torch.LongTensor): [bs, num_templates, 77] + + Returns: + embed_outputs (dict): the embedding outputs, and it contains: + - image_embed (torch.FloatTensor): + - text_embed (torch.FloatTensor): + - shape_embed (torch.FloatTensor): + - logit_scale (float): + """ + + # # text embedding + # text_embed_all = [] + # for i in range(text.shape[0]): + # text_for_one_sample = text[i] + # text_embed = self.encode_text_embed(text_for_one_sample) + # text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True) + # text_embed = text_embed.mean(dim=0) + # text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True) + # text_embed_all.append(text_embed) + # text_embed_all = torch.stack(text_embed_all) + + b = text.shape[0] + text_tokens = rearrange(text, "b t l -> (b t) l") + text_embed = self.encode_text_embed(text_tokens) + text_embed = rearrange(text_embed, "(b t) d -> b t d", b=b) + text_embed = text_embed.mean(dim=1) + text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True) + + # image embedding + image_embed = self.encode_image_embed(image) + + # shape embedding + shape_embed, shape_latents = self.encode_shape_embed(surface, return_latents=True) + + embed_outputs = { + "image_embed": image_embed, + "text_embed": text_embed, + "shape_embed": shape_embed, + "logit_scale": self.clip_model.logit_scale.exp() + } + + return embed_outputs, shape_latents diff --git a/primitive_anything/michelangelo/models/tsal/inference_utils.py b/primitive_anything/michelangelo/models/tsal/inference_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..4ba61a92e90d7d3ca90c6d9611e73904e5bd6e40 --- /dev/null +++ b/primitive_anything/michelangelo/models/tsal/inference_utils.py @@ -0,0 +1,80 @@ +# -*- coding: utf-8 -*- + +import torch +from tqdm import tqdm +from einops import repeat +import numpy as np +from typing import Callable, Tuple, List, Union, Optional +from skimage import measure + +from ...graphics.primitives import generate_dense_grid_points + + +@torch.no_grad() +def extract_geometry(geometric_func: Callable, + device: torch.device, + batch_size: int = 1, + bounds: Union[Tuple[float], List[float], float] = (-1.25, -1.25, -1.25, 1.25, 1.25, 1.25), + octree_depth: int = 7, + num_chunks: int = 10000, + disable: bool = True): + """ + + Args: + geometric_func: + device: + bounds: + octree_depth: + batch_size: + num_chunks: + disable: + + Returns: + + """ + + if isinstance(bounds, float): + bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds] + + bbox_min = np.array(bounds[0:3]) + bbox_max = np.array(bounds[3:6]) + bbox_size = bbox_max - bbox_min + + xyz_samples, grid_size, length = generate_dense_grid_points( + bbox_min=bbox_min, + bbox_max=bbox_max, + octree_depth=octree_depth, + indexing="ij" + ) + xyz_samples = torch.FloatTensor(xyz_samples) + + batch_logits = [] + for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), + desc="Implicit Function:", disable=disable, leave=False): + queries = xyz_samples[start: start + num_chunks, :].to(device) + batch_queries = repeat(queries, "p c -> b p c", b=batch_size) + + logits = geometric_func(batch_queries) + batch_logits.append(logits.cpu()) + + grid_logits = torch.cat(batch_logits, dim=1).view((batch_size, grid_size[0], grid_size[1], grid_size[2])).numpy() + + mesh_v_f = [] + has_surface = np.zeros((batch_size,), dtype=np.bool_) + for i in range(batch_size): + try: + vertices, faces, normals, _ = measure.marching_cubes(grid_logits[i], 0, method="lewiner") + vertices = vertices / grid_size * bbox_size + bbox_min + # vertices[:, [0, 1]] = vertices[:, [1, 0]] + mesh_v_f.append((vertices.astype(np.float32), np.ascontiguousarray(faces))) + has_surface[i] = True + + except ValueError: + mesh_v_f.append((None, None)) + has_surface[i] = False + + except RuntimeError: + mesh_v_f.append((None, None)) + has_surface[i] = False + + return mesh_v_f, has_surface diff --git a/primitive_anything/michelangelo/models/tsal/loss.py b/primitive_anything/michelangelo/models/tsal/loss.py new file mode 100755 index 0000000000000000000000000000000000000000..f8250e488bc8c3e57d080f92927daa1bb5192724 --- /dev/null +++ b/primitive_anything/michelangelo/models/tsal/loss.py @@ -0,0 +1,303 @@ +# -*- coding: utf-8 -*- +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Optional, Tuple, Dict + +from ..modules.distributions import DiagonalGaussianDistribution +from ...utils.eval import compute_psnr +from ...utils import misc + + +class KLNearFar(nn.Module): + def __init__(self, + near_weight: float = 0.1, + kl_weight: float = 1.0, + num_near_samples: Optional[int] = None): + + super().__init__() + + self.near_weight = near_weight + self.kl_weight = kl_weight + self.num_near_samples = num_near_samples + self.geo_criterion = nn.BCEWithLogitsLoss() + + def forward(self, + posteriors: Optional[DiagonalGaussianDistribution], + logits: torch.FloatTensor, + labels: torch.FloatTensor, + split: Optional[str] = "train", **kwargs) -> Tuple[torch.FloatTensor, Dict[str, float]]: + + """ + + Args: + posteriors (DiagonalGaussianDistribution or torch.distributions.Normal): + logits (torch.FloatTensor): [B, 2*N], logits[:, 0:N] is the volume points; logits[:, N:2N] is the near points; + labels (torch.FloatTensor): [B, 2*N], labels[:, 0:N] is the volume points; labels[:, N:2N] is the near points; + split (str): + **kwargs: + + Returns: + loss (torch.Tensor): (,) + log (dict): + + """ + + if self.num_near_samples is None: + num_vol = logits.shape[1] // 2 + else: + num_vol = logits.shape[1] - self.num_near_samples + + vol_logits = logits[:, 0:num_vol] + vol_labels = labels[:, 0:num_vol] + + near_logits = logits[:, num_vol:] + near_labels = labels[:, num_vol:] + + # occupancy loss + # vol_bce = self.geo_criterion(vol_logits, vol_labels) + # near_bce = self.geo_criterion(near_logits, near_labels) + vol_bce = self.geo_criterion(vol_logits.float(), vol_labels.float()) + near_bce = self.geo_criterion(near_logits.float(), near_labels.float()) + + if posteriors is None: + kl_loss = torch.tensor(0.0, dtype=vol_logits.dtype, device=vol_logits.device) + else: + kl_loss = posteriors.kl(dims=(1, 2)) + kl_loss = torch.mean(kl_loss) + + loss = vol_bce + near_bce * self.near_weight + kl_loss * self.kl_weight + + with torch.no_grad(): + preds = logits >= 0 + accuracy = (preds == labels).float() + accuracy = accuracy.mean() + pos_ratio = torch.mean(labels) + + log = { + "{}/total_loss".format(split): loss.clone().detach(), + "{}/near".format(split): near_bce.detach(), + "{}/far".format(split): vol_bce.detach(), + "{}/kl".format(split): kl_loss.detach(), + "{}/accuracy".format(split): accuracy, + "{}/pos_ratio".format(split): pos_ratio + } + + if posteriors is not None: + log[f"{split}/mean"] = posteriors.mean.mean().detach() + log[f"{split}/std_mean"] = posteriors.std.mean().detach() + log[f"{split}/std_max"] = posteriors.std.max().detach() + + return loss, log + + +class KLNearFarColor(nn.Module): + def __init__(self, + near_weight: float = 0.1, + kl_weight: float = 1.0, + color_weight: float = 1.0, + color_criterion: str = "mse", + num_near_samples: Optional[int] = None): + + super().__init__() + + self.color_weight = color_weight + self.near_weight = near_weight + self.kl_weight = kl_weight + self.num_near_samples = num_near_samples + + if color_criterion == "mse": + self.color_criterion = nn.MSELoss() + + elif color_criterion == "l1": + self.color_criterion = nn.L1Loss() + + else: + raise ValueError(f"{color_criterion} must be [`mse`, `l1`].") + + self.geo_criterion = nn.BCEWithLogitsLoss() + + def forward(self, + posteriors: Optional[DiagonalGaussianDistribution], + logits: torch.FloatTensor, + labels: torch.FloatTensor, + pred_colors: torch.FloatTensor, + gt_colors: torch.FloatTensor, + split: Optional[str] = "train", **kwargs) -> Tuple[torch.FloatTensor, Dict[str, float]]: + + """ + + Args: + posteriors (DiagonalGaussianDistribution or torch.distributions.Normal): + logits (torch.FloatTensor): [B, 2*N], logits[:, 0:N] is the volume points; logits[:, N:2N] is the near points; + labels (torch.FloatTensor): [B, 2*N], labels[:, 0:N] is the volume points; labels[:, N:2N] is the near points; + pred_colors (torch.FloatTensor): [B, M, 3] + gt_colors (torch.FloatTensor): [B, M, 3] + split (str): + **kwargs: + + Returns: + loss (torch.Tensor): (,) + log (dict): + + """ + + if self.num_near_samples is None: + num_vol = logits.shape[1] // 2 + else: + num_vol = logits.shape[1] - self.num_near_samples + + vol_logits = logits[:, 0:num_vol] + vol_labels = labels[:, 0:num_vol] + + near_logits = logits[:, num_vol:] + near_labels = labels[:, num_vol:] + + # occupancy loss + # vol_bce = self.geo_criterion(vol_logits, vol_labels) + # near_bce = self.geo_criterion(near_logits, near_labels) + vol_bce = self.geo_criterion(vol_logits.float(), vol_labels.float()) + near_bce = self.geo_criterion(near_logits.float(), near_labels.float()) + + # surface color loss + color = self.color_criterion(pred_colors, gt_colors) + + if posteriors is None: + kl_loss = torch.tensor(0.0, dtype=pred_colors.dtype, device=pred_colors.device) + else: + kl_loss = posteriors.kl(dims=(1, 2)) + kl_loss = torch.mean(kl_loss) + + loss = vol_bce + near_bce * self.near_weight + color * self.color_weight + kl_loss * self.kl_weight + + with torch.no_grad(): + preds = logits >= 0 + accuracy = (preds == labels).float() + accuracy = accuracy.mean() + psnr = compute_psnr(pred_colors, gt_colors) + + log = { + "{}/total_loss".format(split): loss.clone().detach(), + "{}/near".format(split): near_bce.detach(), + "{}/far".format(split): vol_bce.detach(), + "{}/color".format(split): color.detach(), + "{}/kl".format(split): kl_loss.detach(), + "{}/psnr".format(split): psnr.detach(), + "{}/accuracy".format(split): accuracy + } + + return loss, log + + +class ContrastKLNearFar(nn.Module): + def __init__(self, + contrast_weight: float = 1.0, + near_weight: float = 0.1, + kl_weight: float = 1.0, + num_near_samples: Optional[int] = None): + + super().__init__() + + self.labels = None + self.last_local_batch_size = None + + self.contrast_weight = contrast_weight + self.near_weight = near_weight + self.kl_weight = kl_weight + self.num_near_samples = num_near_samples + self.geo_criterion = nn.BCEWithLogitsLoss() + + def forward(self, + shape_embed: torch.FloatTensor, + text_embed: torch.FloatTensor, + image_embed: torch.FloatTensor, + logit_scale: torch.FloatTensor, + posteriors: Optional[DiagonalGaussianDistribution], + shape_logits: torch.FloatTensor, + shape_labels: torch.FloatTensor, + split: Optional[str] = "train", **kwargs): + + local_batch_size = shape_embed.size(0) + + if local_batch_size != self.last_local_batch_size: + self.labels = local_batch_size * misc.get_rank() + torch.arange( + local_batch_size, device=shape_embed.device + ).long() + self.last_local_batch_size = local_batch_size + + # normalized features + shape_embed = F.normalize(shape_embed, dim=-1, p=2) + text_embed = F.normalize(text_embed, dim=-1, p=2) + image_embed = F.normalize(image_embed, dim=-1, p=2) + + # gather features from all GPUs + shape_embed_all, text_embed_all, image_embed_all = misc.all_gather_batch( + [shape_embed, text_embed, image_embed] + ) + + # cosine similarity as logits + logits_per_shape_text = logit_scale * shape_embed @ text_embed_all.t() + logits_per_text_shape = logit_scale * text_embed @ shape_embed_all.t() + logits_per_shape_image = logit_scale * shape_embed @ image_embed_all.t() + logits_per_image_shape = logit_scale * image_embed @ shape_embed_all.t() + contrast_loss = (F.cross_entropy(logits_per_shape_text, self.labels) + + F.cross_entropy(logits_per_text_shape, self.labels)) / 2 + \ + (F.cross_entropy(logits_per_shape_image, self.labels) + + F.cross_entropy(logits_per_image_shape, self.labels)) / 2 + + # shape reconstruction + if self.num_near_samples is None: + num_vol = shape_logits.shape[1] // 2 + else: + num_vol = shape_logits.shape[1] - self.num_near_samples + + vol_logits = shape_logits[:, 0:num_vol] + vol_labels = shape_labels[:, 0:num_vol] + + near_logits = shape_logits[:, num_vol:] + near_labels = shape_labels[:, num_vol:] + + # occupancy loss + vol_bce = self.geo_criterion(vol_logits.float(), vol_labels.float()) + near_bce = self.geo_criterion(near_logits.float(), near_labels.float()) + + if posteriors is None: + kl_loss = torch.tensor(0.0, dtype=vol_logits.dtype, device=vol_logits.device) + else: + kl_loss = posteriors.kl(dims=(1, 2)) + kl_loss = torch.mean(kl_loss) + + loss = vol_bce + near_bce * self.near_weight + kl_loss * self.kl_weight + contrast_loss * self.contrast_weight + + # compute accuracy + with torch.no_grad(): + pred = torch.argmax(logits_per_shape_text, dim=-1) + correct = pred.eq(self.labels).sum() + shape_text_acc = 100 * correct / local_batch_size + + pred = torch.argmax(logits_per_shape_image, dim=-1) + correct = pred.eq(self.labels).sum() + shape_image_acc = 100 * correct / local_batch_size + + preds = shape_logits >= 0 + accuracy = (preds == shape_labels).float() + accuracy = accuracy.mean() + + log = { + "{}/contrast".format(split): contrast_loss.clone().detach(), + "{}/near".format(split): near_bce.detach(), + "{}/far".format(split): vol_bce.detach(), + "{}/kl".format(split): kl_loss.detach(), + "{}/shape_text_acc".format(split): shape_text_acc, + "{}/shape_image_acc".format(split): shape_image_acc, + "{}/total_loss".format(split): loss.clone().detach(), + "{}/accuracy".format(split): accuracy, + } + + if posteriors is not None: + log[f"{split}/mean"] = posteriors.mean.mean().detach() + log[f"{split}/std_mean"] = posteriors.std.mean().detach() + log[f"{split}/std_max"] = posteriors.std.max().detach() + + return loss, log diff --git a/primitive_anything/michelangelo/models/tsal/sal_perceiver.py b/primitive_anything/michelangelo/models/tsal/sal_perceiver.py new file mode 100755 index 0000000000000000000000000000000000000000..1e66624015c706576cddc38c9c8afca1a336c3bd --- /dev/null +++ b/primitive_anything/michelangelo/models/tsal/sal_perceiver.py @@ -0,0 +1,423 @@ +# -*- coding: utf-8 -*- + +import torch +import torch.nn as nn +from typing import Optional +from einops import repeat +import math + +from ..modules import checkpoint +from ..modules.embedder import FourierEmbedder +from ..modules.distributions import DiagonalGaussianDistribution +from ..modules.transformer_blocks import ( + ResidualCrossAttentionBlock, + Transformer +) + +from .tsal_base import ShapeAsLatentModule + + +class CrossAttentionEncoder(nn.Module): + + def __init__(self, *, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + num_latents: int, + fourier_embedder: FourierEmbedder, + point_feats: int, + width: int, + heads: int, + layers: int, + init_scale: float = 0.25, + qkv_bias: bool = True, + flash: bool = False, + use_ln_post: bool = False, + use_checkpoint: bool = False): + + super().__init__() + + self.use_checkpoint = use_checkpoint + self.num_latents = num_latents + + self.query = nn.Parameter(torch.randn((num_latents, width), device=device, dtype=dtype) * 0.02) + + self.fourier_embedder = fourier_embedder + self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width, device=device, dtype=dtype) + self.cross_attn = ResidualCrossAttentionBlock( + device=device, + dtype=dtype, + width=width, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash, + ) + + self.self_attn = Transformer( + device=device, + dtype=dtype, + n_ctx=num_latents, + width=width, + layers=layers, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash, + use_checkpoint=False + ) + + if use_ln_post: + self.ln_post = nn.LayerNorm(width, dtype=dtype, device=device) + else: + self.ln_post = None + + def _forward(self, pc, feats): + """ + + Args: + pc (torch.FloatTensor): [B, N, 3] + feats (torch.FloatTensor or None): [B, N, C] + + Returns: + + """ + + bs = pc.shape[0] + + data = self.fourier_embedder(pc) + if feats is not None: + data = torch.cat([data, feats], dim=-1) + data = self.input_proj(data) + + query = repeat(self.query, "m c -> b m c", b=bs) + latents = self.cross_attn(query, data) + latents = self.self_attn(latents) + + if self.ln_post is not None: + latents = self.ln_post(latents) + + return latents, pc + + def forward(self, pc: torch.FloatTensor, feats: Optional[torch.FloatTensor] = None): + """ + + Args: + pc (torch.FloatTensor): [B, N, 3] + feats (torch.FloatTensor or None): [B, N, C] + + Returns: + dict + """ + + return checkpoint(self._forward, (pc, feats), self.parameters(), self.use_checkpoint) + + +class CrossAttentionDecoder(nn.Module): + + def __init__(self, *, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + num_latents: int, + out_channels: int, + fourier_embedder: FourierEmbedder, + width: int, + heads: int, + init_scale: float = 0.25, + qkv_bias: bool = True, + flash: bool = False, + use_checkpoint: bool = False): + + super().__init__() + + self.use_checkpoint = use_checkpoint + self.fourier_embedder = fourier_embedder + + self.query_proj = nn.Linear(self.fourier_embedder.out_dim, width, device=device, dtype=dtype) + + self.cross_attn_decoder = ResidualCrossAttentionBlock( + device=device, + dtype=dtype, + n_data=num_latents, + width=width, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash + ) + + self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype) + self.output_proj = nn.Linear(width, out_channels, device=device, dtype=dtype) + + def _forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor): + queries = self.query_proj(self.fourier_embedder(queries)) + x = self.cross_attn_decoder(queries, latents) + x = self.ln_post(x) + x = self.output_proj(x) + return x + + def forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor): + return checkpoint(self._forward, (queries, latents), self.parameters(), self.use_checkpoint) + + +class ShapeAsLatentPerceiver(ShapeAsLatentModule): + def __init__(self, *, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + num_latents: int, + point_feats: int = 0, + embed_dim: int = 0, + num_freqs: int = 8, + include_pi: bool = True, + width: int, + heads: int, + num_encoder_layers: int, + num_decoder_layers: int, + init_scale: float = 0.25, + qkv_bias: bool = True, + flash: bool = False, + use_ln_post: bool = False, + use_checkpoint: bool = False): + + super().__init__() + + self.use_checkpoint = use_checkpoint + + self.num_latents = num_latents + self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi) + + init_scale = init_scale * math.sqrt(1.0 / width) + self.encoder = CrossAttentionEncoder( + device=device, + dtype=dtype, + fourier_embedder=self.fourier_embedder, + num_latents=num_latents, + point_feats=point_feats, + width=width, + heads=heads, + layers=num_encoder_layers, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash, + use_ln_post=use_ln_post, + use_checkpoint=use_checkpoint + ) + + self.embed_dim = embed_dim + if embed_dim > 0: + # VAE embed + self.pre_kl = nn.Linear(width, embed_dim * 2, device=device, dtype=dtype) + self.post_kl = nn.Linear(embed_dim, width, device=device, dtype=dtype) + self.latent_shape = (num_latents, embed_dim) + else: + self.latent_shape = (num_latents, width) + + self.transformer = Transformer( + device=device, + dtype=dtype, + n_ctx=num_latents, + width=width, + layers=num_decoder_layers, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash, + use_checkpoint=use_checkpoint + ) + + # geometry decoder + self.geo_decoder = CrossAttentionDecoder( + device=device, + dtype=dtype, + fourier_embedder=self.fourier_embedder, + out_channels=1, + num_latents=num_latents, + width=width, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash, + use_checkpoint=use_checkpoint + ) + + def encode(self, + pc: torch.FloatTensor, + feats: Optional[torch.FloatTensor] = None, + sample_posterior: bool = True): + """ + + Args: + pc (torch.FloatTensor): [B, N, 3] + feats (torch.FloatTensor or None): [B, N, C] + sample_posterior (bool): + + Returns: + latents (torch.FloatTensor) + center_pos (torch.FloatTensor or None): + posterior (DiagonalGaussianDistribution or None): + """ + + latents, center_pos = self.encoder(pc, feats) + + posterior = None + if self.embed_dim > 0: + moments = self.pre_kl(latents) + posterior = DiagonalGaussianDistribution(moments, feat_dim=-1) + + if sample_posterior: + latents = posterior.sample() + else: + latents = posterior.mode() + + return latents, center_pos, posterior + + def decode(self, latents: torch.FloatTensor): + latents = self.post_kl(latents) + return self.transformer(latents) + + def query_geometry(self, queries: torch.FloatTensor, latents: torch.FloatTensor): + logits = self.geo_decoder(queries, latents).squeeze(-1) + return logits + + def forward(self, + pc: torch.FloatTensor, + feats: torch.FloatTensor, + volume_queries: torch.FloatTensor, + sample_posterior: bool = True): + """ + + Args: + pc (torch.FloatTensor): [B, N, 3] + feats (torch.FloatTensor or None): [B, N, C] + volume_queries (torch.FloatTensor): [B, P, 3] + sample_posterior (bool): + + Returns: + logits (torch.FloatTensor): [B, P] + center_pos (torch.FloatTensor): [B, M, 3] + posterior (DiagonalGaussianDistribution or None). + + """ + + latents, center_pos, posterior = self.encode(pc, feats, sample_posterior=sample_posterior) + + latents = self.decode(latents) + logits = self.query_geometry(volume_queries, latents) + + return logits, center_pos, posterior + + +class AlignedShapeLatentPerceiver(ShapeAsLatentPerceiver): + + def __init__(self, *, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + num_latents: int, + point_feats: int = 0, + embed_dim: int = 0, + num_freqs: int = 8, + include_pi: bool = True, + width: int, + heads: int, + num_encoder_layers: int, + num_decoder_layers: int, + init_scale: float = 0.25, + qkv_bias: bool = True, + flash: bool = False, + use_ln_post: bool = False, + use_checkpoint: bool = False): + + super().__init__( + device=device, + dtype=dtype, + num_latents=1 + num_latents, + point_feats=point_feats, + embed_dim=embed_dim, + num_freqs=num_freqs, + include_pi=include_pi, + width=width, + heads=heads, + num_encoder_layers=num_encoder_layers, + num_decoder_layers=num_decoder_layers, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash, + use_ln_post=use_ln_post, + use_checkpoint=use_checkpoint + ) + + self.width = width + + def encode(self, + pc: torch.FloatTensor, + feats: Optional[torch.FloatTensor] = None, + sample_posterior: bool = True): + """ + + Args: + pc (torch.FloatTensor): [B, N, 3] + feats (torch.FloatTensor or None): [B, N, c] + sample_posterior (bool): + + Returns: + shape_embed (torch.FloatTensor) + kl_embed (torch.FloatTensor): + posterior (DiagonalGaussianDistribution or None): + """ + + shape_embed, latents = self.encode_latents(pc, feats) + kl_embed, posterior = self.encode_kl_embed(latents, sample_posterior) + + return shape_embed, kl_embed, posterior + + def encode_latents(self, + pc: torch.FloatTensor, + feats: Optional[torch.FloatTensor] = None): + + x, _ = self.encoder(pc, feats) + + shape_embed = x[:, 0] + latents = x[:, 1:] + + return shape_embed, latents + + def encode_kl_embed(self, latents: torch.FloatTensor, sample_posterior: bool = True): + posterior = None + if self.embed_dim > 0: + moments = self.pre_kl(latents) + posterior = DiagonalGaussianDistribution(moments, feat_dim=-1) + + if sample_posterior: + kl_embed = posterior.sample() + else: + kl_embed = posterior.mode() + else: + kl_embed = latents + + return kl_embed, posterior + + def forward(self, + pc: torch.FloatTensor, + feats: torch.FloatTensor, + volume_queries: torch.FloatTensor, + sample_posterior: bool = True): + """ + + Args: + pc (torch.FloatTensor): [B, N, 3] + feats (torch.FloatTensor or None): [B, N, C] + volume_queries (torch.FloatTensor): [B, P, 3] + sample_posterior (bool): + + Returns: + shape_embed (torch.FloatTensor): [B, projection_dim] + logits (torch.FloatTensor): [B, M] + posterior (DiagonalGaussianDistribution or None). + + """ + + shape_embed, kl_embed, posterior = self.encode(pc, feats, sample_posterior=sample_posterior) + + latents = self.decode(kl_embed) + logits = self.query_geometry(volume_queries, latents) + + return shape_embed, logits, posterior diff --git a/primitive_anything/michelangelo/models/tsal/sal_pl_module.py b/primitive_anything/michelangelo/models/tsal/sal_pl_module.py new file mode 100755 index 0000000000000000000000000000000000000000..20f9d8fa46f934a89afbdfafab195b3be8a84eb9 --- /dev/null +++ b/primitive_anything/michelangelo/models/tsal/sal_pl_module.py @@ -0,0 +1,290 @@ +# -*- coding: utf-8 -*- + +from typing import List, Tuple, Dict, Optional +from omegaconf import DictConfig + +import torch +from torch.optim import lr_scheduler +import pytorch_lightning as pl +from typing import Union +from functools import partial + +from ...utils import instantiate_from_config + +from .inference_utils import extract_geometry +from .tsal_base import ( + ShapeAsLatentModule, + Latent2MeshOutput, + Point2MeshOutput +) + + +class ShapeAsLatentPLModule(pl.LightningModule): + + def __init__(self, *, + module_cfg, + loss_cfg, + optimizer_cfg: Optional[DictConfig] = None, + ckpt_path: Optional[str] = None, + ignore_keys: Union[Tuple[str], List[str]] = ()): + + super().__init__() + + self.sal: ShapeAsLatentModule = instantiate_from_config(module_cfg, device=None, dtype=None) + + self.loss = instantiate_from_config(loss_cfg) + + self.optimizer_cfg = optimizer_cfg + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + self.save_hyperparameters() + + @property + def latent_shape(self): + return self.sal.latent_shape + + @property + def zero_rank(self): + if self._trainer: + zero_rank = self.trainer.local_rank == 0 + else: + zero_rank = True + + return zero_rank + + def init_from_ckpt(self, path, ignore_keys=()): + state_dict = torch.load(path, map_location="cpu")["state_dict"] + + keys = list(state_dict.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del state_dict[k] + + missing, unexpected = self.load_state_dict(state_dict, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + print(f"Unexpected Keys: {unexpected}") + + def configure_optimizers(self) -> Tuple[List, List]: + lr = self.learning_rate + + # optimizers = [torch.optim.AdamW(self.sal.parameters(), lr=lr, betas=(0.9, 0.99), weight_decay=1e-4)] + # optimizers = [torch.optim.AdamW(self.sal.parameters(), lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)] + + if self.optimizer_cfg is None: + optimizers = [torch.optim.AdamW(self.sal.parameters(), lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)] + schedulers = [] + else: + optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=self.sal.parameters()) + scheduler_func = instantiate_from_config( + self.optimizer_cfg.scheduler, + max_decay_steps=self.trainer.max_steps, + lr_max=lr + ) + scheduler = { + "scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule), + "interval": "step", + "frequency": 1 + } + optimizers = [optimizer] + schedulers = [scheduler] + + return optimizers, schedulers + + def forward(self, + pc: torch.FloatTensor, + feats: torch.FloatTensor, + volume_queries: torch.FloatTensor): + + logits, center_pos, posterior = self.sal(pc, feats, volume_queries) + + return posterior, logits + + def encode(self, surface: torch.FloatTensor, sample_posterior=True): + + pc = surface[..., 0:3] + feats = surface[..., 3:6] + + latents, center_pos, posterior = self.sal.encode( + pc=pc, feats=feats, sample_posterior=sample_posterior + ) + + return latents + + def decode(self, + z_q, + bounds: Union[Tuple[float], List[float], float] = 1.1, + octree_depth: int = 7, + num_chunks: int = 10000) -> List[Latent2MeshOutput]: + + latents = self.sal.decode(z_q) # latents: [bs, num_latents, dim] + outputs = self.latent2mesh(latents, bounds=bounds, octree_depth=octree_depth, num_chunks=num_chunks) + + return outputs + + def training_step(self, batch: Dict[str, torch.FloatTensor], + batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor: + """ + + Args: + batch (dict): the batch sample, and it contains: + - surface (torch.FloatTensor): [bs, n_surface, (3 + input_dim)] + - geo_points (torch.FloatTensor): [bs, n_pts, (3 + 1)] + + batch_idx (int): + + optimizer_idx (int): + + Returns: + loss (torch.FloatTensor): + + """ + + pc = batch["surface"][..., 0:3] + feats = batch["surface"][..., 3:] + + volume_queries = batch["geo_points"][..., 0:3] + volume_labels = batch["geo_points"][..., -1] + + posterior, logits = self( + pc=pc, feats=feats, volume_queries=volume_queries + ) + aeloss, log_dict_ae = self.loss(posterior, logits, volume_labels, split="train") + + self.log_dict(log_dict_ae, prog_bar=True, logger=True, batch_size=logits.shape[0], + sync_dist=False, rank_zero_only=True) + + return aeloss + + def validation_step(self, batch: Dict[str, torch.FloatTensor], batch_idx: int) -> torch.FloatTensor: + + pc = batch["surface"][..., 0:3] + feats = batch["surface"][..., 3:] + + volume_queries = batch["geo_points"][..., 0:3] + volume_labels = batch["geo_points"][..., -1] + + posterior, logits = self( + pc=pc, feats=feats, volume_queries=volume_queries, + ) + aeloss, log_dict_ae = self.loss(posterior, logits, volume_labels, split="val") + + self.log_dict(log_dict_ae, prog_bar=True, logger=True, batch_size=logits.shape[0], + sync_dist=False, rank_zero_only=True) + + return aeloss + + def point2mesh(self, + pc: torch.FloatTensor, + feats: torch.FloatTensor, + bounds: Union[Tuple[float], List[float]] = (-1.25, -1.25, -1.25, 1.25, 1.25, 1.25), + octree_depth: int = 7, + num_chunks: int = 10000) -> List[Point2MeshOutput]: + + """ + + Args: + pc: + feats: + bounds: + octree_depth: + num_chunks: + + Returns: + mesh_outputs (List[MeshOutput]): the mesh outputs list. + + """ + + outputs = [] + + device = pc.device + bs = pc.shape[0] + + # 1. point encoder + latents transformer + latents, center_pos, posterior = self.sal.encode(pc, feats) + latents = self.sal.decode(latents) # latents: [bs, num_latents, dim] + + geometric_func = partial(self.sal.query_geometry, latents=latents) + + # 2. decode geometry + mesh_v_f, has_surface = extract_geometry( + geometric_func=geometric_func, + device=device, + batch_size=bs, + bounds=bounds, + octree_depth=octree_depth, + num_chunks=num_chunks, + disable=not self.zero_rank + ) + + # 3. decode texture + for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)): + if not is_surface: + outputs.append(None) + continue + + out = Point2MeshOutput() + out.mesh_v = mesh_v + out.mesh_f = mesh_f + out.pc = torch.cat([pc[i], feats[i]], dim=-1).cpu().numpy() + + if center_pos is not None: + out.center = center_pos[i].cpu().numpy() + + outputs.append(out) + + return outputs + + def latent2mesh(self, + latents: torch.FloatTensor, + bounds: Union[Tuple[float], List[float], float] = 1.1, + octree_depth: int = 7, + num_chunks: int = 10000) -> List[Latent2MeshOutput]: + + """ + + Args: + latents: [bs, num_latents, dim] + bounds: + octree_depth: + num_chunks: + + Returns: + mesh_outputs (List[MeshOutput]): the mesh outputs list. + + """ + + outputs = [] + + geometric_func = partial(self.sal.query_geometry, latents=latents) + + # 2. decode geometry + device = latents.device + mesh_v_f, has_surface = extract_geometry( + geometric_func=geometric_func, + device=device, + batch_size=len(latents), + bounds=bounds, + octree_depth=octree_depth, + num_chunks=num_chunks, + disable=not self.zero_rank + ) + + # 3. decode texture + for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)): + if not is_surface: + outputs.append(None) + continue + + out = Latent2MeshOutput() + out.mesh_v = mesh_v + out.mesh_f = mesh_f + + outputs.append(out) + + return outputs diff --git a/primitive_anything/michelangelo/models/tsal/tsal_base.py b/primitive_anything/michelangelo/models/tsal/tsal_base.py new file mode 100755 index 0000000000000000000000000000000000000000..233a8afbdd0eb24024a6f915e770a286361cf0fe --- /dev/null +++ b/primitive_anything/michelangelo/models/tsal/tsal_base.py @@ -0,0 +1,121 @@ +# -*- coding: utf-8 -*- + +import torch.nn as nn +from typing import Tuple, List, Optional +import pytorch_lightning as pl + + +class Point2MeshOutput(object): + def __init__(self): + self.mesh_v = None + self.mesh_f = None + self.center = None + self.pc = None + + +class Latent2MeshOutput(object): + + def __init__(self): + self.mesh_v = None + self.mesh_f = None + + +class AlignedMeshOutput(object): + + def __init__(self): + self.mesh_v = None + self.mesh_f = None + self.surface = None + self.image = None + self.text: Optional[str] = None + self.shape_text_similarity: Optional[float] = None + self.shape_image_similarity: Optional[float] = None + + +class ShapeAsLatentPLModule(pl.LightningModule): + latent_shape: Tuple[int] + + def encode(self, surface, *args, **kwargs): + raise NotImplementedError + + def decode(self, z_q, *args, **kwargs): + raise NotImplementedError + + def latent2mesh(self, latents, *args, **kwargs) -> List[Latent2MeshOutput]: + raise NotImplementedError + + def point2mesh(self, *args, **kwargs) -> List[Point2MeshOutput]: + raise NotImplementedError + + +class ShapeAsLatentModule(nn.Module): + latent_shape: Tuple[int, int] + + def __init__(self, *args, **kwargs): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + def decode(self, *args, **kwargs): + raise NotImplementedError + + def query_geometry(self, *args, **kwargs): + raise NotImplementedError + + +class AlignedShapeAsLatentPLModule(pl.LightningModule): + latent_shape: Tuple[int] + + def set_shape_model_only(self): + raise NotImplementedError + + def encode(self, surface, *args, **kwargs): + raise NotImplementedError + + def decode(self, z_q, *args, **kwargs): + raise NotImplementedError + + def latent2mesh(self, latents, *args, **kwargs) -> List[Latent2MeshOutput]: + raise NotImplementedError + + def point2mesh(self, *args, **kwargs) -> List[Point2MeshOutput]: + raise NotImplementedError + + +class AlignedShapeAsLatentModule(nn.Module): + shape_model: ShapeAsLatentModule + latent_shape: Tuple[int, int] + + def __init__(self, *args, **kwargs): + super().__init__() + + def set_shape_model_only(self): + raise NotImplementedError + + def encode_image_embed(self, *args, **kwargs): + raise NotImplementedError + + def encode_text_embed(self, *args, **kwargs): + raise NotImplementedError + + def encode_shape_embed(self, *args, **kwargs): + raise NotImplementedError + + +class TexturedShapeAsLatentModule(nn.Module): + + def __init__(self, *args, **kwargs): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + def decode(self, *args, **kwargs): + raise NotImplementedError + + def query_geometry(self, *args, **kwargs): + raise NotImplementedError + + def query_color(self, *args, **kwargs): + raise NotImplementedError diff --git a/primitive_anything/michelangelo/shapevae-256.yaml b/primitive_anything/michelangelo/shapevae-256.yaml new file mode 100755 index 0000000000000000000000000000000000000000..3d158a40a7b0fa3f896b6ff89a9a9b9a7c7df2d4 --- /dev/null +++ b/primitive_anything/michelangelo/shapevae-256.yaml @@ -0,0 +1,42 @@ +model: + target: primitive_anything.michelangelo.models.tsal.asl_pl_module.AlignedShapeAsLatentPLModule + params: + shape_module_cfg: + target: primitive_anything.michelangelo.models.tsal.sal_perceiver.AlignedShapeLatentPerceiver + params: + num_latents: 256 + embed_dim: 64 + point_feats: 3 # normal + num_freqs: 8 + include_pi: false + heads: 12 + width: 768 + num_encoder_layers: 8 + num_decoder_layers: 16 + use_ln_post: true + init_scale: 0.25 + qkv_bias: false + use_checkpoint: true + aligned_module_cfg: + target: primitive_anything.michelangelo.models.tsal.clip_asl_module.CLIPAlignedShapeAsLatentModule + loss_cfg: + target: primitive_anything.michelangelo.models.tsal.loss.ContrastKLNearFar + params: + contrast_weight: 0.1 + near_weight: 0.1 + kl_weight: 0.001 + optimizer_cfg: + optimizer: + target: torch.optim.AdamW + params: + betas: [0.9, 0.99] + eps: 1.e-6 + weight_decay: 1.e-2 + + scheduler: + target: primitive_anything.michelangelo.utils.trainings.lr_scheduler.LambdaWarmUpCosineFactorScheduler + params: + warm_up_steps: 5000 + f_start: 1.e-6 + f_min: 1.e-3 + f_max: 1.0 diff --git a/primitive_anything/michelangelo/utils/__init__.py b/primitive_anything/michelangelo/utils/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..76d2dd39781034eaa33293a2243ebee3b3c982c6 --- /dev/null +++ b/primitive_anything/michelangelo/utils/__init__.py @@ -0,0 +1,4 @@ +# -*- coding: utf-8 -*- + +from .misc import get_config_from_file +from .misc import instantiate_from_config diff --git a/primitive_anything/michelangelo/utils/eval.py b/primitive_anything/michelangelo/utils/eval.py new file mode 100755 index 0000000000000000000000000000000000000000..954b9ae2643c8adb6c9af6141ede2b38a329db22 --- /dev/null +++ b/primitive_anything/michelangelo/utils/eval.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- + +import torch + + +def compute_psnr(x, y, data_range: float = 2, eps: float = 1e-7): + + mse = torch.mean((x - y) ** 2) + psnr = 10 * torch.log10(data_range / (mse + eps)) + + return psnr + diff --git a/primitive_anything/michelangelo/utils/io.py b/primitive_anything/michelangelo/utils/io.py new file mode 100755 index 0000000000000000000000000000000000000000..e651e5a8750ab485b5fbd59a70b38e339b6ed79b --- /dev/null +++ b/primitive_anything/michelangelo/utils/io.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- + +import os +import io +import tarfile +import json +import numpy as np +import numpy.lib.format + + +def mkdir(path): + os.makedirs(path, exist_ok=True) + return path + + +def npy_loads(data): + stream = io.BytesIO(data) + return np.lib.format.read_array(stream) + + +def npz_loads(data): + return np.load(io.BytesIO(data)) + + +def json_loads(data): + return json.loads(data) + + +def load_json(filepath): + with open(filepath, "r") as f: + data = json.load(f) + return data + + +def write_json(filepath, data): + with open(filepath, "w") as f: + json.dump(data, f, indent=2) + + +def extract_tar(tar_path, tar_cache_folder): + + with tarfile.open(tar_path, "r") as tar: + tar.extractall(path=tar_cache_folder) + + tar_uids = sorted(os.listdir(tar_cache_folder)) + print(f"extract tar: {tar_path} to {tar_cache_folder}") + return tar_uids diff --git a/primitive_anything/michelangelo/utils/misc.py b/primitive_anything/michelangelo/utils/misc.py new file mode 100755 index 0000000000000000000000000000000000000000..bbef357bc7c63d3c7f33d048aec68dda2b0e3992 --- /dev/null +++ b/primitive_anything/michelangelo/utils/misc.py @@ -0,0 +1,103 @@ +# -*- coding: utf-8 -*- + +import importlib +from omegaconf import OmegaConf, DictConfig, ListConfig + +import torch +import torch.distributed as dist +from typing import Union + + +def get_config_from_file(config_file: str) -> Union[DictConfig, ListConfig]: + config_file = OmegaConf.load(config_file) + + if 'base_config' in config_file.keys(): + if config_file['base_config'] == "default_base": + base_config = OmegaConf.create() + # base_config = get_default_config() + elif config_file['base_config'].endswith(".yaml"): + base_config = get_config_from_file(config_file['base_config']) + else: + raise ValueError(f"{config_file} must be `.yaml` file or it contains `base_config` key.") + + config_file = {key: value for key, value in config_file if key != "base_config"} + + return OmegaConf.merge(base_config, config_file) + + return config_file + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def get_obj_from_config(config): + if "target" not in config: + raise KeyError("Expected key `target` to instantiate.") + + return get_obj_from_str(config["target"]) + + +def instantiate_from_config(config, **kwargs): + if "target" not in config: + raise KeyError("Expected key `target` to instantiate.") + + cls = get_obj_from_str(config["target"]) + + params = config.get("params", dict()) + # params.update(kwargs) + # instance = cls(**params) + kwargs.update(params) + instance = cls(**kwargs) + + return instance + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def all_gather_batch(tensors): + """ + Performs all_gather operation on the provided tensors. + """ + # Queue the gathered tensors + world_size = get_world_size() + # There is no need for reduction in the single-proc case + if world_size == 1: + return tensors + tensor_list = [] + output_tensor = [] + for tensor in tensors: + tensor_all = [torch.ones_like(tensor) for _ in range(world_size)] + dist.all_gather( + tensor_all, + tensor, + async_op=False # performance opt + ) + + tensor_list.append(tensor_all) + + for tensor_all in tensor_list: + output_tensor.append(torch.cat(tensor_all, dim=0)) + return output_tensor diff --git a/primitive_anything/michelangelo/utils/visualizers/__init__.py b/primitive_anything/michelangelo/utils/visualizers/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..40a96afc6ff09d58a702b76e3f7dd412fe975e26 --- /dev/null +++ b/primitive_anything/michelangelo/utils/visualizers/__init__.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- diff --git a/primitive_anything/michelangelo/utils/visualizers/color_util.py b/primitive_anything/michelangelo/utils/visualizers/color_util.py new file mode 100755 index 0000000000000000000000000000000000000000..7983243fd37f5fee47bc51475dc58c460a067830 --- /dev/null +++ b/primitive_anything/michelangelo/utils/visualizers/color_util.py @@ -0,0 +1,43 @@ +import numpy as np +import matplotlib.pyplot as plt + + +# Helper functions +def get_colors(inp, colormap="viridis", normalize=True, vmin=None, vmax=None): + colormap = plt.cm.get_cmap(colormap) + if normalize: + vmin = np.min(inp) + vmax = np.max(inp) + + norm = plt.Normalize(vmin, vmax) + return colormap(norm(inp))[:, :3] + + +def gen_checkers(n_checkers_x, n_checkers_y, width=256, height=256): + # tex dims need to be power of two. + array = np.ones((width, height, 3), dtype='float32') + + # width in texels of each checker + checker_w = width / n_checkers_x + checker_h = height / n_checkers_y + + for y in range(height): + for x in range(width): + color_key = int(x / checker_w) + int(y / checker_h) + if color_key % 2 == 0: + array[x, y, :] = [1., 0.874, 0.0] + else: + array[x, y, :] = [0., 0., 0.] + return array + + +def gen_circle(width=256, height=256): + xx, yy = np.mgrid[:width, :height] + circle = (xx - width / 2 + 0.5) ** 2 + (yy - height / 2 + 0.5) ** 2 + array = np.ones((width, height, 4), dtype='float32') + array[:, :, 0] = (circle <= width) + array[:, :, 1] = (circle <= width) + array[:, :, 2] = (circle <= width) + array[:, :, 3] = circle <= width + return array + diff --git a/primitive_anything/michelangelo/utils/visualizers/html_util.py b/primitive_anything/michelangelo/utils/visualizers/html_util.py new file mode 100755 index 0000000000000000000000000000000000000000..f90fe6cfefe6108655b48c36d60db537589993d5 --- /dev/null +++ b/primitive_anything/michelangelo/utils/visualizers/html_util.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +import io +import base64 +import numpy as np +from PIL import Image + + +def to_html_frame(content): + + html_frame = f""" + + + {content} + + + """ + + return html_frame + + +def to_single_row_table(caption: str, content: str): + + table_html = f""" + + + + + +
{caption}
{content}
+ """ + + return table_html + + +def to_image_embed_tag(image: np.ndarray): + + # Convert np.ndarray to bytes + img = Image.fromarray(image) + raw_bytes = io.BytesIO() + img.save(raw_bytes, "PNG") + + # Encode bytes to base64 + image_base64 = base64.b64encode(raw_bytes.getvalue()).decode("utf-8") + + image_tag = f""" + Embedded Image + """ + + return image_tag diff --git a/primitive_anything/michelangelo/utils/visualizers/pythreejs_viewer.py b/primitive_anything/michelangelo/utils/visualizers/pythreejs_viewer.py new file mode 100755 index 0000000000000000000000000000000000000000..b3ce0f88f26fcd5e007fde2cec4816901a74ad33 --- /dev/null +++ b/primitive_anything/michelangelo/utils/visualizers/pythreejs_viewer.py @@ -0,0 +1,534 @@ +import numpy as np +from ipywidgets import embed +import pythreejs as p3s +import uuid + +from .color_util import get_colors, gen_circle, gen_checkers + + +EMBED_URL = "https://cdn.jsdelivr.net/npm/@jupyter-widgets/html-manager@1.0.1/dist/embed-amd.js" + + +class PyThreeJSViewer(object): + + def __init__(self, settings, render_mode="WEBSITE"): + self.render_mode = render_mode + self.__update_settings(settings) + self._light = p3s.DirectionalLight(color='white', position=[0, 0, 1], intensity=0.6) + self._light2 = p3s.AmbientLight(intensity=0.5) + self._cam = p3s.PerspectiveCamera(position=[0, 0, 1], lookAt=[0, 0, 0], fov=self.__s["fov"], + aspect=self.__s["width"] / self.__s["height"], children=[self._light]) + self._orbit = p3s.OrbitControls(controlling=self._cam) + self._scene = p3s.Scene(children=[self._cam, self._light2], background=self.__s["background"]) # "#4c4c80" + self._renderer = p3s.Renderer(camera=self._cam, scene=self._scene, controls=[self._orbit], + width=self.__s["width"], height=self.__s["height"], + antialias=self.__s["antialias"]) + + self.__objects = {} + self.__cnt = 0 + + def jupyter_mode(self): + self.render_mode = "JUPYTER" + + def offline(self): + self.render_mode = "OFFLINE" + + def website(self): + self.render_mode = "WEBSITE" + + def __get_shading(self, shading): + shad = {"flat": True, "wireframe": False, "wire_width": 0.03, "wire_color": "black", + "side": 'DoubleSide', "colormap": "viridis", "normalize": [None, None], + "bbox": False, "roughness": 0.5, "metalness": 0.25, "reflectivity": 1.0, + "line_width": 1.0, "line_color": "black", + "point_color": "red", "point_size": 0.01, "point_shape": "circle", + "text_color": "red" + } + for k in shading: + shad[k] = shading[k] + return shad + + def __update_settings(self, settings={}): + sett = {"width": 600, "height": 600, "antialias": True, "scale": 1.5, "background": "#ffffff", + "fov": 30} + for k in settings: + sett[k] = settings[k] + self.__s = sett + + def __add_object(self, obj, parent=None): + if not parent: # Object is added to global scene and objects dict + self.__objects[self.__cnt] = obj + self.__cnt += 1 + self._scene.add(obj["mesh"]) + else: # Object is added to parent object and NOT to objects dict + parent.add(obj["mesh"]) + + self.__update_view() + + if self.render_mode == "JUPYTER": + return self.__cnt - 1 + elif self.render_mode == "WEBSITE": + return self + + def __add_line_geometry(self, lines, shading, obj=None): + lines = lines.astype("float32", copy=False) + mi = np.min(lines, axis=0) + ma = np.max(lines, axis=0) + + geometry = p3s.LineSegmentsGeometry(positions=lines.reshape((-1, 2, 3))) + material = p3s.LineMaterial(linewidth=shading["line_width"], color=shading["line_color"]) + # , vertexColors='VertexColors'), + lines = p3s.LineSegments2(geometry=geometry, material=material) # type='LinePieces') + line_obj = {"geometry": geometry, "mesh": lines, "material": material, + "max": ma, "min": mi, "type": "Lines", "wireframe": None} + + if obj: + return self.__add_object(line_obj, obj), line_obj + else: + return self.__add_object(line_obj) + + def __update_view(self): + if len(self.__objects) == 0: + return + ma = np.zeros((len(self.__objects), 3)) + mi = np.zeros((len(self.__objects), 3)) + for r, obj in enumerate(self.__objects): + ma[r] = self.__objects[obj]["max"] + mi[r] = self.__objects[obj]["min"] + ma = np.max(ma, axis=0) + mi = np.min(mi, axis=0) + diag = np.linalg.norm(ma - mi) + mean = ((ma - mi) / 2 + mi).tolist() + scale = self.__s["scale"] * (diag) + self._orbit.target = mean + self._cam.lookAt(mean) + self._cam.position = [mean[0], mean[1], mean[2] + scale] + self._light.position = [mean[0], mean[1], mean[2] + scale] + + self._orbit.exec_three_obj_method('update') + self._cam.exec_three_obj_method('updateProjectionMatrix') + + def __get_bbox(self, v): + m = np.min(v, axis=0) + M = np.max(v, axis=0) + + # Corners of the bounding box + v_box = np.array([[m[0], m[1], m[2]], [M[0], m[1], m[2]], [M[0], M[1], m[2]], [m[0], M[1], m[2]], + [m[0], m[1], M[2]], [M[0], m[1], M[2]], [M[0], M[1], M[2]], [m[0], M[1], M[2]]]) + + f_box = np.array([[0, 1], [1, 2], [2, 3], [3, 0], [4, 5], [5, 6], [6, 7], [7, 4], + [0, 4], [1, 5], [2, 6], [7, 3]], dtype=np.uint32) + return v_box, f_box + + def __get_colors(self, v, f, c, sh): + coloring = "VertexColors" + if type(c) == np.ndarray and c.size == 3: # Single color + colors = np.ones_like(v) + colors[:, 0] = c[0] + colors[:, 1] = c[1] + colors[:, 2] = c[2] + # print("Single colors") + elif type(c) == np.ndarray and len(c.shape) == 2 and c.shape[1] == 3: # Color values for + if c.shape[0] == f.shape[0]: # faces + colors = np.hstack([c, c, c]).reshape((-1, 3)) + coloring = "FaceColors" + # print("Face color values") + elif c.shape[0] == v.shape[0]: # vertices + colors = c + # print("Vertex color values") + else: # Wrong size, fallback + print("Invalid color array given! Supported are numpy arrays.", type(c)) + colors = np.ones_like(v) + colors[:, 0] = 1.0 + colors[:, 1] = 0.874 + colors[:, 2] = 0.0 + elif type(c) == np.ndarray and c.size == f.shape[0]: # Function values for faces + normalize = sh["normalize"][0] != None and sh["normalize"][1] != None + cc = get_colors(c, sh["colormap"], normalize=normalize, + vmin=sh["normalize"][0], vmax=sh["normalize"][1]) + # print(cc.shape) + colors = np.hstack([cc, cc, cc]).reshape((-1, 3)) + coloring = "FaceColors" + # print("Face function values") + elif type(c) == np.ndarray and c.size == v.shape[0]: # Function values for vertices + normalize = sh["normalize"][0] != None and sh["normalize"][1] != None + colors = get_colors(c, sh["colormap"], normalize=normalize, + vmin=sh["normalize"][0], vmax=sh["normalize"][1]) + # print("Vertex function values") + + else: + colors = np.ones_like(v) + colors[:, 0] = 1.0 + colors[:, 1] = 0.874 + colors[:, 2] = 0.0 + + # No color + if c is not None: + print("Invalid color array given! Supported are numpy arrays.", type(c)) + + return colors, coloring + + def __get_point_colors(self, v, c, sh): + v_color = True + if c is None: # No color given, use global color + # conv = mpl.colors.ColorConverter() + colors = sh["point_color"] # np.array(conv.to_rgb(sh["point_color"])) + v_color = False + elif isinstance(c, str): # No color given, use global color + # conv = mpl.colors.ColorConverter() + colors = c # np.array(conv.to_rgb(c)) + v_color = False + elif type(c) == np.ndarray and len(c.shape) == 2 and c.shape[0] == v.shape[0] and c.shape[1] == 3: + # Point color + colors = c.astype("float32", copy=False) + + elif isinstance(c, np.ndarray) and len(c.shape) == 2 and c.shape[0] == v.shape[0] and c.shape[1] != 3: + # Function values for vertices, but the colors are features + c_norm = np.linalg.norm(c, ord=2, axis=-1) + normalize = sh["normalize"][0] != None and sh["normalize"][1] != None + colors = get_colors(c_norm, sh["colormap"], normalize=normalize, + vmin=sh["normalize"][0], vmax=sh["normalize"][1]) + colors = colors.astype("float32", copy=False) + + elif type(c) == np.ndarray and c.size == v.shape[0]: # Function color + normalize = sh["normalize"][0] != None and sh["normalize"][1] != None + colors = get_colors(c, sh["colormap"], normalize=normalize, + vmin=sh["normalize"][0], vmax=sh["normalize"][1]) + colors = colors.astype("float32", copy=False) + # print("Vertex function values") + + else: + print("Invalid color array given! Supported are numpy arrays.", type(c)) + colors = sh["point_color"] + v_color = False + + return colors, v_color + + def add_mesh(self, v, f, c=None, uv=None, n=None, shading={}, texture_data=None, **kwargs): + shading.update(kwargs) + sh = self.__get_shading(shading) + mesh_obj = {} + + # it is a tet + if v.shape[1] == 3 and f.shape[1] == 4: + f_tmp = np.ndarray([f.shape[0] * 4, 3], dtype=f.dtype) + for i in range(f.shape[0]): + f_tmp[i * 4 + 0] = np.array([f[i][1], f[i][0], f[i][2]]) + f_tmp[i * 4 + 1] = np.array([f[i][0], f[i][1], f[i][3]]) + f_tmp[i * 4 + 2] = np.array([f[i][1], f[i][2], f[i][3]]) + f_tmp[i * 4 + 3] = np.array([f[i][2], f[i][0], f[i][3]]) + f = f_tmp + + if v.shape[1] == 2: + v = np.append(v, np.zeros([v.shape[0], 1]), 1) + + # Type adjustment vertices + v = v.astype("float32", copy=False) + + # Color setup + colors, coloring = self.__get_colors(v, f, c, sh) + + # Type adjustment faces and colors + c = colors.astype("float32", copy=False) + + # Material and geometry setup + ba_dict = {"color": p3s.BufferAttribute(c)} + if coloring == "FaceColors": + verts = np.zeros((f.shape[0] * 3, 3), dtype="float32") + for ii in range(f.shape[0]): + # print(ii*3, f[ii]) + verts[ii * 3] = v[f[ii, 0]] + verts[ii * 3 + 1] = v[f[ii, 1]] + verts[ii * 3 + 2] = v[f[ii, 2]] + v = verts + else: + f = f.astype("uint32", copy=False).ravel() + ba_dict["index"] = p3s.BufferAttribute(f, normalized=False) + + ba_dict["position"] = p3s.BufferAttribute(v, normalized=False) + + if uv is not None: + uv = (uv - np.min(uv)) / (np.max(uv) - np.min(uv)) + if texture_data is None: + texture_data = gen_checkers(20, 20) + tex = p3s.DataTexture(data=texture_data, format="RGBFormat", type="FloatType") + material = p3s.MeshStandardMaterial(map=tex, reflectivity=sh["reflectivity"], side=sh["side"], + roughness=sh["roughness"], metalness=sh["metalness"], + flatShading=sh["flat"], + polygonOffset=True, polygonOffsetFactor=1, polygonOffsetUnits=5) + ba_dict["uv"] = p3s.BufferAttribute(uv.astype("float32", copy=False)) + else: + material = p3s.MeshStandardMaterial(vertexColors=coloring, reflectivity=sh["reflectivity"], + side=sh["side"], roughness=sh["roughness"], metalness=sh["metalness"], + flatShading=sh["flat"], + polygonOffset=True, polygonOffsetFactor=1, polygonOffsetUnits=5) + + if type(n) != type(None) and coloring == "VertexColors": # TODO: properly handle normals for FaceColors as well + ba_dict["normal"] = p3s.BufferAttribute(n.astype("float32", copy=False), normalized=True) + + geometry = p3s.BufferGeometry(attributes=ba_dict) + + if coloring == "VertexColors" and type(n) == type(None): + geometry.exec_three_obj_method('computeVertexNormals') + elif coloring == "FaceColors" and type(n) == type(None): + geometry.exec_three_obj_method('computeFaceNormals') + + # Mesh setup + mesh = p3s.Mesh(geometry=geometry, material=material) + + # Wireframe setup + mesh_obj["wireframe"] = None + if sh["wireframe"]: + wf_geometry = p3s.WireframeGeometry(mesh.geometry) # WireframeGeometry + wf_material = p3s.LineBasicMaterial(color=sh["wire_color"], linewidth=sh["wire_width"]) + wireframe = p3s.LineSegments(wf_geometry, wf_material) + mesh.add(wireframe) + mesh_obj["wireframe"] = wireframe + + # Bounding box setup + if sh["bbox"]: + v_box, f_box = self.__get_bbox(v) + _, bbox = self.add_edges(v_box, f_box, sh, mesh) + mesh_obj["bbox"] = [bbox, v_box, f_box] + + # Object setup + mesh_obj["max"] = np.max(v, axis=0) + mesh_obj["min"] = np.min(v, axis=0) + mesh_obj["geometry"] = geometry + mesh_obj["mesh"] = mesh + mesh_obj["material"] = material + mesh_obj["type"] = "Mesh" + mesh_obj["shading"] = sh + mesh_obj["coloring"] = coloring + mesh_obj["arrays"] = [v, f, c] # TODO replays with proper storage or remove if not needed + + return self.__add_object(mesh_obj) + + def add_lines(self, beginning, ending, shading={}, obj=None, **kwargs): + shading.update(kwargs) + if len(beginning.shape) == 1: + if len(beginning) == 2: + beginning = np.array([[beginning[0], beginning[1], 0]]) + else: + if beginning.shape[1] == 2: + beginning = np.append( + beginning, np.zeros([beginning.shape[0], 1]), 1) + if len(ending.shape) == 1: + if len(ending) == 2: + ending = np.array([[ending[0], ending[1], 0]]) + else: + if ending.shape[1] == 2: + ending = np.append( + ending, np.zeros([ending.shape[0], 1]), 1) + + sh = self.__get_shading(shading) + lines = np.hstack([beginning, ending]) + lines = lines.reshape((-1, 3)) + return self.__add_line_geometry(lines, sh, obj) + + def add_edges(self, vertices, edges, shading={}, obj=None, **kwargs): + shading.update(kwargs) + if vertices.shape[1] == 2: + vertices = np.append( + vertices, np.zeros([vertices.shape[0], 1]), 1) + sh = self.__get_shading(shading) + lines = np.zeros((edges.size, 3)) + cnt = 0 + for e in edges: + lines[cnt, :] = vertices[e[0]] + lines[cnt + 1, :] = vertices[e[1]] + cnt += 2 + return self.__add_line_geometry(lines, sh, obj) + + def add_points(self, points, c=None, shading={}, obj=None, **kwargs): + shading.update(kwargs) + if len(points.shape) == 1: + if len(points) == 2: + points = np.array([[points[0], points[1], 0]]) + else: + if points.shape[1] == 2: + points = np.append( + points, np.zeros([points.shape[0], 1]), 1) + sh = self.__get_shading(shading) + points = points.astype("float32", copy=False) + mi = np.min(points, axis=0) + ma = np.max(points, axis=0) + + g_attributes = {"position": p3s.BufferAttribute(points, normalized=False)} + m_attributes = {"size": sh["point_size"]} + + if sh["point_shape"] == "circle": # Plot circles + tex = p3s.DataTexture(data=gen_circle(16, 16), format="RGBAFormat", type="FloatType") + m_attributes["map"] = tex + m_attributes["alphaTest"] = 0.5 + m_attributes["transparency"] = True + else: # Plot squares + pass + + colors, v_colors = self.__get_point_colors(points, c, sh) + if v_colors: # Colors per point + m_attributes["vertexColors"] = 'VertexColors' + g_attributes["color"] = p3s.BufferAttribute(colors, normalized=False) + + else: # Colors for all points + m_attributes["color"] = colors + + material = p3s.PointsMaterial(**m_attributes) + geometry = p3s.BufferGeometry(attributes=g_attributes) + points = p3s.Points(geometry=geometry, material=material) + point_obj = {"geometry": geometry, "mesh": points, "material": material, + "max": ma, "min": mi, "type": "Points", "wireframe": None} + + if obj: + return self.__add_object(point_obj, obj), point_obj + else: + return self.__add_object(point_obj) + + def remove_object(self, obj_id): + if obj_id not in self.__objects: + print("Invalid object id. Valid ids are: ", list(self.__objects.keys())) + return + self._scene.remove(self.__objects[obj_id]["mesh"]) + del self.__objects[obj_id] + self.__update_view() + + def reset(self): + for obj_id in list(self.__objects.keys()).copy(): + self._scene.remove(self.__objects[obj_id]["mesh"]) + del self.__objects[obj_id] + self.__update_view() + + def update_object(self, oid=0, vertices=None, colors=None, faces=None): + obj = self.__objects[oid] + if type(vertices) != type(None): + if obj["coloring"] == "FaceColors": + f = obj["arrays"][1] + verts = np.zeros((f.shape[0] * 3, 3), dtype="float32") + for ii in range(f.shape[0]): + # print(ii*3, f[ii]) + verts[ii * 3] = vertices[f[ii, 0]] + verts[ii * 3 + 1] = vertices[f[ii, 1]] + verts[ii * 3 + 2] = vertices[f[ii, 2]] + v = verts + + else: + v = vertices.astype("float32", copy=False) + obj["geometry"].attributes["position"].array = v + # self.wireframe.attributes["position"].array = v # Wireframe updates? + obj["geometry"].attributes["position"].needsUpdate = True + # obj["geometry"].exec_three_obj_method('computeVertexNormals') + if type(colors) != type(None): + colors, coloring = self.__get_colors(obj["arrays"][0], obj["arrays"][1], colors, obj["shading"]) + colors = colors.astype("float32", copy=False) + obj["geometry"].attributes["color"].array = colors + obj["geometry"].attributes["color"].needsUpdate = True + if type(faces) != type(None): + if obj["coloring"] == "FaceColors": + print("Face updates are currently only possible in vertex color mode.") + return + f = faces.astype("uint32", copy=False).ravel() + print(obj["geometry"].attributes) + obj["geometry"].attributes["index"].array = f + # self.wireframe.attributes["position"].array = v # Wireframe updates? + obj["geometry"].attributes["index"].needsUpdate = True + # obj["geometry"].exec_three_obj_method('computeVertexNormals') + # self.mesh.geometry.verticesNeedUpdate = True + # self.mesh.geometry.elementsNeedUpdate = True + # self.update() + if self.render_mode == "WEBSITE": + return self + + # def update(self): + # self.mesh.exec_three_obj_method('update') + # self.orbit.exec_three_obj_method('update') + # self.cam.exec_three_obj_method('updateProjectionMatrix') + # self.scene.exec_three_obj_method('update') + + def add_text(self, text, shading={}, **kwargs): + shading.update(kwargs) + sh = self.__get_shading(shading) + tt = p3s.TextTexture(string=text, color=sh["text_color"]) + sm = p3s.SpriteMaterial(map=tt) + text = p3s.Sprite(material=sm, scaleToTexture=True) + self._scene.add(text) + + # def add_widget(self, widget, callback): + # self.widgets.append(widget) + # widget.observe(callback, names='value') + + # def add_dropdown(self, options, default, desc, cb): + # widget = widgets.Dropdown(options=options, value=default, description=desc) + # self.__widgets.append(widget) + # widget.observe(cb, names="value") + # display(widget) + + # def add_button(self, text, cb): + # button = widgets.Button(description=text) + # self.__widgets.append(button) + # button.on_click(cb) + # display(button) + + def to_html(self, imports=True, html_frame=True): + # Bake positions (fixes centering bug in offline rendering) + if len(self.__objects) == 0: + return + ma = np.zeros((len(self.__objects), 3)) + mi = np.zeros((len(self.__objects), 3)) + for r, obj in enumerate(self.__objects): + ma[r] = self.__objects[obj]["max"] + mi[r] = self.__objects[obj]["min"] + ma = np.max(ma, axis=0) + mi = np.min(mi, axis=0) + diag = np.linalg.norm(ma - mi) + mean = (ma - mi) / 2 + mi + for r, obj in enumerate(self.__objects): + v = self.__objects[obj]["geometry"].attributes["position"].array + v -= mean + v += np.array([0.0, .9, 0.0]) #! to move the obj to the center of window + + scale = self.__s["scale"] * (diag) + self._orbit.target = [0.0, 0.0, 0.0] + self._cam.lookAt([0.0, 0.0, 0.0]) + # self._cam.position = [0.0, 0.0, scale] + self._cam.position = [0.0, 0.5, scale * 1.3] #! show four complete meshes in the window + self._light.position = [0.0, 0.0, scale] + + state = embed.dependency_state(self._renderer) + + # Somehow these entries are missing when the state is exported in python. + # Exporting from the GUI works, so we are inserting the missing entries. + for k in state: + if state[k]["model_name"] == "OrbitControlsModel": + state[k]["state"]["maxAzimuthAngle"] = "inf" + state[k]["state"]["maxDistance"] = "inf" + state[k]["state"]["maxZoom"] = "inf" + state[k]["state"]["minAzimuthAngle"] = "-inf" + + tpl = embed.load_requirejs_template + if not imports: + embed.load_requirejs_template = "" + + s = embed.embed_snippet(self._renderer, state=state, embed_url=EMBED_URL) + # s = embed.embed_snippet(self.__w, state=state) + embed.load_requirejs_template = tpl + + if html_frame: + s = "\n\n" + s + "\n\n" + + # Revert changes + for r, obj in enumerate(self.__objects): + v = self.__objects[obj]["geometry"].attributes["position"].array + v += mean + self.__update_view() + + return s + + def save(self, filename=""): + if filename == "": + uid = str(uuid.uuid4()) + ".html" + else: + filename = filename.replace(".html", "") + uid = filename + '.html' + with open(uid, "w") as f: + f.write(self.to_html()) + print("Plot saved to file %s." % uid) diff --git a/primitive_anything/primitive_dataset.py b/primitive_anything/primitive_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..be292a9dd9576be19c5f5e473cf583eb5ecde0f8 --- /dev/null +++ b/primitive_anything/primitive_dataset.py @@ -0,0 +1,157 @@ +import copy +import json +import os + +import numpy as np +from scipy.linalg import polar +from scipy.spatial.transform import Rotation +import open3d as o3d +import torch +from torch.utils.data import Dataset + +from .utils import exists +from .utils.logger import print_log + + +def create_dataset(cfg_dataset): + kwargs = cfg_dataset + name = kwargs.pop('name') + dataset = get_dataset(name)(**kwargs) + print_log(f"Dataset '{name}' init: kwargs={kwargs}, len={len(dataset)}") + return dataset + +def get_dataset(name): + return { + 'base': PrimitiveDataset, + }[name] + + +SHAPE_CODE = { + 'CubeBevel': 0, + 'SphereSharp': 1, + 'CylinderSharp': 2, +} + + +class PrimitiveDataset(Dataset): + def __init__(self, + pc_dir, + bs_dir, + max_length=144, + range_scale=[0, 1], + range_rotation=[-180, 180], + range_translation=[-1, 1], + rotation_type='euler', + pc_format='pc', + ): + self.data_filename = os.listdir(pc_dir) + + self.pc_dir = pc_dir + self.max_length = max_length + self.range_scale = range_scale + self.range_rotation = range_rotation + self.range_translation = range_translation + self.rotation_type = rotation_type + self.pc_format = pc_format + + with open(os.path.join(bs_dir, 'basic_shapes.json'), 'r', encoding='utf-8') as f: + basic_shapes = json.load(f) + + self.typeid_map = { + 1101002001034001: 'CubeBevel', + 1101002001034010: 'SphereSharp', + 1101002001034002: 'CylinderSharp', + } + + def __len__(self): + return len(self.data_filename) + + def __getitem__(self, idx): + pc_file = os.path.join(self.pc_dir, self.data_filename[idx]) + pc = o3d.io.read_point_cloud(pc_file) + + model_data = {} + + points = torch.from_numpy(np.asarray(pc.points)).float() + colors = torch.from_numpy(np.asarray(pc.colors)).float() + normals = torch.from_numpy(np.asarray(pc.normals)).float() + if self.pc_format == 'pc': + model_data['pc'] = torch.concatenate([points, colors], dim=-1).T + elif self.pc_format == 'pn': + model_data['pc'] = torch.concatenate([points, normals], dim=-1) + elif self.pc_format == 'pcn': + model_data['pc'] = torch.concatenate([points, colors, normals], dim=-1) + else: + raise ValueError(f'invalid pc_format: {self.pc_format}') + + return model_data + + +def get_typeid_shapename_mapping(shapenames, config_data): + typeid_map = {} + for info in config_data.values(): + for shapename in shapenames: + if shapename[3:-4] in info['bpPath']: + typeid_map[info['typeId']] = shapename.split('_')[3] + break + return typeid_map + + +def check_valid_range(data, value_range): + lo, hi = value_range + assert hi > lo + return np.logical_and(data >= lo, hi <= hi).all() + + +def quat_to_euler(quat, degree=True): + return Rotation.from_quat(quat).as_euler('XYZ', degrees=degree) + + +def quat_to_rotvec(quat, degree=True): + return Rotation.from_quat(quat).as_rotvec(degrees=degree) + + +def rotate_axis(euler): + trans = np.eye(4, 4) + trans[:3, :3] = Rotation.from_euler('xyz', euler).as_matrix() + return trans + + +def SRT_quat_to_matrix(scale, quat, translation): + rotation_matrix = Rotation.from_quat(quat).as_matrix() + transform_matrix = np.eye(4) + transform_matrix[:3, :3] = rotation_matrix * scale + transform_matrix[:3, 3] = translation + return transform_matrix + + +def matrix_to_SRT_quat2(transform_matrix): # Polar Decomposition + transform_matrix = np.array(transform_matrix) + translation = transform_matrix[:3, 3] + rotation_matrix, scale_matrix = polar(transform_matrix[:3,:3]) + quat = Rotation.from_matrix(rotation_matrix).as_quat() + scale = np.diag(scale_matrix) + return scale, quat, translation + + +def apply_transform_to_block(block, trans_aug): + precision_loss = False + trans = SRT_quat_to_matrix( + block['data']['scale'], + block['data']['rotation'], + block['data']['location'] + ) + + trans = trans_aug @ trans + scale, quat, translation = matrix_to_SRT_quat2(trans) + + trans_rec = SRT_quat_to_matrix(scale, quat, translation) + if not np.allclose(trans, trans_rec, atol=1e-1): + precision_loss = True + return precision_loss, {} + + new_block = copy.deepcopy(block) + new_block['data']['scale'] = scale.tolist() + new_block['data']['rotation'] = quat.tolist() + new_block['data']['location'] = translation.tolist() + return precision_loss, new_block diff --git a/primitive_anything/primitive_transformer.py b/primitive_anything/primitive_transformer.py new file mode 100755 index 0000000000000000000000000000000000000000..9c3d748db3a643755ec5c66ae36fb257cfbdbf27 --- /dev/null +++ b/primitive_anything/primitive_transformer.py @@ -0,0 +1,948 @@ +from __future__ import annotations + +from functools import partial +from math import ceil +import os + +from accelerate.utils import DistributedDataParallelKwargs +from beartype.typing import Tuple, Callable, List + +from einops import rearrange, repeat, reduce, pack +from gateloop_transformer import SimpleGateLoopLayer +from huggingface_hub import PyTorchModelHubMixin +import numpy as np +import open3d as o3d +from tqdm import tqdm +import torch +from torch import nn, Tensor +from torch.nn import Module, ModuleList +import torch.nn.functional as F +from pytorch3d.loss import chamfer_distance +from pytorch3d.transforms import euler_angles_to_matrix +from x_transformers import Decoder +from x_transformers.x_transformers import LayerIntermediates +from x_transformers.autoregressive_wrapper import eval_decorator + +from .michelangelo import ShapeConditioner as ShapeConditioner_miche +from .utils import ( + discretize, + undiscretize, + set_module_requires_grad_, + default, + exists, + safe_cat, + identity, + is_tensor_empty, +) +from .utils.typing import Float, Int, Bool, typecheck + + +# constants + +DEFAULT_DDP_KWARGS = DistributedDataParallelKwargs( + find_unused_parameters = True +) +SHAPE_CODE = { + 'CubeBevel': 0, + 'SphereSharp': 1, + 'CylinderSharp': 2, +} +BS_NAME = { + 0: 'CubeBevel', + 1: 'SphereSharp', + 2: 'CylinderSharp', +} + +# FiLM block + +class FiLM(Module): + def __init__(self, dim, dim_out = None): + super().__init__() + dim_out = default(dim_out, dim) + + self.to_gamma = nn.Linear(dim, dim_out, bias = False) + self.to_beta = nn.Linear(dim, dim_out) + + self.gamma_mult = nn.Parameter(torch.zeros(1,)) + self.beta_mult = nn.Parameter(torch.zeros(1,)) + + def forward(self, x, cond): + gamma, beta = self.to_gamma(cond), self.to_beta(cond) + gamma, beta = tuple(rearrange(t, 'b d -> b 1 d') for t in (gamma, beta)) + + # for initializing to identity + + gamma = (1 + self.gamma_mult * gamma.tanh()) + beta = beta.tanh() * self.beta_mult + + # classic film + + return x * gamma + beta + +# gateloop layers + +class GateLoopBlock(Module): + def __init__( + self, + dim, + *, + depth, + use_heinsen = True + ): + super().__init__() + self.gateloops = ModuleList([]) + + for _ in range(depth): + gateloop = SimpleGateLoopLayer(dim = dim, use_heinsen = use_heinsen) + self.gateloops.append(gateloop) + + def forward( + self, + x, + cache = None + ): + received_cache = exists(cache) + + if is_tensor_empty(x): + return x, None + + if received_cache: + prev, x = x[:, :-1], x[:, -1:] + + cache = default(cache, []) + cache = iter(cache) + + new_caches = [] + for gateloop in self.gateloops: + layer_cache = next(cache, None) + out, new_cache = gateloop(x, cache = layer_cache, return_cache = True) + new_caches.append(new_cache) + x = x + out + + if received_cache: + x = torch.cat((prev, x), dim = -2) + + return x, new_caches + + +def top_k_2(logits, frac_num_tokens=0.1, k=None): + num_tokens = logits.shape[-1] + + k = default(k, ceil(frac_num_tokens * num_tokens)) + k = min(k, num_tokens) + + val, ind = torch.topk(logits, k) + probs = torch.full_like(logits, float('-inf')) + probs.scatter_(2, ind, val) + return probs + + +def soft_argmax(labels): + indices = torch.arange(labels.size(-1), dtype=labels.dtype, device=labels.device) + soft_argmax = torch.sum(labels * indices, dim=-1) + return soft_argmax + + +class PrimitiveTransformerDiscrete(Module, PyTorchModelHubMixin): + @typecheck + def __init__( + self, + *, + num_discrete_scale = 128, + continuous_range_scale: List[float, float] = [0, 1], + dim_scale_embed = 64, + num_discrete_rotation = 180, + continuous_range_rotation: List[float, float] = [-180, 180], + dim_rotation_embed = 64, + num_discrete_translation = 128, + continuous_range_translation: List[float, float] = [-1, 1], + dim_translation_embed = 64, + num_type = 3, + dim_type_embed = 64, + embed_order = 'ctrs', + bin_smooth_blur_sigma = 0.4, + dim: int | Tuple[int, int] = 512, + flash_attn = True, + attn_depth = 12, + attn_dim_head = 64, + attn_heads = 16, + attn_kwargs: dict = dict( + ff_glu = True, + attn_num_mem_kv = 4 + ), + max_primitive_len = 144, + dropout = 0., + coarse_pre_gateloop_depth = 2, + coarse_post_gateloop_depth = 0, + coarse_adaptive_rmsnorm = False, + gateloop_use_heinsen = False, + pad_id = -1, + num_sos_tokens = None, + condition_on_shape = True, + shape_cond_with_cross_attn = False, + shape_cond_with_film = False, + shape_cond_with_cat = False, + shape_condition_model_type = 'michelangelo', + shape_condition_len = 1, + shape_condition_dim = None, + cross_attn_num_mem_kv = 4, # needed for preventing nan when dropping out shape condition + loss_weight: dict = dict( + eos = 1.0, + type = 1.0, + scale = 1.0, + rotation = 1.0, + translation = 1.0, + reconstruction = 1.0, + scale_huber = 1.0, + rotation_huber = 1.0, + translation_huber = 1.0, + ), + bs_pc_dir=None, + ): + super().__init__() + + # feature embedding + self.num_discrete_scale = num_discrete_scale + self.continuous_range_scale = continuous_range_scale + self.discretize_scale = partial(discretize, num_discrete=num_discrete_scale, continuous_range=continuous_range_scale) + self.undiscretize_scale = partial(undiscretize, num_discrete=num_discrete_scale, continuous_range=continuous_range_scale) + self.scale_embed = nn.Embedding(num_discrete_scale, dim_scale_embed) + + self.num_discrete_rotation = num_discrete_rotation + self.continuous_range_rotation = continuous_range_rotation + self.discretize_rotation = partial(discretize, num_discrete=num_discrete_rotation, continuous_range=continuous_range_rotation) + self.undiscretize_rotation = partial(undiscretize, num_discrete=num_discrete_rotation, continuous_range=continuous_range_rotation) + self.rotation_embed = nn.Embedding(num_discrete_rotation, dim_rotation_embed) + + self.num_discrete_translation = num_discrete_translation + self.continuous_range_translation = continuous_range_translation + self.discretize_translation = partial(discretize, num_discrete=num_discrete_translation, continuous_range=continuous_range_translation) + self.undiscretize_translation = partial(undiscretize, num_discrete=num_discrete_translation, continuous_range=continuous_range_translation) + self.translation_embed = nn.Embedding(num_discrete_translation, dim_translation_embed) + + self.num_type = num_type + self.type_embed = nn.Embedding(num_type, dim_type_embed) + + self.embed_order = embed_order + self.bin_smooth_blur_sigma = bin_smooth_blur_sigma + + # initial dimension + + self.dim = dim + init_dim = 3 * (dim_scale_embed + dim_rotation_embed + dim_translation_embed) + dim_type_embed + + # project into model dimension + self.project_in = nn.Linear(init_dim, dim) + + num_sos_tokens = default(num_sos_tokens, 1 if not condition_on_shape or not shape_cond_with_film else 4) + assert num_sos_tokens > 0 + + self.num_sos_tokens = num_sos_tokens + self.sos_token = nn.Parameter(torch.randn(num_sos_tokens, dim)) + + # the transformer eos token + self.eos_token = nn.Parameter(torch.randn(1, dim)) + + self.emb_layernorm = nn.LayerNorm(dim) + self.max_seq_len = max_primitive_len + + # shape condition + + self.condition_on_shape = condition_on_shape + self.shape_cond_with_cross_attn = False + self.shape_cond_with_cat = False + self.shape_condition_model_type = '' + self.conditioner = None + dim_shape = None + + if condition_on_shape: + assert shape_cond_with_cross_attn or shape_cond_with_film or shape_cond_with_cat + self.shape_cond_with_cross_attn = shape_cond_with_cross_attn + self.shape_cond_with_cat = shape_cond_with_cat + self.shape_condition_model_type = shape_condition_model_type + if 'michelangelo' in shape_condition_model_type: + self.conditioner = ShapeConditioner_miche(dim_latent=shape_condition_dim) + self.to_cond_dim = nn.Linear(self.conditioner.dim_model_out * 2, self.conditioner.dim_latent) + self.to_cond_dim_head = nn.Linear(self.conditioner.dim_model_out, self.conditioner.dim_latent) + else: + raise ValueError(f'unknown shape_condition_model_type {self.shape_condition_model_type}') + + dim_shape = self.conditioner.dim_latent + set_module_requires_grad_(self.conditioner, False) + + self.shape_coarse_film_cond = FiLM(dim_shape, dim) if shape_cond_with_film else identity + + self.coarse_gateloop_block = GateLoopBlock(dim, depth=coarse_pre_gateloop_depth, use_heinsen=gateloop_use_heinsen) if coarse_pre_gateloop_depth > 0 else None + self.coarse_post_gateloop_block = GateLoopBlock(dim, depth=coarse_post_gateloop_depth, use_heinsen=gateloop_use_heinsen) if coarse_post_gateloop_depth > 0 else None + self.coarse_adaptive_rmsnorm = coarse_adaptive_rmsnorm + + self.decoder = Decoder( + dim=dim, + depth=attn_depth, + heads=attn_heads, + attn_dim_head=attn_dim_head, + attn_flash=flash_attn, + attn_dropout=dropout, + ff_dropout=dropout, + use_adaptive_rmsnorm=coarse_adaptive_rmsnorm, + dim_condition=dim_shape, + cross_attend=self.shape_cond_with_cross_attn, + cross_attn_dim_context=dim_shape, + cross_attn_num_mem_kv=cross_attn_num_mem_kv, + **attn_kwargs + ) + + # to logits + self.to_eos_logits = nn.Sequential( + nn.Linear(dim, dim), + nn.ReLU(), + nn.Linear(dim, 1) + ) + self.to_type_logits = nn.Sequential( + nn.Linear(dim, dim), + nn.ReLU(), + nn.Linear(dim, num_type) + ) + self.to_translation_logits = nn.Sequential( + nn.Linear(dim + dim_type_embed, dim), + nn.ReLU(), + nn.Linear(dim, 3 * num_discrete_translation) + ) + self.to_rotation_logits = nn.Sequential( + nn.Linear(dim + dim_type_embed + 3 * dim_translation_embed, dim), + nn.ReLU(), + nn.Linear(dim, 3 * num_discrete_rotation) + ) + self.to_scale_logits = nn.Sequential( + nn.Linear(dim + dim_type_embed + 3 * (dim_translation_embed + dim_rotation_embed), dim), + nn.ReLU(), + nn.Linear(dim, 3 * num_discrete_scale) + ) + + self.pad_id = pad_id + + bs_pc_map = {} + for bs_name, type_code in SHAPE_CODE.items(): + pc = o3d.io.read_point_cloud(os.path.join(bs_pc_dir, f'SM_GR_BS_{bs_name}_001.ply')) + bs_pc_map[type_code] = torch.from_numpy(np.asarray(pc.points)).float() + bs_pc_list = [] + for i in range(len(bs_pc_map)): + bs_pc_list.append(bs_pc_map[i]) + self.bs_pc = torch.stack(bs_pc_list, dim=0) + + self.rotation_matrix_align_coord = euler_angles_to_matrix( + torch.Tensor([np.pi/2, 0, 0]), 'XYZ').unsqueeze(0).unsqueeze(0) + + @property + def device(self): + return next(self.parameters()).device + + @typecheck + @torch.no_grad() + def embed_pc(self, pc: Tensor): + if 'michelangelo' in self.shape_condition_model_type: + pc_head, pc_embed = self.conditioner(shape=pc) + pc_embed = torch.cat([self.to_cond_dim_head(pc_head), self.to_cond_dim(pc_embed)], dim=-2).detach() + else: + raise ValueError(f'unknown shape_condition_model_type {self.shape_condition_model_type}') + + return pc_embed + + @typecheck + def recon_primitives( + self, + scale_logits: Float['b np 3 nd'], + rotation_logits: Float['b np 3 nd'], + translation_logits: Float['b np 3 nd'], + type_logits: Int['b np nd'], + primitive_mask: Bool['b np'] + ): + recon_scale = self.undiscretize_scale(scale_logits.argmax(dim=-1)) + recon_scale = recon_scale.masked_fill(~primitive_mask.unsqueeze(-1), float('nan')) + recon_rotation = self.undiscretize_rotation(rotation_logits.argmax(dim=-1)) + recon_rotation = recon_rotation.masked_fill(~primitive_mask.unsqueeze(-1), float('nan')) + recon_translation = self.undiscretize_translation(translation_logits.argmax(dim=-1)) + recon_translation = recon_translation.masked_fill(~primitive_mask.unsqueeze(-1), float('nan')) + recon_type_code = type_logits.argmax(dim=-1) + recon_type_code = recon_type_code.masked_fill(~primitive_mask, -1) + + return { + 'scale': recon_scale, + 'rotation': recon_rotation, + 'translation': recon_translation, + 'type_code': recon_type_code + } + + @typecheck + def sample_primitives( + self, + scale: Float['b np 3 nd'], + rotation: Float['b np 3 nd'], + translation: Float['b np 3 nd'], + type_code: Int['b np nd'], + next_embed: Float['b 1 nd'], + temperature: float = 1., + filter_logits_fn: Callable = top_k_2, + filter_kwargs: dict = dict() + ): + def sample_func(logits): + if logits.ndim == 4: + enable_squeeze = True + logits = logits.squeeze(1) + else: + enable_squeeze = False + + filtered_logits = filter_logits_fn(logits, **filter_kwargs) + + if temperature == 0.: + sample = filtered_logits.argmax(dim=-1) + else: + probs = F.softmax(filtered_logits / temperature, dim=-1) + + sample = torch.zeros((probs.shape[0], probs.shape[1]), dtype=torch.long, device=probs.device) + for b_i in range(probs.shape[0]): + sample[b_i] = torch.multinomial(probs[b_i], 1).squeeze() + + if enable_squeeze: + sample = sample.unsqueeze(1) + + return sample + + next_type_logits = self.to_type_logits(next_embed) + next_type_code = sample_func(next_type_logits) + type_code_new, _ = pack([type_code, next_type_code], 'b *') + + type_embed = self.type_embed(next_type_code) + next_embed_packed, _ = pack([next_embed, type_embed], 'b np *') + next_translation_logits = rearrange(self.to_translation_logits(next_embed_packed), 'b np (c nd) -> b np c nd', nd=self.num_discrete_translation) + next_discretize_translation = sample_func(next_translation_logits) + next_translation = self.undiscretize_translation(next_discretize_translation) + translation_new, _ = pack([translation, next_translation], 'b * nd') + + next_translation_embed = self.translation_embed(next_discretize_translation) + next_embed_packed, _ = pack([next_embed_packed, next_translation_embed], 'b np *') + next_rotation_logits = rearrange(self.to_rotation_logits(next_embed_packed), 'b np (c nd) -> b np c nd', nd=self.num_discrete_rotation) + next_discretize_rotation = sample_func(next_rotation_logits) + next_rotation = self.undiscretize_rotation(next_discretize_rotation) + rotation_new, _ = pack([rotation, next_rotation], 'b * nd') + + next_rotation_embed = self.rotation_embed(next_discretize_rotation) + next_embed_packed, _ = pack([next_embed_packed, next_rotation_embed], 'b np *') + next_scale_logits = rearrange(self.to_scale_logits(next_embed_packed), 'b np (c nd) -> b np c nd', nd=self.num_discrete_scale) + next_discretize_scale = sample_func(next_scale_logits) + next_scale = self.undiscretize_scale(next_discretize_scale) + scale_new, _ = pack([scale, next_scale], 'b * nd') + + return ( + scale_new, + rotation_new, + translation_new, + type_code_new + ) + + @eval_decorator + @torch.no_grad() + @typecheck + def generate( + self, + batch_size: int | None = None, + filter_logits_fn: Callable = top_k_2, + filter_kwargs: dict = dict(), + temperature: float = 1., + scale: Float['b np 3'] | None = None, + rotation: Float['b np 3'] | None = None, + translation: Float['b np 3'] | None = None, + type_code: Int['b np'] | None = None, + pc: Tensor | None = None, + pc_embed: Tensor | None = None, + cache_kv = True, + max_seq_len = None, + ): + max_seq_len = default(max_seq_len, self.max_seq_len) + + if exists(scale) and exists(rotation) and exists(translation) and exists(type_code): + assert not exists(batch_size) + assert scale.shape[1] == rotation.shape[1] == translation.shape[1] == type_code.shape[1] + assert scale.shape[1] <= self.max_seq_len + + batch_size = scale.shape[0] + + if self.condition_on_shape: + assert exists(pc) ^ exists(pc_embed), '`pc` or `pc_embed` must be passed in' + if exists(pc): + pc_embed = self.embed_pc(pc) + + batch_size = default(batch_size, pc_embed.shape[0]) + + batch_size = default(batch_size, 1) + + scale = default(scale, torch.empty((batch_size, 0, 3), dtype=torch.float64, device=self.device)) + rotation = default(rotation, torch.empty((batch_size, 0, 3), dtype=torch.float64, device=self.device)) + translation = default(translation, torch.empty((batch_size, 0, 3), dtype=torch.float64, device=self.device)) + type_code = default(type_code, torch.empty((batch_size, 0), dtype=torch.int64, device=self.device)) + + curr_length = scale.shape[1] + + cache = None + eos_codes = None + + for i in tqdm(range(curr_length, max_seq_len)): + can_eos = i != 0 + output = self.forward( + scale=scale, + rotation=rotation, + translation=translation, + type_code=type_code, + pc_embed=pc_embed, + return_loss=False, + return_cache=cache_kv, + append_eos=False, + cache=cache + ) + if cache_kv: + next_embed, cache = output + else: + next_embed = output + ( + scale, + rotation, + translation, + type_code + ) = self.sample_primitives( + scale, + rotation, + translation, + type_code, + next_embed, + temperature=temperature, + filter_logits_fn=filter_logits_fn, + filter_kwargs=filter_kwargs + ) + + next_eos_logits = self.to_eos_logits(next_embed).squeeze(-1) + next_eos_code = (F.sigmoid(next_eos_logits) > 0.5) + eos_codes = safe_cat([eos_codes, next_eos_code], 1) + if can_eos and eos_codes.any(dim=-1).all(): + break + + # mask out to padding anything after the first eos + mask = eos_codes.float().cumsum(dim=-1) >= 1 + + # concat cur_length to mask + mask = torch.cat((torch.zeros((batch_size, curr_length), dtype=torch.bool, device=self.device), mask), dim=-1) + type_code = type_code.masked_fill(mask, self.pad_id) + scale = scale.masked_fill(mask.unsqueeze(-1), self.pad_id) + rotation = rotation.masked_fill(mask.unsqueeze(-1), self.pad_id) + translation = translation.masked_fill(mask.unsqueeze(-1), self.pad_id) + + recon_primitives = { + 'scale': scale, + 'rotation': rotation, + 'translation': translation, + 'type_code': type_code + } + primitive_mask = ~eos_codes + + return recon_primitives, primitive_mask + + + @eval_decorator + @torch.no_grad() + @typecheck + def generate_w_recon_loss( + self, + batch_size: int | None = None, + filter_logits_fn: Callable = top_k_2, + filter_kwargs: dict = dict(), + temperature: float = 1., + scale: Float['b np 3'] | None = None, + rotation: Float['b np 3'] | None = None, + translation: Float['b np 3'] | None = None, + type_code: Int['b np'] | None = None, + pc: Tensor | None = None, + pc_embed: Tensor | None = None, + cache_kv = True, + max_seq_len = None, + single_directional = True, + ): + max_seq_len = default(max_seq_len, self.max_seq_len) + + if exists(scale) and exists(rotation) and exists(translation) and exists(type_code): + assert not exists(batch_size) + assert scale.shape[1] == rotation.shape[1] == translation.shape[1] == type_code.shape[1] + assert scale.shape[1] <= self.max_seq_len + + batch_size = scale.shape[0] + + if self.condition_on_shape: + assert exists(pc) ^ exists(pc_embed), '`pc` or `pc_embed` must be passed in' + if exists(pc): + pc_embed = self.embed_pc(pc) + + batch_size = default(batch_size, pc_embed.shape[0]) + + batch_size = default(batch_size, 1) + assert batch_size == 1 # TODO: support any batch size + + scale = default(scale, torch.empty((batch_size, 0, 3), dtype=torch.float32, device=self.device)) + rotation = default(rotation, torch.empty((batch_size, 0, 3), dtype=torch.float32, device=self.device)) + translation = default(translation, torch.empty((batch_size, 0, 3), dtype=torch.float32, device=self.device)) + type_code = default(type_code, torch.empty((batch_size, 0), dtype=torch.int64, device=self.device)) + + curr_length = scale.shape[1] + + cache = None + eos_codes = None + last_recon_loss = 1 + for i in tqdm(range(curr_length, max_seq_len)): + can_eos = i != 0 + output = self.forward( + scale=scale, + rotation=rotation, + translation=translation, + type_code=type_code, + pc_embed=pc_embed, + return_loss=False, + return_cache=cache_kv, + append_eos=False, + cache=cache + ) + if cache_kv: + next_embed, cache = output + else: + next_embed = output + ( + scale_new, + rotation_new, + translation_new, + type_code_new + ) = self.sample_primitives( + scale, + rotation, + translation, + type_code, + next_embed, + temperature=temperature, + filter_logits_fn=filter_logits_fn, + filter_kwargs=filter_kwargs + ) + + next_eos_logits = self.to_eos_logits(next_embed).squeeze(-1) + next_eos_code = (F.sigmoid(next_eos_logits) > 0.5) + eos_codes = safe_cat([eos_codes, next_eos_code], 1) + if can_eos and eos_codes.any(dim=-1).all(): + scale, rotation, translation, type_code = ( + scale_new, rotation_new, translation_new, type_code_new) + break + + recon_loss = self.compute_chamfer_distance(scale_new, rotation_new, translation_new, type_code_new, ~eos_codes, pc, single_directional) + if recon_loss < last_recon_loss: + last_recon_loss = recon_loss + scale, rotation, translation, type_code = ( + scale_new, rotation_new, translation_new, type_code_new) + else: + best_recon_loss = recon_loss + best_primitives = dict( + scale=scale_new, rotation=rotation_new, translation=translation_new, type_code=type_code_new) + success_flag = False + print(f'last_recon_loss:{last_recon_loss}, recon_loss:{recon_loss} -> to find better primitive') + for try_i in range(5): + ( + scale_new, + rotation_new, + translation_new, + type_code_new + ) = self.sample_primitives( + scale, + rotation, + translation, + type_code, + next_embed, + temperature=1.0, + filter_logits_fn=filter_logits_fn, + filter_kwargs=filter_kwargs + ) + recon_loss = self.compute_chamfer_distance(scale_new, rotation_new, translation_new, type_code_new, ~eos_codes, pc) + print(f'[try_{try_i}] last_recon_loss:{last_recon_loss}, best_recon_loss:{best_recon_loss}, cur_recon_loss:{recon_loss}') + if recon_loss < last_recon_loss: + last_recon_loss = recon_loss + scale, rotation, translation, type_code = ( + scale_new, rotation_new, translation_new, type_code_new) + success_flag = True + break + else: + if recon_loss < best_recon_loss: + best_recon_loss = recon_loss + best_primitives = dict( + scale=scale_new, rotation=rotation_new, translation=translation_new, type_code=type_code_new) + + if not success_flag: + last_recon_loss = best_recon_loss + scale, rotation, translation, type_code = ( + best_primitives['scale'], best_primitives['rotation'], best_primitives['translation'], best_primitives['type_code']) + print(f'new_last_recon_loss:{last_recon_loss}') + + # mask out to padding anything after the first eos + mask = eos_codes.float().cumsum(dim=-1) >= 1 + type_code = type_code.masked_fill(mask, self.pad_id) + scale = scale.masked_fill(mask.unsqueeze(-1), self.pad_id) + rotation = rotation.masked_fill(mask.unsqueeze(-1), self.pad_id) + translation = translation.masked_fill(mask.unsqueeze(-1), self.pad_id) + + recon_primitives = { + 'scale': scale, + 'rotation': rotation, + 'translation': translation, + 'type_code': type_code + } + primitive_mask = ~eos_codes + + return recon_primitives, primitive_mask + + + @typecheck + def encode( + self, + *, + scale: Float['b np 3'], + rotation: Float['b np 3'], + translation: Float['b np 3'], + type_code: Int['b np'], + primitive_mask: Bool['b np'], + return_primitives = False + ): + """ + einops: + b - batch + np - number of primitives + c - coordinates (3) + d - embed dim + """ + + # compute feature embedding + discretize_scale = self.discretize_scale(scale) + scale_embed = self.scale_embed(discretize_scale) + scale_embed = rearrange(scale_embed, 'b np c d -> b np (c d)') + + discretize_rotation = self.discretize_rotation(rotation) + rotation_embed = self.rotation_embed(discretize_rotation) + rotation_embed = rearrange(rotation_embed, 'b np c d -> b np (c d)') + + discretize_translation = self.discretize_translation(translation) + translation_embed = self.translation_embed(discretize_translation) + translation_embed = rearrange(translation_embed, 'b np c d -> b np (c d)') + + type_embed = self.type_embed(type_code.masked_fill(~primitive_mask, 0)) + + # combine all features and project into model dimension + if self.embed_order == 'srtc': + primitive_embed, _ = pack([scale_embed, rotation_embed, translation_embed, type_embed], 'b np *') + else: + primitive_embed, _ = pack([type_embed, translation_embed, rotation_embed, scale_embed], 'b np *') + + primitive_embed = self.project_in(primitive_embed) + primitive_embed = primitive_embed.masked_fill(~primitive_mask.unsqueeze(-1), 0.) + + if not return_primitives: + return primitive_embed + + primitive_embed_unpacked = { + 'scale': scale_embed, + 'rotation': rotation_embed, + 'translation': translation_embed, + 'type_code': type_embed + } + + primitives_gt = { + 'scale': discretize_scale, + 'rotation': discretize_rotation, + 'translation': discretize_translation, + 'type_code': type_code + } + + return primitive_embed, primitive_embed_unpacked, primitives_gt + + @typecheck + def compute_chamfer_distance( + self, + scale_pred: Float['b np 3'], + rotation_pred: Float['b np 3'], + translation_pred: Float['b np 3'], + type_pred: Int['b np'], + primitive_mask: Bool['b np'], + pc: Tensor, # b, num_points, c + single_directional = True + ): + scale_pred = scale_pred.float() + rotation_pred = rotation_pred.float() + translation_pred = translation_pred.float() + + pc_pred = apply_transformation(self.bs_pc.to(type_pred.device)[type_pred], scale_pred, torch.deg2rad(rotation_pred), translation_pred) + pc_pred = torch.matmul(pc_pred, self.rotation_matrix_align_coord.to(type_pred.device)) + pc_pred_flat = rearrange(pc_pred, 'b np p c -> b (np p) c') + pc_pred_sampled = random_sample_pc(pc_pred_flat, primitive_mask.sum(dim=-1, keepdim=True), n_points=self.bs_pc.shape[1]) + + if single_directional: + recon_loss, _ = chamfer_distance(pc[:, :, :3].float(), pc_pred_sampled.float(), single_directional=True) # single directional + else: + recon_loss, _ = chamfer_distance(pc_pred_sampled.float(), pc[:, :, :3].float()) + + return recon_loss + + def forward( + self, + *, + scale: Float['b np 3'], + rotation: Float['b np 3'], + translation: Float['b np 3'], + type_code: Int['b np'], + loss_reduction: str = 'mean', + return_cache = False, + append_eos = True, + cache: LayerIntermediates | None = None, + pc: Tensor | None = None, + pc_embed: Tensor | None = None, + **kwargs + ): + + primitive_mask = reduce(scale != self.pad_id, 'b np 3 -> b np', 'all') + + if scale.shape[1] > 0: + codes, primitives_embeds, primitives_gt = self.encode( + scale=scale, + rotation=rotation, + translation=translation, + type_code=type_code, + primitive_mask=primitive_mask, + return_primitives=True + ) + else: + codes = torch.empty((scale.shape[0], 0, self.dim), dtype=torch.float32, device=self.device) + + # handle shape conditions + + attn_context_kwargs = dict() + + if self.condition_on_shape: + assert exists(pc) ^ exists(pc_embed), '`pc` or `pc_embed` must be passed in' + + if exists(pc): + if 'michelangelo' in self.shape_condition_model_type: + pc_head, pc_embed = self.conditioner(shape=pc) + pc_embed = torch.cat([self.to_cond_dim_head(pc_head), self.to_cond_dim(pc_embed)], dim=-2) + else: + raise ValueError(f'unknown shape_condition_model_type {self.shape_condition_model_type}') + + assert pc_embed.shape[0] == codes.shape[0], 'batch size of point cloud is not equal to the batch size of the primitive codes' + + pooled_pc_embed = pc_embed.mean(dim=1) # (b, shape_condition_dim) + + if self.shape_cond_with_cross_attn: + attn_context_kwargs = dict( + context=pc_embed + ) + + if self.coarse_adaptive_rmsnorm: + attn_context_kwargs.update( + condition=pooled_pc_embed + ) + + batch, seq_len, _ = codes.shape # (b, np, dim) + device = codes.device + assert seq_len <= self.max_seq_len, f'received codes of length {seq_len} but needs to be less than or equal to set max_seq_len {self.max_seq_len}' + + if append_eos: + assert exists(codes) + code_lens = primitive_mask.sum(dim=-1) + codes = pad_tensor(codes) + + batch_arange = torch.arange(batch, device=device) + batch_arange = rearrange(batch_arange, '... -> ... 1') + code_lens = rearrange(code_lens, '... -> ... 1') + codes[batch_arange, code_lens] = self.eos_token # (b, np+1, dim) + + primitive_codes = codes # (b, np, dim) + + primitive_codes_len = primitive_codes.shape[-2] + + ( + coarse_cache, + coarse_gateloop_cache, + coarse_post_gateloop_cache, + ) = cache if exists(cache) else ((None,) * 3) + + if not exists(cache): + sos = repeat(self.sos_token, 'n d -> b n d', b=batch) + + if self.shape_cond_with_cat: + sos, _ = pack([pc_embed, sos], 'b * d') + primitive_codes, packed_sos_shape = pack([sos, primitive_codes], 'b * d') # (b, n_sos+np, dim) + + # condition primitive codes with shape if needed + if self.condition_on_shape: + primitive_codes = self.shape_coarse_film_cond(primitive_codes, pooled_pc_embed) + + # attention on primitive codes (coarse) + + if exists(self.coarse_gateloop_block): + primitive_codes, coarse_gateloop_cache = self.coarse_gateloop_block(primitive_codes, cache=coarse_gateloop_cache) + + attended_primitive_codes, coarse_cache = self.decoder( # (b, n_sos+np, dim) + primitive_codes, + cache=coarse_cache, + return_hiddens=True, + **attn_context_kwargs + ) + + if exists(self.coarse_post_gateloop_block): + primitive_codes, coarse_post_gateloop_cache = self.coarse_post_gateloop_block(primitive_codes, cache=coarse_post_gateloop_cache) + + embed = attended_primitive_codes[:, -(primitive_codes_len + 1):] # (b, np+1, dim) + + if not return_cache: + return embed[:, -1:] + + next_cache = ( + coarse_cache, + coarse_gateloop_cache, + coarse_post_gateloop_cache + ) + + return embed[:, -1:], next_cache + + +def pad_tensor(tensor): + if tensor.dim() == 3: + bs, seq_len, dim = tensor.shape + padding = torch.zeros((bs, 1, dim), dtype=tensor.dtype, device=tensor.device) + elif tensor.dim() == 2: + bs, seq_len = tensor.shape + padding = torch.zeros((bs, 1), dtype=tensor.dtype, device=tensor.device) + else: + raise ValueError('Unsupported tensor shape: {}'.format(tensor.shape)) + + return torch.cat([tensor, padding], dim=1) + + +def apply_transformation(pc, scale, rotation_vector, translation): + bs, np, num_points, _ = pc.shape + scaled_pc = pc * scale.unsqueeze(2) + + rotation_matrix = euler_angles_to_matrix(rotation_vector.view(-1, 3), 'XYZ').view(bs, np, 3, 3) # euler tmp + rotated_pc = torch.einsum('bnij,bnpj->bnpi', rotation_matrix, scaled_pc) + + transformed_pc = rotated_pc + translation.unsqueeze(2) + + return transformed_pc + + +def random_sample_pc(pc, max_lens, n_points=10000): + bs = max_lens.shape[0] + max_len = max_lens.max().item() * n_points + + random_values = torch.rand(bs, max_len, device=max_lens.device) + mask = torch.arange(max_len).expand(bs, max_len).to(max_lens.device) < (max_lens * n_points) + masked_random_values = random_values * mask.float() + _, indices = torch.topk(masked_random_values, n_points, dim=1) + + return pc[torch.arange(bs).unsqueeze(1), indices] \ No newline at end of file diff --git a/primitive_anything/utils/__init__.py b/primitive_anything/utils/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..778612545fe486e30d41cb17cb1bf107421b2d57 --- /dev/null +++ b/primitive_anything/utils/__init__.py @@ -0,0 +1,275 @@ +from math import ceil +from pathlib import Path +import os +import re + +from beartype.typing import Tuple +from einops import rearrange, repeat +from toolz import valmap +import torch +from torch import Tensor +from torch.nn import Module +import torch.nn.functional as F +import yaml + +from .typing import typecheck + + +def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +def path_mkdir(path): + path = Path(path) + path.mkdir(parents=True, exist_ok=True) + return path + + +def load_yaml(path, default_path=None): + path = path_exists(path) + with open(path, mode='r') as fp: + cfg_s = yaml.load(fp, Loader=yaml.FullLoader) + + if default_path is not None: + default_path = path_exists(default_path) + with open(default_path, mode='r') as fp: + cfg = yaml.load(fp, Loader=yaml.FullLoader) + else: + # try current dir default + default_path = path.parent / 'default.yml' + if default_path.exists(): + with open(default_path, mode='r') as fp: + cfg = yaml.load(fp, Loader=yaml.FullLoader) + else: + cfg = {} + + update_recursive(cfg, cfg_s) + return cfg + + +def dump_yaml(cfg, path): + with open(path, mode='w') as f: + return yaml.safe_dump(cfg, f) + + +def update_recursive(dict1, dict2): + ''' Update two config dictionaries recursively. + Args: + dict1 (dict): first dictionary to be updated + dict2 (dict): second dictionary which entries should be used + ''' + for k, v in dict2.items(): + if k not in dict1: + dict1[k] = dict() + if isinstance(v, dict): + update_recursive(dict1[k], v) + else: + dict1[k] = v + + +def load_latest_checkpoint(checkpoint_dir): + pattern = re.compile(rf".+\.ckpt\.(\d+)\.pt") + max_epoch = -1 + latest_checkpoint = None + + for filename in os.listdir(checkpoint_dir): + match = pattern.match(filename) + if match: + num_epoch = int(match.group(1)) + if num_epoch > max_epoch: + max_epoch = num_epoch + latest_checkpoint = checkpoint_dir / filename + + if not exists(latest_checkpoint): + raise FileNotFoundError(f"No checkpoint files found in {checkpoint_dir}") + + checkpoint = torch.load(latest_checkpoint) + return checkpoint, latest_checkpoint + + +def torch_to(inp, device, non_blocking=False): + nb = non_blocking # set to True when doing distributed jobs + if isinstance(inp, torch.Tensor): + return inp.to(device, non_blocking=nb) + elif isinstance(inp, (list, tuple)): + return type(inp)(map(lambda t: t.to(device, non_blocking=nb) if isinstance(t, torch.Tensor) else t, inp)) + elif isinstance(inp, dict): + return valmap(lambda t: t.to(device, non_blocking=nb) if isinstance(t, torch.Tensor) else t, inp) + else: + raise NotImplementedError + + +# helper functions + +def exists(v): + return v is not None + +def default(v, d): + return v if exists(v) else d + +def first(it): + return it[0] + +def identity(t, *args, **kwargs): + return t + +def divisible_by(num, den): + return (num % den) == 0 + +def is_odd(n): + return not divisible_by(n, 2) + +def is_empty(x): + return len(x) == 0 + +def is_tensor_empty(t: Tensor): + return t.numel() == 0 + +def set_module_requires_grad_( + module: Module, + requires_grad: bool +): + for param in module.parameters(): + param.requires_grad = requires_grad + +def l1norm(t): + return F.normalize(t, dim = -1, p = 1) + +def l2norm(t): + return F.normalize(t, dim = -1, p = 2) + +def safe_cat(tensors, dim): + tensors = [*filter(exists, tensors)] + + if len(tensors) == 0: + return None + elif len(tensors) == 1: + return first(tensors) + + return torch.cat(tensors, dim = dim) + +def pad_at_dim(t, padding, dim = -1, value = 0): + ndim = t.ndim + right_dims = (ndim - dim - 1) if dim >= 0 else (-dim - 1) + zeros = (0, 0) * right_dims + return F.pad(t, (*zeros, *padding), value = value) + +def pad_to_length(t, length, dim = -1, value = 0, right = True): + curr_length = t.shape[dim] + remainder = length - curr_length + + if remainder <= 0: + return t + + padding = (0, remainder) if right else (remainder, 0) + return pad_at_dim(t, padding, dim = dim, value = value) + +def masked_mean(tensor, mask, dim = -1, eps = 1e-5): + if not exists(mask): + return tensor.mean(dim = dim) + + mask = rearrange(mask, '... -> ... 1') + tensor = tensor.masked_fill(~mask, 0.) + + total_el = mask.sum(dim = dim) + num = tensor.sum(dim = dim) + den = total_el.float().clamp(min = eps) + mean = num / den + mean = mean.masked_fill(total_el == 0, 0.) + return mean + +def cycle(dl): + while True: + for data in dl: + yield data + +def maybe_del(d: dict, *keys): + for key in keys: + if key not in d: + continue + + del d[key] + + +# tensor helper functions + +@typecheck +def discretize( + t: Tensor, + *, + continuous_range: Tuple[float, float], + num_discrete: int = 128 +) -> Tensor: + lo, hi = continuous_range + assert hi > lo + + t = (t - lo) / (hi - lo) + t *= num_discrete + t -= 0.5 + + return t.round().long().clamp(min = 0, max = num_discrete - 1) + +@typecheck +def undiscretize( + t: Tensor, + *, + continuous_range = Tuple[float, float], + num_discrete: int = 128 +) -> Tensor: + lo, hi = continuous_range + assert hi > lo + + t = t.float() + + t += 0.5 + t /= num_discrete + return t * (hi - lo) + lo + +@typecheck +def gaussian_blur_1d( + t: Tensor, + *, + sigma: float = 1., + kernel_size: int = 5 +) -> Tensor: + + _, _, channels, device, dtype = *t.shape, t.device, t.dtype + + width = int(ceil(sigma * kernel_size)) + width += (width + 1) % 2 + half_width = width // 2 + + distance = torch.arange(-half_width, half_width + 1, dtype = dtype, device = device) + + gaussian = torch.exp(-(distance ** 2) / (2 * sigma ** 2)) + gaussian = l1norm(gaussian) + + kernel = repeat(gaussian, 'n -> c 1 n', c = channels) + + t = rearrange(t, 'b n c -> b c n') + out = F.conv1d(t, kernel, padding = half_width, groups = channels) + return rearrange(out, 'b c n -> b n c') + +@typecheck +def scatter_mean( + tgt: Tensor, + indices: Tensor, + src = Tensor, + *, + dim: int = -1, + eps: float = 1e-5 +): + """ + todo: update to pytorch 2.1 and try https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_reduce_.html#torch.Tensor.scatter_reduce_ + """ + num = tgt.scatter_add(dim, indices, src) + den = torch.zeros_like(tgt).scatter_add(dim, indices, torch.ones_like(src)) + return num / den.clamp(min = eps) + +def prob_mask_like(shape, prob, device): + if prob == 1: + return torch.ones(shape, device = device, dtype = torch.bool) + elif prob == 0: + return torch.zeros(shape, device = device, dtype = torch.bool) + else: + return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob \ No newline at end of file diff --git a/primitive_anything/utils/logger.py b/primitive_anything/utils/logger.py new file mode 100755 index 0000000000000000000000000000000000000000..cf5ee29d21cfbf2e6331f0151528bcc29d615291 --- /dev/null +++ b/primitive_anything/utils/logger.py @@ -0,0 +1,65 @@ +import logging +import time +import os + + +class Verbose: + mute = False + + +def print_log(s, logger=None, level='info'): + if Verbose.mute: + return None + + if logger is None: + logger = logging.getLogger('trainer') + if level == 'info': + print_info(s) + logger.info(s) + elif level == 'warning': + print_warning(s) + logger.warning(s) + elif level == 'error': + print_error(s) + logger.error(s) + else: + raise NotImplementedError + + +def create_logger(log_dir, name='trainer'): + assert os.path.exists(log_dir), 'log_dir {} does not exist.' + logger = logging.getLogger(name) + file_path = log_dir / '{}.log'.format(name) + hdlr = logging.FileHandler(file_path) + formatter = logging.Formatter('[%(asctime)s] %(levelname)s: %(message)s') + hdlr.setFormatter(formatter) + logger.addHandler(hdlr) + logger.setLevel(logging.INFO) + return logger + + +class TerminalColors: + HEADER = '\033[95m' + OKBLUE = '\033[94m' + OKGREEN = '\033[92m' + WARNING = '\033[93m' + FAIL = '\033[91m' + ENDC = '\033[0m' + BOLD = '\033[1m' + UNDERLINE = '\033[4m' + + +def get_time(): + return time.strftime('%Y-%m-%d %H:%M:%S') + + +def print_info(s): + print(TerminalColors.OKBLUE + '[' + get_time() + '] ' + str(s) + TerminalColors.ENDC) + + +def print_warning(s): + print(TerminalColors.WARNING + '[' + get_time() + '] WARN ' + str(s) + TerminalColors.ENDC) + + +def print_error(s): + print(TerminalColors.FAIL + '[' + get_time() + '] ERROR ' + str(s) + TerminalColors.ENDC) diff --git a/primitive_anything/utils/typing.py b/primitive_anything/utils/typing.py new file mode 100755 index 0000000000000000000000000000000000000000..a92116b431fcb97796d9ced7a2c4c2065ba05964 --- /dev/null +++ b/primitive_anything/utils/typing.py @@ -0,0 +1,57 @@ +from environs import Env + +from torch import Tensor + +from beartype import beartype +from beartype.door import is_bearable + +from jaxtyping import ( + Float, + Int, + Bool, + jaxtyped +) + +# environment + +env = Env() +env.read_env() + +# function + +def always(value): + def inner(*args, **kwargs): + return value + return inner + +def identity(t): + return t + +# jaxtyping is a misnomer, works for pytorch + +class TorchTyping: + def __init__(self, abstract_dtype): + self.abstract_dtype = abstract_dtype + + def __getitem__(self, shapes: str): + return self.abstract_dtype[Tensor, shapes] + +Float = TorchTyping(Float) +Int = TorchTyping(Int) +Bool = TorchTyping(Bool) + +# use env variable TYPECHECK to control whether to use beartype + jaxtyping + +should_typecheck = env.bool('TYPECHECK', False) + +typecheck = jaxtyped(typechecker = beartype) if should_typecheck else identity + +beartype_isinstance = is_bearable if should_typecheck else always(True) + +__all__ = [ + Float, + Int, + Bool, + typecheck, + beartype_isinstance +] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..500f927b30797b396a9e399efecc56394eb08b1f --- /dev/null +++ b/requirements.txt @@ -0,0 +1,2 @@ +ninja +pytorch3d @ https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt221/pytorch3d-0.7.6-cp310-cp310-linux_x86_64.whl