diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..a6344aac8c09253b3b630fb776ae94478aa0275b --- /dev/null +++ b/.gitattributes @@ -0,0 +1,35 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..4705d431dc4e2ef746899bd678977c5133f1ab81 --- /dev/null +++ b/README.md @@ -0,0 +1,14 @@ +--- +title: MIDI 3D +emoji: 📚 +colorFrom: purple +colorTo: red +sdk: gradio +sdk_version: 4.44.1 +app_file: app.py +pinned: false +license: apache-2.0 +short_description: Image to Compositional 3D Scene Generation +--- + +Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..86d5dcd035764fd2ae58d3a4cd68e01a0184ee2b --- /dev/null +++ b/app.py @@ -0,0 +1,334 @@ +import json +import os +import random +import tempfile +from typing import Any, List, Union + +import gradio as gr +import numpy as np +import spaces +import torch +import trimesh +from gradio_image_prompter import ImagePrompter +from gradio_litmodel3d import LitModel3D +from huggingface_hub import snapshot_download +from PIL import Image +from skimage import measure +from transformers import AutoModelForMaskGeneration, AutoProcessor + +from midi.pipelines.pipeline_midi import MIDIPipeline +from midi.utils.smoothing import smooth_gpu +from scripts.grounding_sam import plot_segmentation, segment +from scripts.inference_midi import preprocess_image, split_rgb_mask + +# Constants +MAX_SEED = np.iinfo(np.int32).max +TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp") +DTYPE = torch.bfloat16 +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" +REPO_ID = "VAST-AI/MIDI-3D" + +MARKDOWN = """ +## Image to 3D Scene with [MIDI-3D](https://huanngzh.github.io/MIDI-Page/) +Important! Please check out our [instruction video](https://github.com/user-attachments/assets/4fc8aea4-010f-40c7-989d-6b1d9d3e3e09)! +1. Upload an image, and draw bounding boxes for each instance by holding and dragging the mouse. Then clik "Run Segmentation" to generate the segmentation result. Ensure instances should not be too small and bounding boxes fit snugly around each instance. +2. Check "Do image padding" in "Generation Settings" if instances in your image are too close to the image border. Then click "Run Generation" to generate a 3D scene from the image and segmentation result. +3. If you find the generated 3D scene satisfactory, download it by clicking the "Download GLB" button. +""" + +EXAMPLES = [ + [ + { + "image": "assets/example_data/Cartoon-Style/03_rgb.png", + }, + "assets/example_data/Cartoon-Style/03_seg.png", + 42, + False, + False, + ], + [ + { + "image": "assets/example_data/Cartoon-Style/01_rgb.png", + }, + "assets/example_data/Cartoon-Style/01_seg.png", + 42, + False, + False, + ], + [ + { + "image": "assets/example_data/Realistic-Style/02_rgb.png", + }, + "assets/example_data/Realistic-Style/02_seg.png", + 42, + False, + False, + ], + [ + { + "image": "assets/example_data/Cartoon-Style/00_rgb.png", + }, + "assets/example_data/Cartoon-Style/00_seg.png", + 42, + False, + False, + ], + [ + { + "image": "assets/example_data/Realistic-Style/00_rgb.png", + }, + "assets/example_data/Realistic-Style/00_seg.png", + 42, + False, + True, + ], + [ + { + "image": "assets/example_data/Realistic-Style/01_rgb.png", + }, + "assets/example_data/Realistic-Style/01_seg.png", + 42, + False, + True, + ], + [ + { + "image": "assets/example_data/Realistic-Style/05_rgb.png", + }, + "assets/example_data/Realistic-Style/05_seg.png", + 42, + False, + False, + ], +] + +os.makedirs(TMP_DIR, exist_ok=True) + +# Prepare models +## Grounding SAM +segmenter_id = "facebook/sam-vit-base" +sam_processor = AutoProcessor.from_pretrained(segmenter_id) +sam_segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to( + DEVICE, DTYPE +) +## MIDI-3D +local_dir = "pretrained_weights/MIDI-3D" +snapshot_download(repo_id=REPO_ID, local_dir=local_dir) +pipe: MIDIPipeline = MIDIPipeline.from_pretrained(local_dir).to(DEVICE, DTYPE) +pipe.init_custom_adapter( + set_self_attn_module_names=[ + "blocks.8", + "blocks.9", + "blocks.10", + "blocks.11", + "blocks.12", + ] +) + + +# Utils +def get_random_hex(): + random_bytes = os.urandom(8) + random_hex = random_bytes.hex() + return random_hex + + +@spaces.GPU() +@torch.no_grad() +@torch.autocast(device_type=DEVICE, dtype=torch.bfloat16) +def run_segmentation(image_prompts: Any, polygon_refinement: bool) -> Image.Image: + rgb_image = image_prompts["image"].convert("RGB") + + # pre-process the layers and get the xyxy boxes of each layer + if len(image_prompts["points"]) == 0: + gr.Error("Please draw bounding boxes for each instance on the image.") + boxes = [ + [ + [int(box[0]), int(box[1]), int(box[3]), int(box[4])] + for box in image_prompts["points"] + ] + ] + + # run the segmentation + detections = segment( + sam_processor, + sam_segmentator, + rgb_image, + boxes=[boxes], + polygon_refinement=polygon_refinement, + ) + seg_map_pil = plot_segmentation(rgb_image, detections) + + torch.cuda.empty_cache() + + return seg_map_pil + + +@torch.no_grad() +def run_midi( + pipe: Any, + rgb_image: Union[str, Image.Image], + seg_image: Union[str, Image.Image], + seed: int, + num_inference_steps: int = 50, + guidance_scale: float = 7.0, + do_image_padding: bool = False, +) -> trimesh.Scene: + if do_image_padding: + rgb_image, seg_image = preprocess_image(rgb_image, seg_image) + instance_rgbs, instance_masks, scene_rgbs = split_rgb_mask(rgb_image, seg_image) + + num_instances = len(instance_rgbs) + outputs = pipe( + image=instance_rgbs, + mask=instance_masks, + image_scene=scene_rgbs, + attention_kwargs={"num_instances": num_instances}, + generator=torch.Generator(device=pipe.device).manual_seed(seed), + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + decode_progressive=True, + return_dict=False, + ) + + return outputs + + +@spaces.GPU(duration=300) +@torch.no_grad() +@torch.autocast(device_type=DEVICE, dtype=torch.bfloat16) +def run_generation( + rgb_image: Any, + seg_image: Union[str, Image.Image], + seed: int, + randomize_seed: bool = False, + num_inference_steps: int = 50, + guidance_scale: float = 7.0, + do_image_padding: bool = False, +): + if randomize_seed: + seed = random.randint(0, MAX_SEED) + + if not isinstance(rgb_image, Image.Image) and "image" in rgb_image: + rgb_image = rgb_image["image"] + + outputs = run_midi( + pipe, + rgb_image, + seg_image, + seed, + num_inference_steps, + guidance_scale, + do_image_padding, + ) + + # marching cubes + trimeshes = [] + for _, (logits_, grid_size, bbox_size, bbox_min, bbox_max) in enumerate( + zip(*outputs) + ): + grid_logits = logits_.view(grid_size) + grid_logits = smooth_gpu(grid_logits, method="gaussian", sigma=1) + torch.cuda.empty_cache() + vertices, faces, normals, _ = measure.marching_cubes( + grid_logits.float().cpu().numpy(), 0, method="lewiner" + ) + vertices = vertices / grid_size * bbox_size + bbox_min + + # Trimesh + mesh = trimesh.Trimesh(vertices.astype(np.float32), np.ascontiguousarray(faces)) + trimeshes.append(mesh) + + # compose the output meshes + scene = trimesh.Scene(trimeshes) + + tmp_path = os.path.join(TMP_DIR, f"midi3d_{get_random_hex()}.glb") + scene.export(tmp_path) + + torch.cuda.empty_cache() + + return tmp_path, tmp_path, seed + + +# Demo +with gr.Blocks() as demo: + gr.Markdown(MARKDOWN) + + with gr.Row(): + with gr.Column(): + with gr.Row(): + image_prompts = ImagePrompter(label="Input Image", type="pil") + seg_image = gr.Image( + label="Segmentation Result", type="pil", format="png" + ) + + with gr.Accordion("Segmentation Settings", open=False): + polygon_refinement = gr.Checkbox( + label="Polygon Refinement", value=False + ) + seg_button = gr.Button("Run Segmentation") + + with gr.Accordion("Generation Settings", open=False): + do_image_padding = gr.Checkbox(label="Do image padding", value=False) + seed = gr.Slider( + label="Seed", + minimum=0, + maximum=MAX_SEED, + step=1, + value=0, + ) + randomize_seed = gr.Checkbox(label="Randomize seed", value=True) + num_inference_steps = gr.Slider( + label="Number of inference steps", + minimum=1, + maximum=50, + step=1, + value=50, + ) + guidance_scale = gr.Slider( + label="CFG scale", + minimum=0.0, + maximum=10.0, + step=0.1, + value=7.0, + ) + gen_button = gr.Button("Run Generation", variant="primary") + + with gr.Column(): + model_output = LitModel3D(label="Generated GLB", exposure=1.0, height=500) + download_glb = gr.DownloadButton(label="Download GLB", interactive=False) + + with gr.Row(): + gr.Examples( + examples=EXAMPLES, + fn=run_generation, + inputs=[image_prompts, seg_image, seed, randomize_seed, do_image_padding], + outputs=[model_output, download_glb, seed], + cache_examples=False, + ) + + seg_button.click( + run_segmentation, + inputs=[ + image_prompts, + polygon_refinement, + ], + outputs=[seg_image], + ).then(lambda: gr.Button(interactive=True), outputs=[gen_button]) + + gen_button.click( + run_generation, + inputs=[ + image_prompts, + seg_image, + seed, + randomize_seed, + num_inference_steps, + guidance_scale, + do_image_padding, + ], + outputs=[model_output, download_glb, seed], + ).then(lambda: gr.Button(interactive=True), outputs=[download_glb]) + + +demo.launch() diff --git a/assets/example_data/3D-Front/ffb067ad-cf9a-4321-82ae-4e684c59ea3e_KidsRoom-5300_rgb.png b/assets/example_data/3D-Front/ffb067ad-cf9a-4321-82ae-4e684c59ea3e_KidsRoom-5300_rgb.png new file mode 100644 index 0000000000000000000000000000000000000000..c138e538ceaa573cae9ea2b1447851cc3c3046fc Binary files /dev/null and b/assets/example_data/3D-Front/ffb067ad-cf9a-4321-82ae-4e684c59ea3e_KidsRoom-5300_rgb.png differ diff --git a/assets/example_data/3D-Front/ffb067ad-cf9a-4321-82ae-4e684c59ea3e_KidsRoom-5300_seg.png b/assets/example_data/3D-Front/ffb067ad-cf9a-4321-82ae-4e684c59ea3e_KidsRoom-5300_seg.png new file mode 100644 index 0000000000000000000000000000000000000000..8ae73d6d9132afe2959ac3d95ed0e10e35e8f749 Binary files /dev/null and b/assets/example_data/3D-Front/ffb067ad-cf9a-4321-82ae-4e684c59ea3e_KidsRoom-5300_seg.png differ diff --git a/assets/example_data/3D-Front/ffd98024-7200-429e-8b9a-1234a5937826_LivingRoom-360_rgb.png b/assets/example_data/3D-Front/ffd98024-7200-429e-8b9a-1234a5937826_LivingRoom-360_rgb.png new file mode 100644 index 0000000000000000000000000000000000000000..39297e4f95a8c9dc6b472194372f80814ccf6e9b Binary files /dev/null and b/assets/example_data/3D-Front/ffd98024-7200-429e-8b9a-1234a5937826_LivingRoom-360_rgb.png differ diff --git a/assets/example_data/3D-Front/ffd98024-7200-429e-8b9a-1234a5937826_LivingRoom-360_seg.png b/assets/example_data/3D-Front/ffd98024-7200-429e-8b9a-1234a5937826_LivingRoom-360_seg.png new file mode 100644 index 0000000000000000000000000000000000000000..0b793840de558f847d52f77ce3cdfaabcec863dd Binary files /dev/null and b/assets/example_data/3D-Front/ffd98024-7200-429e-8b9a-1234a5937826_LivingRoom-360_seg.png differ diff --git a/assets/example_data/3D-Front/fff98d42-99a4-43fc-9639-5761cb4f87df_SecondBedroom-127961_rgb.png b/assets/example_data/3D-Front/fff98d42-99a4-43fc-9639-5761cb4f87df_SecondBedroom-127961_rgb.png new file mode 100644 index 0000000000000000000000000000000000000000..f0a1d903539252879c12f18489e1e7baefa42c0d Binary files /dev/null and b/assets/example_data/3D-Front/fff98d42-99a4-43fc-9639-5761cb4f87df_SecondBedroom-127961_rgb.png differ diff --git a/assets/example_data/3D-Front/fff98d42-99a4-43fc-9639-5761cb4f87df_SecondBedroom-127961_seg.png b/assets/example_data/3D-Front/fff98d42-99a4-43fc-9639-5761cb4f87df_SecondBedroom-127961_seg.png new file mode 100644 index 0000000000000000000000000000000000000000..ced625ea874dd4e7a4580868fc75ca24306b7c67 Binary files /dev/null and b/assets/example_data/3D-Front/fff98d42-99a4-43fc-9639-5761cb4f87df_SecondBedroom-127961_seg.png differ diff --git a/assets/example_data/Cartoon-Style/00_rgb.png b/assets/example_data/Cartoon-Style/00_rgb.png new file mode 100644 index 0000000000000000000000000000000000000000..2b915376f7f5dea9944c7cced2d4850f698f7516 Binary files /dev/null and b/assets/example_data/Cartoon-Style/00_rgb.png differ diff --git a/assets/example_data/Cartoon-Style/00_seg.png b/assets/example_data/Cartoon-Style/00_seg.png new file mode 100644 index 0000000000000000000000000000000000000000..544364133f8036f02e8640d7e75ff95b16d5c983 Binary files /dev/null and b/assets/example_data/Cartoon-Style/00_seg.png differ diff --git a/assets/example_data/Cartoon-Style/01_rgb.png b/assets/example_data/Cartoon-Style/01_rgb.png new file mode 100644 index 0000000000000000000000000000000000000000..22040aded42ad419ce8a4c7d507569383e68214f Binary files /dev/null and b/assets/example_data/Cartoon-Style/01_rgb.png differ diff --git a/assets/example_data/Cartoon-Style/01_seg.png b/assets/example_data/Cartoon-Style/01_seg.png new file mode 100644 index 0000000000000000000000000000000000000000..e78104e1a4eeae206bf56444e70cdec229023d35 Binary files /dev/null and b/assets/example_data/Cartoon-Style/01_seg.png differ diff --git a/assets/example_data/Cartoon-Style/02_rgb.png b/assets/example_data/Cartoon-Style/02_rgb.png new file mode 100644 index 0000000000000000000000000000000000000000..4f63141113b8bb3d78e965e6afdeb75633c91bf8 Binary files /dev/null and b/assets/example_data/Cartoon-Style/02_rgb.png differ diff --git a/assets/example_data/Cartoon-Style/02_seg.png b/assets/example_data/Cartoon-Style/02_seg.png new file mode 100644 index 0000000000000000000000000000000000000000..4e5e1fa4a8db0da6afae14fdf2694d3048f47939 Binary files /dev/null and b/assets/example_data/Cartoon-Style/02_seg.png differ diff --git a/assets/example_data/Cartoon-Style/03_rgb.png b/assets/example_data/Cartoon-Style/03_rgb.png new file mode 100644 index 0000000000000000000000000000000000000000..a2834b5d6f839e63766a2399110dc0b3f2fab5ac Binary files /dev/null and b/assets/example_data/Cartoon-Style/03_rgb.png differ diff --git a/assets/example_data/Cartoon-Style/03_seg.png b/assets/example_data/Cartoon-Style/03_seg.png new file mode 100644 index 0000000000000000000000000000000000000000..260d56aa85c203ba45525d3658a00e8d82629639 Binary files /dev/null and b/assets/example_data/Cartoon-Style/03_seg.png differ diff --git a/assets/example_data/Cartoon-Style/04_rgb.png b/assets/example_data/Cartoon-Style/04_rgb.png new file mode 100644 index 0000000000000000000000000000000000000000..2ba8287a928ce99f6960fe51c4115a6687e383c6 Binary files /dev/null and b/assets/example_data/Cartoon-Style/04_rgb.png differ diff --git a/assets/example_data/Cartoon-Style/04_seg.png b/assets/example_data/Cartoon-Style/04_seg.png new file mode 100644 index 0000000000000000000000000000000000000000..53c4afa174d93169be40dc8ab687aec99befa9ab Binary files /dev/null and b/assets/example_data/Cartoon-Style/04_seg.png differ diff --git a/assets/example_data/Realistic-Style/00_rgb.png b/assets/example_data/Realistic-Style/00_rgb.png new file mode 100644 index 0000000000000000000000000000000000000000..24cd187fbaf61a535598ec493460db0c58ca0f3f Binary files /dev/null and b/assets/example_data/Realistic-Style/00_rgb.png differ diff --git a/assets/example_data/Realistic-Style/00_seg.png b/assets/example_data/Realistic-Style/00_seg.png new file mode 100644 index 0000000000000000000000000000000000000000..ae4f24c58b72a0c8fa0d320524a1665e0fdc2d83 Binary files /dev/null and b/assets/example_data/Realistic-Style/00_seg.png differ diff --git a/assets/example_data/Realistic-Style/01_rgb.png b/assets/example_data/Realistic-Style/01_rgb.png new file mode 100644 index 0000000000000000000000000000000000000000..0a62c98d0f563a019916fc5f486048f26281879f Binary files /dev/null and b/assets/example_data/Realistic-Style/01_rgb.png differ diff --git a/assets/example_data/Realistic-Style/01_seg.png b/assets/example_data/Realistic-Style/01_seg.png new file mode 100644 index 0000000000000000000000000000000000000000..7a41f69a603a70dbaf4d2209aa15cfc94c879a8a Binary files /dev/null and b/assets/example_data/Realistic-Style/01_seg.png differ diff --git a/assets/example_data/Realistic-Style/02_rgb.png b/assets/example_data/Realistic-Style/02_rgb.png new file mode 100644 index 0000000000000000000000000000000000000000..2cd66431280833ae08e41181ada56a0048acc218 Binary files /dev/null and b/assets/example_data/Realistic-Style/02_rgb.png differ diff --git a/assets/example_data/Realistic-Style/02_seg.png b/assets/example_data/Realistic-Style/02_seg.png new file mode 100644 index 0000000000000000000000000000000000000000..f72ad670e2c5fc24dc709147a0cb2eac0e79c601 Binary files /dev/null and b/assets/example_data/Realistic-Style/02_seg.png differ diff --git a/assets/example_data/Realistic-Style/03_rgb.png b/assets/example_data/Realistic-Style/03_rgb.png new file mode 100644 index 0000000000000000000000000000000000000000..fc3cb7d17b1039b9f2b7708e808887055ba4c5b5 Binary files /dev/null and b/assets/example_data/Realistic-Style/03_rgb.png differ diff --git a/assets/example_data/Realistic-Style/03_seg.png b/assets/example_data/Realistic-Style/03_seg.png new file mode 100644 index 0000000000000000000000000000000000000000..f22f26262738f6084cbf3bb3ef2e0a1e620cfb4f Binary files /dev/null and b/assets/example_data/Realistic-Style/03_seg.png differ diff --git a/assets/example_data/Realistic-Style/04_rgb.png b/assets/example_data/Realistic-Style/04_rgb.png new file mode 100644 index 0000000000000000000000000000000000000000..1270b2f4d99ac96c20bf17cb0163e0a0590f6608 Binary files /dev/null and b/assets/example_data/Realistic-Style/04_rgb.png differ diff --git a/assets/example_data/Realistic-Style/04_seg.png b/assets/example_data/Realistic-Style/04_seg.png new file mode 100644 index 0000000000000000000000000000000000000000..082ced5d8eee381e55f3ba5a8e196ab0f31d312c Binary files /dev/null and b/assets/example_data/Realistic-Style/04_seg.png differ diff --git a/assets/example_data/Realistic-Style/05_rgb.png b/assets/example_data/Realistic-Style/05_rgb.png new file mode 100644 index 0000000000000000000000000000000000000000..1aad6afff2873b3a098848ce5682ca7956599630 Binary files /dev/null and b/assets/example_data/Realistic-Style/05_rgb.png differ diff --git a/assets/example_data/Realistic-Style/05_seg.png b/assets/example_data/Realistic-Style/05_seg.png new file mode 100644 index 0000000000000000000000000000000000000000..8aa305ba18331d16112e6363dbe9687796f1c1df Binary files /dev/null and b/assets/example_data/Realistic-Style/05_seg.png differ diff --git a/assets/example_data/Realistic-Style/06_rgb.png b/assets/example_data/Realistic-Style/06_rgb.png new file mode 100644 index 0000000000000000000000000000000000000000..9968cdee68c59c63ace8245c2e02add77234db16 Binary files /dev/null and b/assets/example_data/Realistic-Style/06_rgb.png differ diff --git a/assets/example_data/Realistic-Style/06_seg.png b/assets/example_data/Realistic-Style/06_seg.png new file mode 100644 index 0000000000000000000000000000000000000000..46e0609b31df1113109f4d53887b1778d57c4388 Binary files /dev/null and b/assets/example_data/Realistic-Style/06_seg.png differ diff --git a/midi/inference_utils.py b/midi/inference_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..34498e0b80db4145ebd43067161ac81b6b3e5917 --- /dev/null +++ b/midi/inference_utils.py @@ -0,0 +1,22 @@ +from typing import List, Tuple + +import numpy as np +import PIL +import torch.nn.functional as F +from PIL import Image + + +def generate_dense_grid_points( + bbox_min: np.ndarray, bbox_max: np.ndarray, octree_depth: int, indexing: str = "ij" +): + length = bbox_max - bbox_min + num_cells = np.exp2(octree_depth) + x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32) + y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32) + z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32) + [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing) + xyz = np.stack((xs, ys, zs), axis=-1) + xyz = xyz.reshape(-1, 3) + grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1] + + return xyz, grid_size, length diff --git a/midi/loaders/__init__.py b/midi/loaders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f93ecab4e9cb397dc0050b42edce64daa84b74f0 --- /dev/null +++ b/midi/loaders/__init__.py @@ -0,0 +1 @@ +from .custom_adapter import CustomAdapterMixin diff --git a/midi/loaders/custom_adapter.py b/midi/loaders/custom_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..b6032ffde48f0434a9062ec26b81bee7012b7cae --- /dev/null +++ b/midi/loaders/custom_adapter.py @@ -0,0 +1,99 @@ +import os +from typing import Dict, Optional, Union + +import safetensors +import torch +from diffusers.utils import _get_model_file, logging +from safetensors import safe_open + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class CustomAdapterMixin: + def init_custom_adapter(self, *args, **kwargs): + self._init_custom_adapter(*args, **kwargs) + + def _init_custom_adapter(self, *args, **kwargs): + raise NotImplementedError + + def load_custom_adapter( + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + weight_name: str, + subfolder: Optional[str] = None, + **kwargs, + ): + # Load the main state dict first. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + if weight_name.endswith(".safetensors"): + state_dict = {} + with safe_open(model_file, framework="pt", device="cpu") as f: + for key in f.keys(): + state_dict[key] = f.get_tensor(key) + else: + state_dict = torch.load(model_file, map_location="cpu") + else: + state_dict = pretrained_model_name_or_path_or_dict + + self._load_custom_adapter(state_dict) + + def _load_custom_adapter(self, state_dict): + raise NotImplementedError + + def save_custom_adapter( + self, + save_directory: Union[str, os.PathLike], + weight_name: str, + safe_serialization: bool = False, + **kwargs, + ): + if os.path.isfile(save_directory): + logger.error( + f"Provided path ({save_directory}) should be a directory, not a file" + ) + return + + if safe_serialization: + + def save_function(weights, filename): + return safetensors.torch.save_file( + weights, filename, metadata={"format": "pt"} + ) + + else: + save_function = torch.save + + # Save the model + state_dict = self._save_custom_adapter(**kwargs) + save_function(state_dict, os.path.join(save_directory, weight_name)) + logger.info( + f"Custom adapter weights saved in {os.path.join(save_directory, weight_name)}" + ) + + def _save_custom_adapter(self): + raise NotImplementedError diff --git a/midi/models/attention_processor.py b/midi/models/attention_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..018304f5fd11d9b92e4bb83772add3c4b7f8ccce --- /dev/null +++ b/midi/models/attention_processor.py @@ -0,0 +1,412 @@ +from typing import Callable, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from diffusers.models.attention_processor import Attention +from diffusers.utils import logging +from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available +from diffusers.utils.torch_utils import is_torch_version, maybe_allow_in_graph +from einops import rearrange +from torch import nn + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class TripoSGAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is + used in the TripoSG model. It applies a s normalization layer and rotary embedding on query and key vector. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + from diffusers.models.embeddings import apply_rotary_emb + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width + ).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view( + batch_size, attn.heads, -1, attention_mask.shape[-1] + ) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( + 1, 2 + ) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states + ) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + # NOTE that pre-trained models split heads first then split qkv or kv, like .view(..., attn.heads, 3, dim) + # instead of .view(..., 3, attn.heads, dim). So we need to re-split here. + if not attn.is_cross_attention: + qkv = torch.cat((query, key, value), dim=-1) + split_size = qkv.shape[-1] // attn.heads // 3 + qkv = qkv.view(batch_size, -1, attn.heads, split_size * 3) + query, key, value = torch.split(qkv, split_size, dim=-1) + else: + kv = torch.cat((key, value), dim=-1) + split_size = kv.shape[-1] // attn.heads // 2 + kv = kv.view(batch_size, -1, attn.heads, split_size * 2) + key, value = torch.split(kv, split_size, dim=-1) + + head_dim = key.shape[-1] + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb) + if not attn.is_cross_attention: + key = apply_rotary_emb(key, image_rotary_emb) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class FusedTripoSGAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0) with fused + projection layers. This is used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on + query and key vector. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "FusedTripoSGAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + from diffusers.models.embeddings import apply_rotary_emb + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width + ).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view( + batch_size, attn.heads, -1, attention_mask.shape[-1] + ) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( + 1, 2 + ) + + # NOTE that pre-trained split heads first, then split qkv + if encoder_hidden_states is None: + qkv = attn.to_qkv(hidden_states) + split_size = qkv.shape[-1] // attn.heads // 3 + qkv = qkv.view(batch_size, -1, attn.heads, split_size * 3) + query, key, value = torch.split(qkv, split_size, dim=-1) + else: + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states + ) + query = attn.to_q(hidden_states) + + kv = attn.to_kv(encoder_hidden_states) + split_size = kv.shape[-1] // attn.heads // 2 + kv = kv.view(batch_size, -1, attn.heads, split_size * 2) + key, value = torch.split(kv, split_size, dim=-1) + + head_dim = key.shape[-1] + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb) + if not attn.is_cross_attention: + key = apply_rotary_emb(key, image_rotary_emb) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class MIAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is + used in the MIDI model. It applies a normalization layer and rotary embedding on query and key vector. + """ + + def __init__(self, use_mi: bool = True): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + self.use_mi = use_mi + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + num_instances: Optional[torch.IntTensor] = None, + ) -> torch.Tensor: + from diffusers.models.embeddings import apply_rotary_emb + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width + ).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view( + batch_size, attn.heads, -1, attention_mask.shape[-1] + ) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( + 1, 2 + ) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states + ) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + # NOTE that pre-trained models split heads first then split qkv or kv, like .view(..., attn.heads, 3, dim) + # instead of .view(..., 3, attn.heads, dim). So we need to re-split here. + if not attn.is_cross_attention: + qkv = torch.cat((query, key, value), dim=-1) + split_size = qkv.shape[-1] // attn.heads // 3 + qkv = qkv.view(batch_size, -1, attn.heads, split_size * 3) + query, key, value = torch.split(qkv, split_size, dim=-1) + else: + kv = torch.cat((key, value), dim=-1) + split_size = kv.shape[-1] // attn.heads // 2 + kv = kv.view(batch_size, -1, attn.heads, split_size * 2) + key, value = torch.split(kv, split_size, dim=-1) + + head_dim = key.shape[-1] + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb) + if not attn.is_cross_attention: + key = apply_rotary_emb(key, image_rotary_emb) + + if self.use_mi and num_instances is not None: + key = rearrange( + key, "(b ni) h nt c -> b h (ni nt) c", ni=num_instances + ).repeat_interleave(num_instances, dim=0) + value = rearrange( + value, "(b ni) h nt c -> b h (ni nt) c", ni=num_instances + ).repeat_interleave(num_instances, dim=0) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + hidden_states = F.scaled_dot_product_attention( + query, + key, + value, + dropout_p=0.0, + is_causal=False, + ) + else: + hidden_states = F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + ) + + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states diff --git a/midi/models/autoencoders/__init__.py b/midi/models/autoencoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..64a9ea290fb3c006737000b0e046ac492e0ec26f --- /dev/null +++ b/midi/models/autoencoders/__init__.py @@ -0,0 +1 @@ +from .autoencoder_kl_triposg import TripoSGVAEModel diff --git a/midi/models/autoencoders/autoencoder_kl_triposg.py b/midi/models/autoencoders/autoencoder_kl_triposg.py new file mode 100644 index 0000000000000000000000000000000000000000..9aed3d0c4ba836a1e9fed51affec6ac86711eea7 --- /dev/null +++ b/midi/models/autoencoders/autoencoder_kl_triposg.py @@ -0,0 +1,541 @@ +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.attention_processor import Attention, AttentionProcessor +from diffusers.models.autoencoders.vae import DecoderOutput +from diffusers.models.modeling_outputs import AutoencoderKLOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import FP32LayerNorm, LayerNorm +from diffusers.utils import logging +from diffusers.utils.accelerate_utils import apply_forward_hook +from einops import repeat +from tqdm import tqdm +from torch_cluster import fps + +from ..attention_processor import FusedTripoSGAttnProcessor2_0, TripoSGAttnProcessor2_0 +from ..embeddings import FrequencyPositionalEmbedding +from ..transformers.triposg_transformer import DiTBlock +from .vae import DiagonalGaussianDistribution + +import subprocess +import sys + + +def install_package(package_name): + try: + subprocess.check_call([sys.executable, "-m", "pip", "install", package_name]) + return True + except subprocess.CalledProcessError: + return False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class TripoSGEncoder(nn.Module): + def __init__( + self, + in_channels: int = 3, + dim: int = 512, + num_attention_heads: int = 8, + num_layers: int = 8, + ): + super().__init__() + + self.proj_in = nn.Linear(in_channels, dim, bias=True) + + self.blocks = nn.ModuleList( + [ + DiTBlock( + dim=dim, + num_attention_heads=num_attention_heads, + use_self_attention=False, + use_cross_attention=True, + cross_attention_dim=dim, + cross_attention_norm_type="layer_norm", + activation_fn="gelu", + norm_type="fp32_layer_norm", + norm_eps=1e-5, + qk_norm=False, + qkv_bias=False, + ) # cross attention + ] + + [ + DiTBlock( + dim=dim, + num_attention_heads=num_attention_heads, + use_self_attention=True, + self_attention_norm_type="fp32_layer_norm", + use_cross_attention=False, + use_cross_attention_2=False, + activation_fn="gelu", + norm_type="fp32_layer_norm", + norm_eps=1e-5, + qk_norm=False, + qkv_bias=False, + ) + for _ in range(num_layers) # self attention + ] + ) + + self.norm_out = LayerNorm(dim) + + def forward(self, sample_1: torch.Tensor, sample_2: torch.Tensor): + hidden_states = self.proj_in(sample_1) + encoder_hidden_states = self.proj_in(sample_2) + + for layer, block in enumerate(self.blocks): + if layer == 0: + hidden_states = block( + hidden_states, encoder_hidden_states=encoder_hidden_states + ) + else: + hidden_states = block(hidden_states) + + hidden_states = self.norm_out(hidden_states) + + return hidden_states + + +class TripoSGDecoder(nn.Module): + def __init__( + self, + in_channels: int = 3, + out_channels: int = 1, + dim: int = 512, + num_attention_heads: int = 8, + num_layers: int = 16, + grad_type: str = "analytical", + grad_interval: float = 0.001, + ): + super().__init__() + + if grad_type not in ["numerical", "analytical"]: + raise ValueError(f"grad_type must be one of ['numerical', 'analytical']") + self.grad_type = grad_type + self.grad_interval = grad_interval + + self.blocks = nn.ModuleList( + [ + DiTBlock( + dim=dim, + num_attention_heads=num_attention_heads, + use_self_attention=True, + self_attention_norm_type="fp32_layer_norm", + use_cross_attention=False, + use_cross_attention_2=False, + activation_fn="gelu", + norm_type="fp32_layer_norm", + norm_eps=1e-5, + qk_norm=False, + qkv_bias=False, + ) + for _ in range(num_layers) # self attention + ] + + [ + DiTBlock( + dim=dim, + num_attention_heads=num_attention_heads, + use_self_attention=False, + use_cross_attention=True, + cross_attention_dim=dim, + cross_attention_norm_type="layer_norm", + activation_fn="gelu", + norm_type="fp32_layer_norm", + norm_eps=1e-5, + qk_norm=False, + qkv_bias=False, + ) # cross attention + ] + ) + + self.proj_query = nn.Linear(in_channels, dim, bias=True) + + self.norm_out = LayerNorm(dim) + self.proj_out = nn.Linear(dim, out_channels, bias=True) + + def query_geometry( + self, + model_fn: callable, + queries: torch.Tensor, + sample: torch.Tensor, + grad: bool = False, + ): + logits = model_fn(queries, sample) + if grad: + with torch.autocast(device_type="cuda", dtype=torch.float32): + if self.grad_type == "numerical": + interval = self.grad_interval + grad_value = [] + for offset in [ + (interval, 0, 0), + (0, interval, 0), + (0, 0, interval), + ]: + offset_tensor = torch.tensor(offset, device=queries.device)[ + None, : + ] + res_p = model_fn(queries + offset_tensor, sample)[..., 0] + res_n = model_fn(queries - offset_tensor, sample)[..., 0] + grad_value.append((res_p - res_n) / (2 * interval)) + grad_value = torch.stack(grad_value, dim=-1) + else: + queries_d = torch.clone(queries) + queries_d.requires_grad = True + with torch.enable_grad(): + res_d = model_fn(queries_d, sample) + grad_value = torch.autograd.grad( + res_d, + [queries_d], + grad_outputs=torch.ones_like(res_d), + create_graph=self.training, + )[0] + else: + grad_value = None + + return logits, grad_value + + def forward( + self, + sample: torch.Tensor, + queries: torch.Tensor, + kv_cache: Optional[torch.Tensor] = None, + ): + if kv_cache is None: + hidden_states = sample + for _, block in enumerate(self.blocks[:-1]): + hidden_states = block(hidden_states) + kv_cache = hidden_states + + # query grid logits by cross attention + def query_fn(q, kv): + q = self.proj_query(q) + l = self.blocks[-1](q, encoder_hidden_states=kv) + return self.proj_out(self.norm_out(l)) + + logits, grad = self.query_geometry( + query_fn, queries, kv_cache, grad=self.training + ) + logits = logits * -1 if not isinstance(logits, Tuple) else logits[0] * -1 + + return logits, kv_cache + + +class TripoSGVAEModel(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + in_channels: int = 3, # NOTE xyz instead of feature dim + latent_channels: int = 64, + num_attention_heads: int = 8, + width_encoder: int = 512, + width_decoder: int = 1024, + num_layers_encoder: int = 8, + num_layers_decoder: int = 16, + embedding_type: str = "frequency", + embed_frequency: int = 8, + embed_include_pi: bool = False, + ): + super().__init__() + + self.out_channels = 1 + + if embedding_type == "frequency": + self.embedder = FrequencyPositionalEmbedding( + num_freqs=embed_frequency, + logspace=True, + input_dim=in_channels, + include_pi=embed_include_pi, + ) + else: + raise NotImplementedError( + f"Embedding type {embedding_type} is not supported." + ) + + self.encoder = TripoSGEncoder( + in_channels=in_channels + self.embedder.out_dim, + dim=width_encoder, + num_attention_heads=num_attention_heads, + num_layers=num_layers_encoder, + ) + self.decoder = TripoSGDecoder( + in_channels=self.embedder.out_dim, + out_channels=self.out_channels, + dim=width_decoder, + num_attention_heads=num_attention_heads, + num_layers=num_layers_decoder, + ) + + self.quant = nn.Linear(width_encoder, latent_channels * 2, bias=True) + self.post_quant = nn.Linear(latent_channels, width_decoder, bias=True) + + self.use_slicing = False + self.slicing_length = 1 + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedTripoSGAttnProcessor2_0 + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError( + "`fuse_qkv_projections()` is not supported for models having added KV projections." + ) + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + self.set_attn_processor(FusedTripoSGAttnProcessor2_0()) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors( + name: str, + module: torch.nn.Module, + processors: Dict[str, AttentionProcessor], + ): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor( + self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]] + ): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(TripoSGAttnProcessor2_0()) + + def enable_slicing(self, slicing_length: int = 1) -> None: + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + self.slicing_length = slicing_length + + def disable_slicing(self) -> None: + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + def _sample_features( + self, x: torch.Tensor, num_tokens: int = 2048, seed: Optional[int] = None + ): + """ + Sample points from features of the input point cloud. + + Args: + x (torch.Tensor): The input point cloud. shape: (B, N, C) + num_tokens (int, optional): The number of points to sample. Defaults to 2048. + seed (Optional[int], optional): The random seed. Defaults to None. + """ + rng = np.random.default_rng(seed) + indices = rng.choice( + x.shape[1], num_tokens * 4, replace=num_tokens * 4 > x.shape[1] + ) + selected_points = x[:, indices] + + batch_size, num_points, num_channels = selected_points.shape + flattened_points = selected_points.view(batch_size * num_points, num_channels) + batch_indices = ( + torch.arange(batch_size).to(x.device).repeat_interleave(num_points) + ) + + # fps sampling + sampling_ratio = 1.0 / 4 + sampled_indices = fps( + flattened_points[:, :3], + batch_indices, + ratio=sampling_ratio, + random_start=self.training, + ) + sampled_points = flattened_points[sampled_indices].view( + batch_size, -1, num_channels + ) + + return sampled_points + + def _encode( + self, x: torch.Tensor, num_tokens: int = 2048, seed: Optional[int] = None + ): + position_channels = self.config.in_channels + positions, features = x[..., :position_channels], x[..., position_channels:] + x_kv = torch.cat([self.embedder(positions), features], dim=-1) + + sampled_x = self._sample_features(x, num_tokens, seed) + positions, features = ( + sampled_x[..., :position_channels], + sampled_x[..., position_channels:], + ) + x_q = torch.cat([self.embedder(positions), features], dim=-1) + + x = self.encoder(x_q, x_kv) + + x = self.quant(x) + + return x + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True, **kwargs + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + """ + Encode a batch of point features into latents. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [ + self._encode(x_slice, **kwargs) + for x_slice in x.split(self.slicing_length) + ] + h = torch.cat(encoded_slices) + else: + h = self._encode(x, **kwargs) + + posterior = DiagonalGaussianDistribution(h, feature_dim=-1) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode( + self, + z: torch.Tensor, + sampled_points: torch.Tensor, + num_chunks: int = 50000, + to_cpu: bool = False, + return_dict: bool = True, + ) -> Union[DecoderOutput, torch.Tensor]: + xyz_samples = sampled_points + + z = self.post_quant(z) + + num_points = xyz_samples.shape[1] + kv_cache = None + dec = [] + + for i in range(0, num_points, num_chunks): + queries = xyz_samples[:, i : i + num_chunks, :].to(z.device, dtype=z.dtype) + queries = self.embedder(queries) + + z_, kv_cache = self.decoder(z, queries, kv_cache) + dec.append(z_ if not to_cpu else z_.cpu()) + + z = torch.cat(dec, dim=1) + + if not return_dict: + return (z,) + + return DecoderOutput(sample=z) + + @apply_forward_hook + def decode( + self, + z: torch.Tensor, + sampled_points: torch.Tensor, + return_dict: bool = True, + **kwargs, + ) -> Union[DecoderOutput, torch.Tensor]: + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [ + self._decode(z_slice, p_slice, **kwargs).sample + for z_slice, p_slice in zip( + z.split(self.slicing_length), + sampled_points.split(self.slicing_length), + ) + ] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z, sampled_points, **kwargs).sample + + if not return_dict: + return (decoded,) + return DecoderOutput(sample=decoded) + + def forward(self, x: torch.Tensor): + pass diff --git a/midi/models/autoencoders/vae.py b/midi/models/autoencoders/vae.py new file mode 100644 index 0000000000000000000000000000000000000000..01aae603a3f1b298bacd723053d396c614028018 --- /dev/null +++ b/midi/models/autoencoders/vae.py @@ -0,0 +1,69 @@ +from typing import Optional, Tuple + +import numpy as np +import torch +from diffusers.utils.torch_utils import randn_tensor + + +class DiagonalGaussianDistribution(object): + def __init__( + self, + parameters: torch.Tensor, + deterministic: bool = False, + feature_dim: int = 1, + ): + self.parameters = parameters + self.feature_dim = feature_dim + self.mean, self.logvar = torch.chunk(parameters, 2, dim=feature_dim) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like( + self.mean, device=self.parameters.device, dtype=self.parameters.dtype + ) + + def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor: + # make sure sample is on the same device as the parameters and has same dtype + sample = randn_tensor( + self.mean.shape, + generator=generator, + device=self.parameters.device, + dtype=self.parameters.dtype, + ) + x = self.mean + self.std * sample + return x + + def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor: + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.sum( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=[1, 2, 3], + ) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) + + def nll( + self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3] + ) -> torch.Tensor: + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims, + ) + + def mode(self) -> torch.Tensor: + return self.mean diff --git a/midi/models/embeddings.py b/midi/models/embeddings.py new file mode 100644 index 0000000000000000000000000000000000000000..1fdb5eaa8570c87d2ffbfc7eeeaba8b070874a69 --- /dev/null +++ b/midi/models/embeddings.py @@ -0,0 +1,96 @@ +import torch +import torch.nn as nn + + +class FrequencyPositionalEmbedding(nn.Module): + """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts + each feature dimension of `x[..., i]` into: + [ + sin(x[..., i]), + sin(f_1*x[..., i]), + sin(f_2*x[..., i]), + ... + sin(f_N * x[..., i]), + cos(x[..., i]), + cos(f_1*x[..., i]), + cos(f_2*x[..., i]), + ... + cos(f_N * x[..., i]), + x[..., i] # only present if include_input is True. + ], here f_i is the frequency. + + Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs]. + If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...]; + Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]. + + Args: + num_freqs (int): the number of frequencies, default is 6; + logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], + otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]; + input_dim (int): the input dimension, default is 3; + include_input (bool): include the input tensor or not, default is True. + + Attributes: + frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], + otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1); + + out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1), + otherwise, it is input_dim * num_freqs * 2. + + """ + + def __init__( + self, + num_freqs: int = 6, + logspace: bool = True, + input_dim: int = 3, + include_input: bool = True, + include_pi: bool = True, + ) -> None: + """The initialization""" + + super().__init__() + + if logspace: + frequencies = 2.0 ** torch.arange(num_freqs, dtype=torch.float32) + else: + frequencies = torch.linspace( + 1.0, 2.0 ** (num_freqs - 1), num_freqs, dtype=torch.float32 + ) + + if include_pi: + frequencies *= torch.pi + + self.register_buffer("frequencies", frequencies, persistent=False) + self.include_input = include_input + self.num_freqs = num_freqs + + self.out_dim = self.get_dims(input_dim) + + def get_dims(self, input_dim): + temp = 1 if self.include_input or self.num_freqs == 0 else 0 + out_dim = input_dim * (self.num_freqs * 2 + temp) + + return out_dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward process. + + Args: + x: tensor of shape [..., dim] + + Returns: + embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)] + where temp is 1 if include_input is True and 0 otherwise. + """ + + if self.num_freqs > 0: + embed = (x[..., None].contiguous() * self.frequencies).view( + *x.shape[:-1], -1 + ) + if self.include_input: + return torch.cat((x, embed.sin(), embed.cos()), dim=-1) + else: + return torch.cat((embed.sin(), embed.cos()), dim=-1) + else: + return x diff --git a/midi/models/transformers/__init__.py b/midi/models/transformers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..952297cfd3141c281c18bca37523949a78529a9f --- /dev/null +++ b/midi/models/transformers/__init__.py @@ -0,0 +1,61 @@ +from typing import Callable, Optional + +from .triposg_transformer import TripoSGDiTModel + + +def default_set_attn_proc_func( + name: str, + hidden_size: int, + cross_attention_dim: Optional[int], + ori_attn_proc: object, +) -> object: + return ori_attn_proc + + +def set_transformer_attn_processor( + transformer: TripoSGDiTModel, + set_self_attn_proc_func: Callable = default_set_attn_proc_func, + set_cross_attn_1_proc_func: Callable = default_set_attn_proc_func, + set_cross_attn_2_proc_func: Callable = default_set_attn_proc_func, + set_self_attn_module_names: Optional[list[str]] = None, + set_cross_attn_1_module_names: Optional[list[str]] = None, + set_cross_attn_2_module_names: Optional[list[str]] = None, +) -> None: + do_set_processor = lambda name, module_names: ( + any([name.startswith(module_name) for module_name in module_names]) + if module_names is not None + else True + ) # prefix match + + attn_procs = {} + for name, attn_processor in transformer.attn_processors.items(): + hidden_size = transformer.config.width + if name.endswith("attn1.processor"): + # self attention + attn_procs[name] = ( + set_self_attn_proc_func(name, hidden_size, None, attn_processor) + if do_set_processor(name, set_self_attn_module_names) + else attn_processor + ) + elif name.endswith("attn2.processor"): + # cross attention + cross_attention_dim = transformer.config.cross_attention_dim + attn_procs[name] = ( + set_cross_attn_1_proc_func( + name, hidden_size, cross_attention_dim, attn_processor + ) + if do_set_processor(name, set_cross_attn_1_module_names) + else attn_processor + ) + elif name.endswith("attn2_2.processor"): + # cross attention 2 + cross_attention_dim = transformer.config.cross_attention_2_dim + attn_procs[name] = ( + set_cross_attn_2_proc_func( + name, hidden_size, cross_attention_dim, attn_processor + ) + if do_set_processor(name, set_cross_attn_2_module_names) + else attn_processor + ) + + transformer.set_attn_processor(attn_procs) diff --git a/midi/models/transformers/modeling_outputs.py b/midi/models/transformers/modeling_outputs.py new file mode 100644 index 0000000000000000000000000000000000000000..0928fa0ca39275a85f8b7fa49c68af745d4c74c5 --- /dev/null +++ b/midi/models/transformers/modeling_outputs.py @@ -0,0 +1,8 @@ +from dataclasses import dataclass + +import torch + + +@dataclass +class Transformer1DModelOutput: + sample: torch.FloatTensor diff --git a/midi/models/transformers/triposg_transformer.py b/midi/models/transformers/triposg_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..fcb4d4cb119216696f01b5d34436fe10f88785f8 --- /dev/null +++ b/midi/models/transformers/triposg_transformer.py @@ -0,0 +1,690 @@ +# Copyright 2024 HunyuanDiT Authors, Qixun Wang and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import PeftAdapterMixin +from diffusers.models.attention import FeedForward +from diffusers.models.attention_processor import Attention, AttentionProcessor +from diffusers.models.embeddings import ( + GaussianFourierProjection, + TimestepEmbedding, + Timesteps, +) +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import ( + AdaLayerNormContinuous, + FP32LayerNorm, + LayerNorm, +) +from diffusers.utils import ( + USE_PEFT_BACKEND, + is_torch_version, + logging, + scale_lora_layers, + unscale_lora_layers, +) +from diffusers.utils.torch_utils import maybe_allow_in_graph +from torch import nn + +from ..attention_processor import FusedTripoSGAttnProcessor2_0, TripoSGAttnProcessor2_0 +from .modeling_outputs import Transformer1DModelOutput + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@maybe_allow_in_graph +class DiTBlock(nn.Module): + r""" + Transformer block used in Hunyuan-DiT model (https://github.com/Tencent/HunyuanDiT). Allow skip connection and + QKNorm + + Parameters: + dim (`int`): + The number of channels in the input and output. + num_attention_heads (`int`): + The number of headsto use for multi-head attention. + cross_attention_dim (`int`,*optional*): + The size of the encoder_hidden_states vector for cross attention. + dropout(`float`, *optional*, defaults to 0.0): + The dropout probability to use. + activation_fn (`str`,*optional*, defaults to `"geglu"`): + Activation function to be used in feed-forward. . + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_eps (`float`, *optional*, defaults to 1e-6): + A small constant added to the denominator in normalization layers to prevent division by zero. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + ff_inner_dim (`int`, *optional*): + The size of the hidden layer in the feed-forward block. Defaults to `None`. + ff_bias (`bool`, *optional*, defaults to `True`): + Whether to use bias in the feed-forward block. + skip (`bool`, *optional*, defaults to `False`): + Whether to use skip connection. Defaults to `False` for down-blocks and mid-blocks. + qk_norm (`bool`, *optional*, defaults to `True`): + Whether to use normalization in QK calculation. Defaults to `True`. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + use_self_attention: bool = True, + use_cross_attention: bool = False, + self_attention_norm_type: Optional[str] = None, # ada layer norm + cross_attention_dim: Optional[int] = None, + cross_attention_norm_type: Optional[str] = "fp32_layer_norm", + # parallel second cross attention + use_cross_attention_2: bool = False, + cross_attention_2_dim: Optional[int] = None, + cross_attention_2_norm_type: Optional[str] = None, + dropout=0.0, + activation_fn: str = "gelu", + norm_type: str = "fp32_layer_norm", # TODO + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + final_dropout: bool = False, + ff_inner_dim: Optional[int] = None, # int(dim * 4) if None + ff_bias: bool = True, + skip: bool = False, + skip_concat_front: bool = False, # [x, skip] or [skip, x] + skip_norm_last: bool = False, # this is an error + qk_norm: bool = True, + qkv_bias: bool = True, + ): + super().__init__() + + self.use_self_attention = use_self_attention + self.use_cross_attention = use_cross_attention + self.use_cross_attention_2 = use_cross_attention_2 + self.skip_concat_front = skip_concat_front + self.skip_norm_last = skip_norm_last + # Define 3 blocks. Each block has its own normalization layer. + # NOTE: when new version comes, check norm2 and norm 3 + # 1. Self-Attn + if use_self_attention: + if ( + self_attention_norm_type == "fp32_layer_norm" + or self_attention_norm_type is None + ): + self.norm1 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine) + else: + raise NotImplementedError + + self.attn1 = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=dim // num_attention_heads, + heads=num_attention_heads, + qk_norm="rms_norm" if qk_norm else None, + eps=1e-6, + bias=qkv_bias, + processor=TripoSGAttnProcessor2_0(), + ) + + # 2. Cross-Attn + if use_cross_attention: + assert cross_attention_dim is not None + + self.norm2 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + dim_head=dim // num_attention_heads, + heads=num_attention_heads, + qk_norm="rms_norm" if qk_norm else None, + cross_attention_norm=cross_attention_norm_type, + eps=1e-6, + bias=qkv_bias, + processor=TripoSGAttnProcessor2_0(), + ) + + # 2'. Parallel Second Cross-Attn + if use_cross_attention_2: + assert cross_attention_2_dim is not None + + self.norm2_2 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine) + + self.attn2_2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_2_dim, + dim_head=dim // num_attention_heads, + heads=num_attention_heads, + qk_norm="rms_norm" if qk_norm else None, + cross_attention_norm=cross_attention_2_norm_type, + eps=1e-6, + bias=qkv_bias, + processor=TripoSGAttnProcessor2_0(), + ) + + # 3. Feed-forward + self.norm3 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine) + + self.ff = FeedForward( + dim, + dropout=dropout, ### 0.0 + activation_fn=activation_fn, ### approx GeLU + final_dropout=final_dropout, ### 0.0 + inner_dim=ff_inner_dim, ### int(dim * mlp_ratio) + bias=ff_bias, + ) + + # 4. Skip Connection + if skip: + self.skip_norm = FP32LayerNorm(dim, norm_eps, elementwise_affine=True) + self.skip_linear = nn.Linear(2 * dim, dim) + else: + self.skip_linear = None + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_hidden_states_2: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + skip: Optional[torch.Tensor] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> torch.Tensor: + # Prepare attention kwargs + attention_kwargs = attention_kwargs or {} + + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Long Skip Connection + if self.skip_linear is not None: + cat = torch.cat( + ( + [skip, hidden_states] + if self.skip_concat_front + else [hidden_states, skip] + ), + dim=-1, + ) + if self.skip_norm_last: + # don't do this + hidden_states = self.skip_linear(cat) + hidden_states = self.skip_norm(hidden_states) + else: + cat = self.skip_norm(cat) + hidden_states = self.skip_linear(cat) + + # 1. Self-Attention + if self.use_self_attention: + norm_hidden_states = self.norm1(hidden_states) + attn_output = self.attn1( + norm_hidden_states, + image_rotary_emb=image_rotary_emb, + **attention_kwargs, + ) + hidden_states = hidden_states + attn_output + + # 2. Cross-Attention + if self.use_cross_attention: + if self.use_cross_attention_2: + hidden_states = ( + hidden_states + + self.attn2( + self.norm2(hidden_states), + encoder_hidden_states=encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + **attention_kwargs, + ) + + self.attn2_2( + self.norm2_2(hidden_states), + encoder_hidden_states=encoder_hidden_states_2, + image_rotary_emb=image_rotary_emb, + **attention_kwargs, + ) + ) + else: + hidden_states = hidden_states + self.attn2( + self.norm2(hidden_states), + encoder_hidden_states=encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + **attention_kwargs, + ) + + # FFN Layer ### TODO: switch norm2 and norm3 in the state dict + mlp_inputs = self.norm3(hidden_states) + hidden_states = hidden_states + self.ff(mlp_inputs) + + return hidden_states + + +class TripoSGDiTModel(ModelMixin, ConfigMixin, PeftAdapterMixin): + """ + TripoSG: Diffusion model with a Transformer backbone. + + Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): + The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): + The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + patch_size (`int`, *optional*): + The size of the patch to use for the input. + activation_fn (`str`, *optional*, defaults to `"geglu"`): + Activation function to use in feed-forward. + sample_size (`int`, *optional*): + The width of the latent images. This is fixed during training since it is used to learn a number of + position embeddings. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + cross_attention_dim (`int`, *optional*): + The number of dimension in the clip text embedding. + hidden_size (`int`, *optional*): + The size of hidden layer in the conditioning embedding layers. + num_layers (`int`, *optional*, defaults to 1): + The number of layers of Transformer blocks to use. + mlp_ratio (`float`, *optional*, defaults to 4.0): + The ratio of the hidden layer size to the input size. + learn_sigma (`bool`, *optional*, defaults to `True`): + Whether to predict variance. + cross_attention_dim_t5 (`int`, *optional*): + The number dimensions in t5 text embedding. + pooled_projection_dim (`int`, *optional*): + The size of the pooled projection. + text_len (`int`, *optional*): + The length of the clip text embedding. + text_len_t5 (`int`, *optional*): + The length of the T5 text embedding. + use_style_cond_and_image_meta_size (`bool`, *optional*): + Whether or not to use style condition and image meta size. True for version <=1.1, False for version >= 1.2 + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + width: int = 2048, + in_channels: int = 64, + num_layers: int = 21, + cross_attention_dim: int = 768, + cross_attention_2_dim: int = 1024, + ): + super().__init__() + self.out_channels = in_channels + self.num_heads = num_attention_heads + self.inner_dim = width + self.mlp_ratio = 4.0 + + time_embed_dim, timestep_input_dim = self._set_time_proj( + "positional", + inner_dim=self.inner_dim, + flip_sin_to_cos=False, + freq_shift=0, + time_embedding_dim=None, + ) + self.time_proj = TimestepEmbedding( + timestep_input_dim, time_embed_dim, act_fn="gelu", out_dim=self.inner_dim + ) + self.proj_in = nn.Linear(self.config.in_channels, self.inner_dim, bias=True) + + self.blocks = nn.ModuleList( + [ + DiTBlock( + dim=self.inner_dim, + num_attention_heads=self.config.num_attention_heads, + use_self_attention=True, + use_cross_attention=True, + self_attention_norm_type="fp32_layer_norm", + cross_attention_dim=self.config.cross_attention_dim, + cross_attention_norm_type=None, + use_cross_attention_2=True, + cross_attention_2_dim=self.config.cross_attention_2_dim, + cross_attention_2_norm_type=None, + activation_fn="gelu", + norm_type="fp32_layer_norm", # TODO + norm_eps=1e-5, + ff_inner_dim=int(self.inner_dim * self.mlp_ratio), + skip=layer > num_layers // 2, + skip_concat_front=True, + skip_norm_last=True, # this is an error + qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details. + qkv_bias=False, + ) + for layer in range(num_layers) + ] + ) + + self.norm_out = LayerNorm(self.inner_dim) + self.proj_out = nn.Linear(self.inner_dim, self.out_channels, bias=True) + + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value=False): + self.gradient_checkpointing = value + + def _set_time_proj( + self, + time_embedding_type: str, + inner_dim: int, + flip_sin_to_cos: bool, + freq_shift: float, + time_embedding_dim: int, + ) -> Tuple[int, int]: + if time_embedding_type == "fourier": + time_embed_dim = time_embedding_dim or inner_dim * 2 + if time_embed_dim % 2 != 0: + raise ValueError( + f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}." + ) + self.time_embed = GaussianFourierProjection( + time_embed_dim // 2, + set_W_to_weight=False, + log=False, + flip_sin_to_cos=flip_sin_to_cos, + ) + timestep_input_dim = time_embed_dim + elif time_embedding_type == "positional": + time_embed_dim = time_embedding_dim or inner_dim * 4 + + self.time_embed = Timesteps(inner_dim, flip_sin_to_cos, freq_shift) + timestep_input_dim = inner_dim + else: + raise ValueError( + f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." + ) + + return time_embed_dim, timestep_input_dim + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedTripoSGAttnProcessor2_0 + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError( + "`fuse_qkv_projections()` is not supported for models having added KV projections." + ) + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + self.set_attn_processor(FusedTripoSGAttnProcessor2_0()) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors( + name: str, + module: torch.nn.Module, + processors: Dict[str, AttentionProcessor], + ): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor( + self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]] + ): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(TripoSGAttnProcessor2_0()) + + def forward( + self, + hidden_states: Optional[torch.Tensor], + timestep: Union[int, float, torch.LongTensor], + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_hidden_states_2: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ): + """ + The [`HunyuanDiT2DModel`] forward method. + + Args: + hidden_states (`torch.Tensor` of shape `(batch size, dim, height, width)`): + The input tensor. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. + encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. + encoder_hidden_states_2 ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. + return_dict: bool + Whether to return a dictionary. + """ + + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if ( + attention_kwargs is not None + and attention_kwargs.get("scale", None) is not None + ): + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + _, N, _ = hidden_states.shape + + temb = self.time_embed(timestep).to(hidden_states.dtype) + temb = self.time_proj(temb) + temb = temb.unsqueeze(dim=1) # unsqueeze to concat with hidden_states + + hidden_states = self.proj_in(hidden_states) + + # N + 1 token + hidden_states = torch.cat([temb, hidden_states], dim=1) + + skips = [] + for layer, block in enumerate(self.blocks): + skip = None if layer <= self.config.num_layers // 2 else skips.pop() + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = ( + {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + encoder_hidden_states_2, + temb, + image_rotary_emb, + skip, + attention_kwargs, + **ckpt_kwargs, + ) + else: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_2=encoder_hidden_states_2, + temb=temb, + image_rotary_emb=image_rotary_emb, + skip=skip, + attention_kwargs=attention_kwargs, + ) # (N, L, D) + + if layer < self.config.num_layers // 2: + skips.append(hidden_states) + + # final layer + hidden_states = self.norm_out(hidden_states) + hidden_states = hidden_states[:, -N:] + hidden_states = self.proj_out(hidden_states) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (hidden_states,) + + return Transformer1DModelOutput(sample=hidden_states) + + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking + def enable_forward_chunking( + self, chunk_size: Optional[int] = None, dim: int = 0 + ) -> None: + """ + Sets the attention processor to use [feed forward + chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). + + Parameters: + chunk_size (`int`, *optional*): + The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually + over each tensor of dim=`dim`. + dim (`int`, *optional*, defaults to `0`): + The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) + or dim=1 (sequence length). + """ + if dim not in [0, 1]: + raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") + + # By default chunk size is 1 + chunk_size = chunk_size or 1 + + def fn_recursive_feed_forward( + module: torch.nn.Module, chunk_size: int, dim: int + ): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, chunk_size, dim) + + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking + def disable_forward_chunking(self): + def fn_recursive_feed_forward( + module: torch.nn.Module, chunk_size: int, dim: int + ): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, None, 0) diff --git a/midi/pipelines/pipeline_midi.py b/midi/pipelines/pipeline_midi.py new file mode 100644 index 0000000000000000000000000000000000000000..de208c91a9d306f0cdf63d9ad78492d3ce53765d --- /dev/null +++ b/midi/pipelines/pipeline_midi.py @@ -0,0 +1,497 @@ +import inspect +import math +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL +import PIL.Image +import torch +import torch.nn.functional as F +from diffusers.image_processor import PipelineImageInput +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler # not sure +from diffusers.utils import logging +from diffusers.utils.torch_utils import randn_tensor +from peft import LoraConfig, get_peft_model_state_dict +from transformers import ( + BitImageProcessor, + CLIPImageProcessor, + CLIPVisionModelWithProjection, + Dinov2Model, +) + +from ..inference_utils import generate_dense_grid_points +from ..loaders import CustomAdapterMixin +from ..models.attention_processor import MIAttnProcessor2_0 +from ..models.autoencoders import TripoSGVAEModel +from ..models.transformers import TripoSGDiTModel, set_transformer_attn_processor +from .pipeline_triposg_output import TripoSGPipelineOutput +from .pipeline_utils import TransformerDiffusionMixin + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError( + "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" + ) + if timesteps is not None: + accepts_timesteps = "timesteps" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys() + ) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys() + ) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class MIDIPipeline(DiffusionPipeline, TransformerDiffusionMixin, CustomAdapterMixin): + """ + Pipeline for image-to-scene generation based on pre-trained shape diffusion. + """ + + def __init__( + self, + vae: TripoSGVAEModel, + transformer: TripoSGDiTModel, + scheduler: FlowMatchEulerDiscreteScheduler, + image_encoder_1: CLIPVisionModelWithProjection, + image_encoder_2: Dinov2Model, + feature_extractor_1: CLIPImageProcessor, + feature_extractor_2: BitImageProcessor, + ): + super().__init__() + + self.register_modules( + vae=vae, + transformer=transformer, + scheduler=scheduler, + image_encoder_1=image_encoder_1, + image_encoder_2=image_encoder_2, + feature_extractor_1=feature_extractor_1, + feature_extractor_2=feature_extractor_2, + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @property + def decode_progressive(self): + return self._decode_progressive + + def encode_image_1(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder_1.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor_1(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder_1(image).image_embeds + image_embeds = image_embeds.repeat_interleave( + num_images_per_prompt, dim=0 + ).unsqueeze(1) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + def encode_image_2( + self, + image_one, + image_two, + mask, + device, + num_images_per_prompt, + ): + dtype = next(self.image_encoder_2.parameters()).dtype + + images = [image_one, image_two, mask] + images_new = [] + for i, image in enumerate(images): + if not isinstance(image, torch.Tensor): + if i <= 1: + images_new.append( + self.feature_extractor_2( + image, return_tensors="pt" + ).pixel_values + ) + else: + image = [ + torch.from_numpy( + (np.array(im) / 255.0).astype(np.float32) + ).unsqueeze(0) + for im in image + ] + image = torch.stack(image, dim=0) + images_new.append( + F.interpolate( + image, size=images_new[0].shape[-2:], mode="nearest" + ) + ) + + image = torch.cat(images_new, dim=1).to(device=device, dtype=dtype) + image_embeds = self.image_encoder_2(image).last_hidden_state + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + def prepare_latents( + self, + batch_size, + num_tokens, + num_channels_latents, + dtype, + device, + generator, + latents: Optional[torch.Tensor] = None, + ): + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = (batch_size, num_tokens, num_channels_latents) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + return latents + + @torch.no_grad() + def decode_latents( + self, + latents: torch.Tensor, + sampled_points: torch.Tensor, + decode_progressive: bool = False, + decode_to_cpu: bool = False, + # Params for sampling points + bbox_min: np.ndarray = np.array([-1.005, -1.005, -1.005]), + bbox_max: np.ndarray = np.array([1.005, 1.005, 1.005]), + octree_depth: int = 8, + indexing: str = "ij", + padding: float = 0.05, + ): + device, dtype = latents.device, latents.dtype + batch_size = latents.shape[0] + + grid_sizes, bbox_sizes, bbox_mins, bbox_maxs = [], [], [], [] + + if sampled_points is None: + sampled_points, grid_size, bbox_size = generate_dense_grid_points( + bbox_min, bbox_max, octree_depth, indexing + ) + sampled_points = torch.FloatTensor(sampled_points).to( + device=device, dtype=dtype + ) + sampled_points = sampled_points.unsqueeze(0).expand(batch_size, -1, -1) + + grid_sizes.append(grid_size) + bbox_sizes.append(bbox_size) + bbox_mins.append(bbox_min) + bbox_maxs.append(bbox_max) + + self.vae: TripoSGVAEModel + output = self.vae.decode( + latents, sampled_points=sampled_points, to_cpu=decode_to_cpu + ).sample + + if not decode_progressive: + return (output, grid_sizes, bbox_sizes, bbox_mins, bbox_maxs) + + grid_sizes, bbox_sizes, bbox_mins, bbox_maxs = [], [], [], [] + sampled_points_list = [] + + for i in range(batch_size): + sdf_ = output[i].squeeze(-1) # [num_points] + sampled_points_ = sampled_points[i] + occupied_points = sampled_points_[sdf_ <= 0] # [num_occupied_points, 3] + + if occupied_points.shape[0] == 0: + logger.warning( + f"No occupied points found in batch {i}. Using original bounding box." + ) + else: + bbox_min = occupied_points.min(dim=0).values + bbox_max = occupied_points.max(dim=0).values + bbox_min = (bbox_min - padding).float().cpu().numpy() + bbox_max = (bbox_max + padding).float().cpu().numpy() + + sampled_points_, grid_size, bbox_size = generate_dense_grid_points( + bbox_min, bbox_max, octree_depth, indexing + ) + sampled_points_ = torch.FloatTensor(sampled_points_).to( + device=device, dtype=dtype + ) + sampled_points_list.append(sampled_points_) + + grid_sizes.append(grid_size) + bbox_sizes.append(bbox_size) + bbox_mins.append(bbox_min) + bbox_maxs.append(bbox_max) + + sampled_points = torch.stack(sampled_points_list, dim=0) + + # Re-decode the new sampled points + output = self.vae.decode( + latents, sampled_points=sampled_points, to_cpu=decode_to_cpu + ).sample + + return (output, grid_sizes, bbox_sizes, bbox_mins, bbox_maxs) + + @torch.no_grad() + def __call__( + self, + image: PipelineImageInput, + mask: PipelineImageInput, + image_scene: PipelineImageInput, + num_inference_steps: int = 50, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + num_images_per_prompt: int = 1, + sampled_points: Optional[torch.Tensor] = None, + decode_progressive: bool = False, + decode_to_cpu: bool = False, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + output_type: Optional[str] = "mesh_vf", + return_dict: bool = True, + ): + # 1. Check inputs. Raise error if not correct + # TODO + + self._decode_progressive = decode_progressive + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if isinstance(image, PIL.Image.Image): + batch_size = 1 + elif isinstance(image, list): + batch_size = len(image) + elif isinstance(image, torch.Tensor): + batch_size = image.shape[0] + else: + raise ValueError("Invalid input type for image") + + device = self._execution_device + + # 3. Encode condition + image_embeds_1, negative_image_embeds_1 = self.encode_image_1( + image, device, num_images_per_prompt + ) + image_embeds_2, negative_image_embeds_2 = self.encode_image_2( + image, image_scene, mask, device, num_images_per_prompt + ) + + if self.do_classifier_free_guidance: + image_embeds_1 = torch.cat([negative_image_embeds_1, image_embeds_1], dim=0) + image_embeds_2 = torch.cat([negative_image_embeds_2, image_embeds_2], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps + ) + num_warmup_steps = max( + len(timesteps) - num_inference_steps * self.scheduler.order, 0 + ) + self._num_timesteps = len(timesteps) + + # 5. Prepare latent variables + num_tokens = self.transformer.config.width + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_tokens, + num_channels_latents, + image_embeds_1.dtype, + device, + generator, + latents, + ) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + torch.cat([latents] * 2) + if self.do_classifier_free_guidance + else latents + ) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + noise_pred = self.transformer( + latent_model_input, + timestep, + encoder_hidden_states=image_embeds_1, + encoder_hidden_states_2=image_embeds_2, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_image = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * ( + noise_pred_image - noise_pred_uncond + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step( + noise_pred, t, latents, return_dict=False + )[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + image_embeds_1 = callback_outputs.pop( + "image_embeds_1", image_embeds_1 + ) + negative_image_embeds_1 = callback_outputs.pop( + "negative_image_embeds_1", negative_image_embeds_1 + ) + image_embeds_2 = callback_outputs.pop( + "image_embeds_2", image_embeds_2 + ) + negative_image_embeds_2 = callback_outputs.pop( + "negative_image_embeds_2", negative_image_embeds_2 + ) + + # call the callback, if provided + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + progress_bar.update() + + grid_sizes, bbox_sizes, bbox_mins, bbox_maxs = None, None, None, None + + if output_type == "latent": + output = latents + else: + output, grid_sizes, bbox_sizes, bbox_mins, bbox_maxs = self.decode_latents( + latents, + sampled_points=sampled_points, + decode_progressive=decode_progressive, + decode_to_cpu=decode_to_cpu, + ) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (output, grid_sizes, bbox_sizes, bbox_mins, bbox_maxs) + + return TripoSGPipelineOutput( + samples=output, + grid_sizes=grid_sizes, + bbox_sizes=bbox_sizes, + bbox_mins=bbox_mins, + bbox_maxs=bbox_maxs, + ) + + def _init_custom_adapter( + self, set_self_attn_module_names: Optional[List[str]] = None + ): + # Set attention processor + func_default = lambda name, hs, cad, ap: MIAttnProcessor2_0(use_mi=False) + set_transformer_attn_processor( # avoid warning + self.transformer, + set_self_attn_proc_func=func_default, + set_cross_attn_1_proc_func=func_default, + set_cross_attn_2_proc_func=func_default, + ) + set_transformer_attn_processor( + self.transformer, + set_self_attn_proc_func=lambda name, hs, cad, ap: MIAttnProcessor2_0(), + set_self_attn_module_names=set_self_attn_module_names, + ) diff --git a/midi/pipelines/pipeline_triposg_output.py b/midi/pipelines/pipeline_triposg_output.py new file mode 100644 index 0000000000000000000000000000000000000000..d9030716b3f528dfe280059e83e812cbdd0c7c74 --- /dev/null +++ b/midi/pipelines/pipeline_triposg_output.py @@ -0,0 +1,25 @@ +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy as np +import torch +from diffusers.utils import BaseOutput + +PipelineBoxOutput = Union[ + List[List[int]], # [[257, 257, 257], ...] + List[List[float]], # [[-1.05, -1.05, -1.05], ...] + List[np.ndarray], +] + + +@dataclass +class TripoSGPipelineOutput(BaseOutput): + r""" + Output class for TripoSG pipelines. + """ + + samples: torch.Tensor + grid_sizes: Optional[PipelineBoxOutput] = None + bbox_sizes: Optional[PipelineBoxOutput] = None + bbox_mins: Optional[PipelineBoxOutput] = None + bbox_maxs: Optional[PipelineBoxOutput] = None diff --git a/midi/pipelines/pipeline_utils.py b/midi/pipelines/pipeline_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a2dc4413b7c859e0dc88ade2aa12629414026274 --- /dev/null +++ b/midi/pipelines/pipeline_utils.py @@ -0,0 +1,96 @@ +from diffusers.utils import logging + +logger = logging.get_logger(__name__) + + +class TransformerDiffusionMixin: + r""" + Helper for DiffusionPipeline with vae and transformer.(mainly for DIT) + """ + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def fuse_qkv_projections(self, transformer: bool = True, vae: bool = True): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + + Args: + transformer (`bool`, defaults to `True`): To apply fusion on the Transformer. + vae (`bool`, defaults to `True`): To apply fusion on the VAE. + """ + self.fusing_transformer = False + self.fusing_vae = False + + if transformer: + self.fusing_transformer = True + self.transformer.fuse_qkv_projections() + + if vae: + self.fusing_vae = True + self.vae.fuse_qkv_projections() + + def unfuse_qkv_projections(self, transformer: bool = True, vae: bool = True): + """Disable QKV projection fusion if enabled. + + + + This API is 🧪 experimental. + + + + Args: + transformer (`bool`, defaults to `True`): To apply fusion on the Transformer. + vae (`bool`, defaults to `True`): To apply fusion on the VAE. + + """ + if transformer: + if not self.fusing_transformer: + logger.warning( + "The UNet was not initially fused for QKV projections. Doing nothing." + ) + else: + self.transformer.unfuse_qkv_projections() + self.fusing_transformer = False + + if vae: + if not self.fusing_vae: + logger.warning( + "The VAE was not initially fused for QKV projections. Doing nothing." + ) + else: + self.vae.unfuse_qkv_projections() + self.fusing_vae = False diff --git a/midi/schedulers/__init__.py b/midi/schedulers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..619c3313f7b568c7500aeafa1ed91ad1b85400e5 --- /dev/null +++ b/midi/schedulers/__init__.py @@ -0,0 +1,5 @@ +from .scheduling_rectified_flow import ( + RectifiedFlowScheduler, + compute_density_for_timestep_sampling, + compute_loss_weighting, +) diff --git a/midi/schedulers/scheduling_rectified_flow.py b/midi/schedulers/scheduling_rectified_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..d97c1bb1d833165a6150b1e5173232a49e93b7ea --- /dev/null +++ b/midi/schedulers/scheduling_rectified_flow.py @@ -0,0 +1,327 @@ +""" +Adapted from https://github.com/huggingface/diffusers/blob/v0.30.3/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py. +""" + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import SchedulerMixin +from diffusers.utils import BaseOutput, logging +from torch.distributions import LogisticNormal + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# TODO: may move to training_utils.py +def compute_density_for_timestep_sampling( + weighting_scheme: str, + batch_size: int, + logit_mean: float = 0.0, + logit_std: float = 1.0, + mode_scale: float = None, +): + if weighting_scheme == "logit_normal": + # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). + u = torch.normal( + mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu" + ) + u = torch.nn.functional.sigmoid(u) + elif weighting_scheme == "logit_normal_dist": + u = ( + LogisticNormal(loc=logit_mean, scale=logit_std) + .sample((batch_size,))[:, 0] + .to("cpu") + ) + elif weighting_scheme == "mode": + u = torch.rand(size=(batch_size,), device="cpu") + u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) + else: + u = torch.rand(size=(batch_size,), device="cpu") + return u + + +def compute_loss_weighting(weighting_scheme: str, sigmas=None): + """ + Computes loss weighting scheme for SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "sigma_sqrt": + weighting = (sigmas**-2.0).float() + elif weighting_scheme == "cosmap": + bot = 1 - 2 * sigmas + 2 * sigmas**2 + weighting = 2 / (math.pi * bot) + else: + weighting = torch.ones_like(sigmas) + return weighting + + +@dataclass +class RectifiedFlowSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.FloatTensor + + +class RectifiedFlowScheduler(SchedulerMixin, ConfigMixin): + """ + The rectified flow scheduler is a scheduler that is used to propagate the diffusion process in the rectified flow. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + shift (`float`, defaults to 1.0): + The shift value for the timestep schedule. + """ + + _compatibles = [] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + use_dynamic_shifting: bool = False, + ): + # pre-compute timesteps and sigmas; no use in fact + # NOTE that shape diffusion sample timesteps randomly or in a distribution, + # instead of sampling from the pre-defined linspace + timesteps = np.array( + [ + (1.0 - i / num_train_timesteps) * num_train_timesteps + for i in range(num_train_timesteps) + ] + ) + timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) + + sigmas = timesteps / num_train_timesteps + if not use_dynamic_shifting: + # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution + sigmas = self.time_shift(sigmas) + + self.timesteps = sigmas * num_train_timesteps + + self._step_index = None + self._begin_index = None + + self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def _t_to_sigma(self, timestep): + return timestep / self.config.num_train_timesteps + + def time_shift_dynamic(self, mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + def time_shift(self, t: torch.Tensor): + return self.config.shift * t / (1 + (self.config.shift - 1) * t) + + def set_timesteps( + self, + num_inference_steps: int = None, + device: Union[str, torch.device] = None, + sigmas: Optional[List[float]] = None, + mu: Optional[float] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + + if self.config.use_dynamic_shifting and mu is None: + raise ValueError( + " you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`" + ) + + if sigmas is None: + self.num_inference_steps = num_inference_steps + timesteps = np.array( + [ + (1.0 - i / num_inference_steps) * self.config.num_train_timesteps + for i in range(num_inference_steps) + ] + ) # different from the original code in SD3 + sigmas = timesteps / self.config.num_train_timesteps + + if self.config.use_dynamic_shifting: + sigmas = self.time_shift_dynamic(mu, 1.0, sigmas) + else: + sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) + + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) + timesteps = sigmas * self.config.num_train_timesteps + + self.timesteps = timesteps.to(device=device) + self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + + self._step_index = None + self._begin_index = None + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + s_churn: float = 0.0, + s_tmin: float = 0.0, + s_tmax: float = float("inf"), + s_noise: float = 1.0, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[RectifiedFlowSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + s_churn (`float`): + s_tmin (`float`): + s_tmax (`float`): + s_noise (`float`, defaults to 1.0): + Scaling factor for noise added to the sample. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or + tuple. + + Returns: + [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is + returned, otherwise a tuple is returned where the first element is the sample tensor. + """ + + if ( + isinstance(timestep, int) + or isinstance(timestep, torch.IntTensor) + or isinstance(timestep, torch.LongTensor) + ): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + + sigma = self.sigmas[self.step_index] + sigma_next = self.sigmas[self.step_index + 1] + + # Here different directions are used for the flow matching + prev_sample = sample + (sigma - sigma_next) * model_output + + # Cast sample back to model compatible dtype + prev_sample = prev_sample.to(model_output.dtype) + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return RectifiedFlowSchedulerOutput(prev_sample=prev_sample) + + def scale_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + """ + Forward function for the noise scaling in the flow matching. + """ + sigmas = self._t_to_sigma(timesteps.to(dtype=torch.float32)) + + while len(sigmas.shape) < len(original_samples.shape): + sigmas = sigmas.unsqueeze(-1) + + return (1.0 - sigmas) * original_samples + sigmas * noise + + def __len__(self): + return self.config.num_train_timesteps diff --git a/midi/utils/smoothing.py b/midi/utils/smoothing.py new file mode 100644 index 0000000000000000000000000000000000000000..fa7911ec01469c581ecd0782c4380f893e8f9880 --- /dev/null +++ b/midi/utils/smoothing.py @@ -0,0 +1,615 @@ +# -*- coding: utf-8 -*- + +""" +Utilities for smoothing the occ/sdf grids. +""" + +import logging +from typing import Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from scipy import ndimage as ndi +from scipy import sparse + +__all__ = [ + "smooth", + "smooth_constrained", + "smooth_gaussian", + "signed_distance_function", + "smooth_gpu", + "smooth_constrained_gpu", + "smooth_gaussian_gpu", + "signed_distance_function_gpu", +] + + +def _build_variable_indices(band: np.ndarray) -> np.ndarray: + num_variables = np.count_nonzero(band) + variable_indices = np.full(band.shape, -1, dtype=np.int_) + variable_indices[band] = np.arange(num_variables) + return variable_indices + + +def _buildq3d(variable_indices: np.ndarray): + """ + Builds the filterq matrix for the given variables. + """ + + num_variables = variable_indices.max() + 1 + filterq = sparse.lil_matrix((3 * num_variables, num_variables)) + + # Pad variable_indices to simplify out-of-bounds accesses + variable_indices = np.pad( + variable_indices, [(0, 1), (0, 1), (0, 1)], mode="constant", constant_values=-1 + ) + + coords = np.nonzero(variable_indices >= 0) + for count, (i, j, k) in enumerate(zip(*coords)): + + assert variable_indices[i, j, k] == count + + filterq[3 * count, count] = -2 + neighbor = variable_indices[i - 1, j, k] + if neighbor >= 0: + filterq[3 * count, neighbor] = 1 + else: + filterq[3 * count, count] += 1 + + neighbor = variable_indices[i + 1, j, k] + if neighbor >= 0: + filterq[3 * count, neighbor] = 1 + else: + filterq[3 * count, count] += 1 + + filterq[3 * count + 1, count] = -2 + neighbor = variable_indices[i, j - 1, k] + if neighbor >= 0: + filterq[3 * count + 1, neighbor] = 1 + else: + filterq[3 * count + 1, count] += 1 + + neighbor = variable_indices[i, j + 1, k] + if neighbor >= 0: + filterq[3 * count + 1, neighbor] = 1 + else: + filterq[3 * count + 1, count] += 1 + + filterq[3 * count + 2, count] = -2 + neighbor = variable_indices[i, j, k - 1] + if neighbor >= 0: + filterq[3 * count + 2, neighbor] = 1 + else: + filterq[3 * count + 2, count] += 1 + + neighbor = variable_indices[i, j, k + 1] + if neighbor >= 0: + filterq[3 * count + 2, neighbor] = 1 + else: + filterq[3 * count + 2, count] += 1 + + filterq = filterq.tocsr() + return filterq.T.dot(filterq) + + +def _buildq3d_gpu(variable_indices: torch.Tensor, chunk_size=10000): + """ + Builds the filterq matrix for the given variables on GPU, using chunking to reduce memory usage. + """ + device = variable_indices.device + num_variables = variable_indices.max().item() + 1 + + # Pad variable_indices to simplify out-of-bounds accesses + variable_indices = torch.nn.functional.pad( + variable_indices, (0, 1, 0, 1, 0, 1), mode="constant", value=-1 + ) + + coords = torch.nonzero(variable_indices >= 0) + i, j, k = coords[:, 0], coords[:, 1], coords[:, 2] + + # Function to process a chunk of data + def process_chunk(start, end): + row_indices = [] + col_indices = [] + values = [] + + for axis in range(3): + row_indices.append(3 * torch.arange(start, end, device=device) + axis) + col_indices.append( + variable_indices[i[start:end], j[start:end], k[start:end]] + ) + values.append(torch.full((end - start,), -2, device=device)) + + for offset in [-1, 1]: + if axis == 0: + neighbor = variable_indices[ + i[start:end] + offset, j[start:end], k[start:end] + ] + elif axis == 1: + neighbor = variable_indices[ + i[start:end], j[start:end] + offset, k[start:end] + ] + else: + neighbor = variable_indices[ + i[start:end], j[start:end], k[start:end] + offset + ] + + mask = neighbor >= 0 + row_indices.append( + 3 * torch.arange(start, end, device=device)[mask] + axis + ) + col_indices.append(neighbor[mask]) + values.append(torch.ones(mask.sum(), device=device)) + + # Add 1 to the diagonal for out-of-bounds neighbors + row_indices.append( + 3 * torch.arange(start, end, device=device)[~mask] + axis + ) + col_indices.append( + variable_indices[i[start:end], j[start:end], k[start:end]][~mask] + ) + values.append(torch.ones((~mask).sum(), device=device)) + + return torch.cat(row_indices), torch.cat(col_indices), torch.cat(values) + + # Process data in chunks + all_row_indices = [] + all_col_indices = [] + all_values = [] + + for start in range(0, coords.shape[0], chunk_size): + end = min(start + chunk_size, coords.shape[0]) + row_indices, col_indices, values = process_chunk(start, end) + all_row_indices.append(row_indices) + all_col_indices.append(col_indices) + all_values.append(values) + + # Concatenate all chunks + row_indices = torch.cat(all_row_indices) + col_indices = torch.cat(all_col_indices) + values = torch.cat(all_values) + + # Create sparse tensor + indices = torch.stack([row_indices, col_indices]) + filterq = torch.sparse_coo_tensor( + indices, values, (3 * num_variables, num_variables) + ) + + # Compute filterq.T @ filterq + return torch.sparse.mm(filterq.t(), filterq) + + +# Usage example: +# variable_indices = torch.tensor(...).cuda() # Your input tensor on GPU +# result = _buildq3d_gpu(variable_indices) + + +def _buildq2d(variable_indices: np.ndarray): + """ + Builds the filterq matrix for the given variables. + + Version for 2 dimensions. + """ + + num_variables = variable_indices.max() + 1 + filterq = sparse.lil_matrix((3 * num_variables, num_variables)) + + # Pad variable_indices to simplify out-of-bounds accesses + variable_indices = np.pad( + variable_indices, [(0, 1), (0, 1)], mode="constant", constant_values=-1 + ) + + coords = np.nonzero(variable_indices >= 0) + for count, (i, j) in enumerate(zip(*coords)): + assert variable_indices[i, j] == count + + filterq[2 * count, count] = -2 + neighbor = variable_indices[i - 1, j] + if neighbor >= 0: + filterq[2 * count, neighbor] = 1 + else: + filterq[2 * count, count] += 1 + + neighbor = variable_indices[i + 1, j] + if neighbor >= 0: + filterq[2 * count, neighbor] = 1 + else: + filterq[2 * count, count] += 1 + + filterq[2 * count + 1, count] = -2 + neighbor = variable_indices[i, j - 1] + if neighbor >= 0: + filterq[2 * count + 1, neighbor] = 1 + else: + filterq[2 * count + 1, count] += 1 + + neighbor = variable_indices[i, j + 1] + if neighbor >= 0: + filterq[2 * count + 1, neighbor] = 1 + else: + filterq[2 * count + 1, count] += 1 + + filterq = filterq.tocsr() + return filterq.T.dot(filterq) + + +def _jacobi( + filterq, + x0: np.ndarray, + lower_bound: np.ndarray, + upper_bound: np.ndarray, + max_iters: int = 10, + rel_tol: float = 1e-6, + weight: float = 0.5, +): + """Jacobi method with constraints.""" + + jacobi_r = sparse.lil_matrix(filterq) + shp = jacobi_r.shape + jacobi_d = 1.0 / filterq.diagonal() + jacobi_r.setdiag((0,) * shp[0]) + jacobi_r = jacobi_r.tocsr() + + x = x0 + + # We check the stopping criterion each 10 iterations + check_each = 10 + cum_rel_tol = 1 - (1 - rel_tol) ** check_each + + energy_now = np.dot(x, filterq.dot(x)) / 2 + logging.info("Energy at iter %d: %.6g", 0, energy_now) + for i in range(max_iters): + + x_1 = -jacobi_d * jacobi_r.dot(x) + x = weight * x_1 + (1 - weight) * x + + # Constraints. + x = np.maximum(x, lower_bound) + x = np.minimum(x, upper_bound) + + # Stopping criterion + if (i + 1) % check_each == 0: + # Update energy + energy_before = energy_now + energy_now = np.dot(x, filterq.dot(x)) / 2 + + logging.info("Energy at iter %d: %.6g", i + 1, energy_now) + + # Check stopping criterion + cum_rel_improvement = (energy_before - energy_now) / energy_before + if cum_rel_improvement < cum_rel_tol: + break + + return x + + +def signed_distance_function( + levelset: np.ndarray, band_radius: int +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Return the distance to the 0.5 levelset of a function, the mask of the + border (i.e., the nearest cells to the 0.5 level-set) and the mask of the + band (i.e., the cells of the function whose distance to the 0.5 level-set + is less of equal to `band_radius`). + """ + + binary_array = np.where(levelset > 0, True, False) + + # Compute the band and the border. + dist_func = ndi.distance_transform_edt + distance = np.where( + binary_array, dist_func(binary_array) - 0.5, -dist_func(~binary_array) + 0.5 + ) + border = np.abs(distance) < 1 + band = np.abs(distance) <= band_radius + + return distance, border, band + + +def signed_distance_function_iso0( + levelset: np.ndarray, band_radius: int +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Return the distance to the 0 levelset of a function, the mask of the + border (i.e., the nearest cells to the 0 level-set) and the mask of the + band (i.e., the cells of the function whose distance to the 0 level-set + is less of equal to `band_radius`). + """ + + binary_array = levelset > 0 + + # Compute the band and the border. + dist_func = ndi.distance_transform_edt + distance = np.where( + binary_array, dist_func(binary_array), -dist_func(~binary_array) + ) + border = np.zeros_like(levelset, dtype=bool) + border[:-1, :, :] |= levelset[:-1, :, :] * levelset[1:, :, :] <= 0 + border[:, :-1, :] |= levelset[:, :-1, :] * levelset[:, 1:, :] <= 0 + border[:, :, :-1] |= levelset[:, :, :-1] * levelset[:, :, 1:] <= 0 + band = np.abs(distance) <= band_radius + + return distance, border, band + + +def signed_distance_function_gpu(levelset: torch.Tensor, band_radius: int): + binary_array = (levelset > 0).float() + + # Compute distance transform + dist_pos = ( + F.max_pool3d( + -binary_array.unsqueeze(0).unsqueeze(0), kernel_size=3, stride=1, padding=1 + ) + .squeeze(0) + .squeeze(0) + + binary_array + ) + dist_neg = F.max_pool3d( + (binary_array - 1).unsqueeze(0).unsqueeze(0), kernel_size=3, stride=1, padding=1 + ).squeeze(0).squeeze(0) + (1 - binary_array) + + distance = torch.where(binary_array > 0, dist_pos - 0.5, -dist_neg + 0.5) + + # breakpoint() + + # Use levelset as distance directly + # distance = levelset + # print(distance.shape) + # Compute border and band + border = torch.abs(distance) < 1 + band = torch.abs(distance) <= band_radius + + return distance, border, band + + +def smooth_constrained( + binary_array: np.ndarray, + band_radius: int = 4, + max_iters: int = 250, + rel_tol: float = 1e-6, +) -> np.ndarray: + """ + Implementation of the smoothing method from + + "Surface Extraction from Binary Volumes with Higher-Order Smoothness" + Victor Lempitsky, CVPR10 + """ + + # # Compute the distance map, the border and the band. + logging.info("Computing distance transform...") + # distance, _, band = signed_distance_function(binary_array, band_radius) + binary_array_gpu = torch.from_numpy(binary_array).cuda() + distance, _, band = signed_distance_function_gpu(binary_array_gpu, band_radius) + distance = distance.cpu().numpy() + band = band.cpu().numpy() + + variable_indices = _build_variable_indices(band) + + # Compute filterq. + logging.info("Building matrix filterq...") + if binary_array.ndim == 3: + filterq = _buildq3d(variable_indices) + # variable_indices_gpu = torch.from_numpy(variable_indices).cuda() + # filterq_gpu = _buildq3d_gpu(variable_indices_gpu) + # filterq = filterq_gpu.cpu().numpy() + elif binary_array.ndim == 2: + filterq = _buildq2d(variable_indices) + else: + raise ValueError("binary_array.ndim not in [2, 3]") + + # Initialize the variables. + res = np.asarray(distance, dtype=np.double) + x = res[band] + upper_bound = np.where(x < 0, x, np.inf) + lower_bound = np.where(x > 0, x, -np.inf) + + upper_bound[np.abs(upper_bound) < 1] = 0 + lower_bound[np.abs(lower_bound) < 1] = 0 + + # Solve. + logging.info("Minimizing energy...") + x = _jacobi( + filterq=filterq, + x0=x, + lower_bound=lower_bound, + upper_bound=upper_bound, + max_iters=max_iters, + rel_tol=rel_tol, + ) + + res[band] = x + return res + + +def total_variation_denoising(x, weight=0.1, num_iterations=5, eps=1e-8): + diff_x = torch.diff(x, dim=0, prepend=x[:1]) + diff_y = torch.diff(x, dim=1, prepend=x[:, :1]) + diff_z = torch.diff(x, dim=2, prepend=x[:, :, :1]) + + norm = torch.sqrt(diff_x**2 + diff_y**2 + diff_z**2 + eps) + + div_x = torch.diff(diff_x / norm, dim=0, append=diff_x[-1:] / norm[-1:]) + div_y = torch.diff(diff_y / norm, dim=1, append=diff_y[:, -1:] / norm[:, -1:]) + div_z = torch.diff(diff_z / norm, dim=2, append=diff_z[:, :, -1:] / norm[:, :, -1:]) + + return x - weight * (div_x + div_y + div_z) + + +def smooth_constrained_gpu( + binary_array: torch.Tensor, + band_radius: int = 4, + max_iters: int = 250, + rel_tol: float = 1e-4, +): + distance, _, band = signed_distance_function_gpu(binary_array, band_radius) + + # Initialize variables + x = distance[band] + upper_bound = torch.where(x < 0, x, torch.tensor(float("inf"), device=x.device)) + lower_bound = torch.where(x > 0, x, torch.tensor(float("-inf"), device=x.device)) + + upper_bound[torch.abs(upper_bound) < 1] = 0 + lower_bound[torch.abs(lower_bound) < 1] = 0 + + # Define the 3D Laplacian kernel + laplacian_kernel = torch.tensor( + [ + [ + [ + [[0, 1, 0], [1, -6, 1], [0, 1, 0]], + [[1, 0, 1], [0, 0, 0], [1, 0, 1]], + [[0, 1, 0], [1, 0, 1], [0, 1, 0]], + ] + ] + ], + device=x.device, + ).float() + + laplacian_kernel = laplacian_kernel / laplacian_kernel.abs().sum() + + breakpoint() + + # Simplified Jacobi iteration + for i in range(max_iters): + # Reshape x to 5D tensor (batch, channel, depth, height, width) + x_5d = x.view(1, 1, *band.shape) + x_3d = x.view(*band.shape) + + # Apply 3D convolution + laplacian = F.conv3d(x_5d, laplacian_kernel, padding=1) + + # Reshape back to original dimensions + laplacian = laplacian.view(x.shape) + + # Use a small relaxation factor to improve stability + relaxation_factor = 0.1 + tv_weight = 0.1 + # x_new = x + relaxation_factor * laplacian + x_new = total_variation_denoising(x_3d, weight=tv_weight) + # Print laplacian min and max + # print(f"Laplacian min: {laplacian.min().item():.4f}, max: {laplacian.max().item():.4f}") + + # Apply constraints + # Reshape x_new to match the dimensions of lower_bound and upper_bound + x_new = x_new.view(x.shape) + x_new = torch.clamp(x_new, min=lower_bound, max=upper_bound) + + # Check for convergence + diff_norm = torch.norm(x_new - x) + print(diff_norm) + x_norm = torch.norm(x) + + if x_norm > 1e-8: # Avoid division by very small numbers + relative_change = diff_norm / x_norm + if relative_change < rel_tol: + break + elif diff_norm < rel_tol: # If x_norm is very small, check absolute change + break + + x = x_new + + # Check for NaN and break if found, also check for inf + if torch.isnan(x).any() or torch.isinf(x).any(): + print(f"NaN or Inf detected at iteration {i}") + breakpoint() + break + + result = distance.clone() + result[band] = x + return result + + +def smooth_gaussian(binary_array: np.ndarray, sigma: float = 3) -> np.ndarray: + vol = np.float_(binary_array) - 0.5 + return ndi.gaussian_filter(vol, sigma=sigma) + + +def smooth_gaussian_gpu(binary_array: torch.Tensor, sigma: float = 3): + # vol = binary_array.float() + vol = binary_array + kernel_size = int(2 * sigma + 1) + kernel = torch.ones( + 1, + 1, + kernel_size, + kernel_size, + kernel_size, + device=binary_array.device, + dtype=vol.dtype, + ) / (kernel_size**3) + return F.conv3d( + vol.unsqueeze(0).unsqueeze(0), kernel, padding=kernel_size // 2 + ).squeeze() + + +def smooth(binary_array: np.ndarray, method: str = "auto", **kwargs) -> np.ndarray: + """ + Smooths the 0.5 level-set of a binary array. Returns a floating-point + array with a smoothed version of the original level-set in the 0 isovalue. + + This function can apply two different methods: + + - A constrained smoothing method which preserves details and fine + structures, but it is slow and requires a large amount of memory. This + method is recommended when the input array is small (smaller than + (500, 500, 500)). + - A Gaussian filter applied over the binary array. This method is fast, but + not very precise, as it can destroy fine details. It is only recommended + when the input array is large and the 0.5 level-set does not contain + thin structures. + + Parameters + ---------- + binary_array : ndarray + Input binary array with the 0.5 level-set to smooth. + method : str, one of ['auto', 'gaussian', 'constrained'] + Smoothing method. If 'auto' is given, the method will be automatically + chosen based on the size of `binary_array`. + + Parameters for 'gaussian' + ------------------------- + sigma : float + Size of the Gaussian filter (default 3). + + Parameters for 'constrained' + ---------------------------- + max_iters : positive integer + Number of iterations of the constrained optimization method + (default 250). + rel_tol: float + Relative tolerance as a stopping criterion (default 1e-6). + + Output + ------ + res : ndarray + Floating-point array with a smoothed 0 level-set. + """ + + binary_array = np.asarray(binary_array) + + if method == "auto": + if binary_array.size > 512**3: + method = "gaussian" + else: + method = "constrained" + + if method == "gaussian": + return smooth_gaussian(binary_array, **kwargs) + + if method == "constrained": + return smooth_constrained(binary_array, **kwargs) + + raise ValueError("Unknown method '{}'".format(method)) + + +def smooth_gpu(binary_array: torch.Tensor, method: str = "auto", **kwargs): + if method == "auto": + method = "gaussian" if binary_array.numel() > 512**3 else "constrained" + + if method == "gaussian": + return smooth_gaussian_gpu(binary_array, **kwargs) + elif method == "constrained": + return smooth_constrained_gpu(binary_array, **kwargs) + else: + raise ValueError(f"Unknown method '{method}'") diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..80b7911acdb1266f9d3befd8bffc19fced7dd731 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,14 @@ +gradio==4.44.1 +gradio_litmodel3d +gradio_image_prompter +diffusers +transformers +einops +torch-cluster +huggingface_hub +opencv-python +trimesh +omegaconf +scikit-image +numpy==1.22.3 +peft \ No newline at end of file diff --git a/scripts/grounding_sam.py b/scripts/grounding_sam.py new file mode 100644 index 0000000000000000000000000000000000000000..edf499af7690e86b38fa5a0fe4fcd8c189030ca0 --- /dev/null +++ b/scripts/grounding_sam.py @@ -0,0 +1,363 @@ +# Adapted from https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb + +import argparse +import os +import random +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import cv2 +import numpy as np +import requests +import torch +from PIL import Image +from transformers import AutoModelForMaskGeneration, AutoProcessor, pipeline + + +def create_palette(): + # Define a palette with 24 colors for labels 0-23 (example colors) + palette = [ + 0, + 0, + 0, # Label 0 (black) + 255, + 0, + 0, # Label 1 (red) + 0, + 255, + 0, # Label 2 (green) + 0, + 0, + 255, # Label 3 (blue) + 255, + 255, + 0, # Label 4 (yellow) + 255, + 0, + 255, # Label 5 (magenta) + 0, + 255, + 255, # Label 6 (cyan) + 128, + 0, + 0, # Label 7 (dark red) + 0, + 128, + 0, # Label 8 (dark green) + 0, + 0, + 128, # Label 9 (dark blue) + 128, + 128, + 0, # Label 10 + 128, + 0, + 128, # Label 11 + 0, + 128, + 128, # Label 12 + 64, + 0, + 0, # Label 13 + 0, + 64, + 0, # Label 14 + 0, + 0, + 64, # Label 15 + 64, + 64, + 0, # Label 16 + 64, + 0, + 64, # Label 17 + 0, + 64, + 64, # Label 18 + 192, + 192, + 192, # Label 19 (light gray) + 128, + 128, + 128, # Label 20 (gray) + 255, + 165, + 0, # Label 21 (orange) + 75, + 0, + 130, # Label 22 (indigo) + 238, + 130, + 238, # Label 23 (violet) + ] + # Extend the palette to have 768 values (256 * 3) + palette.extend([0] * (768 - len(palette))) + return palette + + +PALETTE = create_palette() + + +# Result Utils +@dataclass +class BoundingBox: + xmin: int + ymin: int + xmax: int + ymax: int + + @property + def xyxy(self) -> List[float]: + return [self.xmin, self.ymin, self.xmax, self.ymax] + + +@dataclass +class DetectionResult: + score: Optional[float] = None + label: Optional[str] = None + box: Optional[BoundingBox] = None + mask: Optional[np.array] = None + + @classmethod + def from_dict(cls, detection_dict: Dict) -> "DetectionResult": + return cls( + score=detection_dict["score"], + label=detection_dict["label"], + box=BoundingBox( + xmin=detection_dict["box"]["xmin"], + ymin=detection_dict["box"]["ymin"], + xmax=detection_dict["box"]["xmax"], + ymax=detection_dict["box"]["ymax"], + ), + ) + + +# Utils +def mask_to_polygon(mask: np.ndarray) -> List[List[int]]: + # Find contours in the binary mask + contours, _ = cv2.findContours( + mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE + ) + + # Find the contour with the largest area + largest_contour = max(contours, key=cv2.contourArea) + + # Extract the vertices of the contour + polygon = largest_contour.reshape(-1, 2).tolist() + + return polygon + + +def polygon_to_mask( + polygon: List[Tuple[int, int]], image_shape: Tuple[int, int] +) -> np.ndarray: + """ + Convert a polygon to a segmentation mask. + + Args: + - polygon (list): List of (x, y) coordinates representing the vertices of the polygon. + - image_shape (tuple): Shape of the image (height, width) for the mask. + + Returns: + - np.ndarray: Segmentation mask with the polygon filled. + """ + # Create an empty mask + mask = np.zeros(image_shape, dtype=np.uint8) + + # Convert polygon to an array of points + pts = np.array(polygon, dtype=np.int32) + + # Fill the polygon with white color (255) + cv2.fillPoly(mask, [pts], color=(255,)) + + return mask + + +def load_image(image_str: str) -> Image.Image: + if image_str.startswith("http"): + image = Image.open(requests.get(image_str, stream=True).raw).convert("RGB") + else: + image = Image.open(image_str).convert("RGB") + + return image + + +def get_boxes(results: DetectionResult) -> List[List[List[float]]]: + boxes = [] + for result in results: + xyxy = result.box.xyxy + boxes.append(xyxy) + + return [boxes] + + +def refine_masks( + masks: torch.BoolTensor, polygon_refinement: bool = False +) -> List[np.ndarray]: + masks = masks.cpu().float() + masks = masks.permute(0, 2, 3, 1) + masks = masks.mean(axis=-1) + masks = (masks > 0).int() + masks = masks.numpy().astype(np.uint8) + masks = list(masks) + + if polygon_refinement: + for idx, mask in enumerate(masks): + shape = mask.shape + polygon = mask_to_polygon(mask) + mask = polygon_to_mask(polygon, shape) + masks[idx] = mask + + return masks + + +# Post-processing Utils +def generate_colored_segmentation(label_image): + # Create a PIL Image from the label image (assuming it's a 2D numpy array) + label_image_pil = Image.fromarray(label_image.astype(np.uint8), mode="P") + + # Apply the palette to the image + palette = create_palette() + label_image_pil.putpalette(palette) + + return label_image_pil + + +def plot_segmentation(image, detections): + seg_map = np.zeros(image.size[::-1], dtype=np.uint8) + for i, detection in enumerate(detections): + mask = detection.mask + seg_map[mask > 0] = i + 1 + seg_map_pil = generate_colored_segmentation(seg_map) + return seg_map_pil + + +# Grounded SAM +def prepare_model( + device: str = "cuda", + detector_id: Optional[str] = None, + segmenter_id: Optional[str] = None, +): + detector_id = ( + detector_id if detector_id is not None else "IDEA-Research/grounding-dino-tiny" + ) + object_detector = pipeline( + model=detector_id, task="zero-shot-object-detection", device=device + ) + + segmenter_id = segmenter_id if segmenter_id is not None else "facebook/sam-vit-base" + processor = AutoProcessor.from_pretrained(segmenter_id) + segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to(device) + + return object_detector, processor, segmentator + + +def detect( + object_detector: Any, + image: Image.Image, + labels: List[str], + threshold: float = 0.3, +) -> List[Dict[str, Any]]: + """ + Use Grounding DINO to detect a set of labels in an image in a zero-shot fashion. + """ + labels = [label if label.endswith(".") else label + "." for label in labels] + + results = object_detector(image, candidate_labels=labels, threshold=threshold) + results = [DetectionResult.from_dict(result) for result in results] + + return results + + +def segment( + processor: Any, + segmentator: Any, + image: Image.Image, + boxes: Optional[List[List[List[float]]]] = None, + detection_results: Optional[List[Dict[str, Any]]] = None, + polygon_refinement: bool = False, +) -> List[DetectionResult]: + """ + Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes. + """ + if detection_results is None and boxes is None: + raise ValueError( + "Either detection_results or detection_boxes must be provided." + ) + + if boxes is None: + boxes = get_boxes(detection_results) + + inputs = processor(images=image, input_boxes=boxes, return_tensors="pt").to( + segmentator.device, segmentator.dtype + ) + + outputs = segmentator(**inputs) + masks = processor.post_process_masks( + masks=outputs.pred_masks, + original_sizes=inputs.original_sizes, + reshaped_input_sizes=inputs.reshaped_input_sizes, + )[0] + + masks = refine_masks(masks, polygon_refinement) + + if detection_results is None: + detection_results = [DetectionResult() for _ in masks] + + for detection_result, mask in zip(detection_results, masks): + detection_result.mask = mask + + return detection_results + + +def grounded_segmentation( + object_detector, + processor, + segmentator, + image: Union[Image.Image, str], + labels: Union[str, List[str]], + threshold: float = 0.3, + polygon_refinement: bool = False, +) -> Tuple[np.ndarray, List[DetectionResult], Image.Image]: + if isinstance(image, str): + image = load_image(image) + if isinstance(labels, str): + labels = labels.split(",") + + detections = detect(object_detector, image, labels, threshold) + detections = segment(processor, segmentator, image, detections, polygon_refinement) + + seg_map_pil = plot_segmentation(image, detections) + + return np.array(image), detections, seg_map_pil + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--image", type=str, required=True) + parser.add_argument("--labels", type=str, nargs="+", required=True) + parser.add_argument("--output", type=str, default="./", help="Output directory") + parser.add_argument("--threshold", type=float, default=0.3) + parser.add_argument( + "--detector_id", type=str, default="IDEA-Research/grounding-dino-base" + ) + parser.add_argument("--segmenter_id", type=str, default="facebook/sam-vit-base") + args = parser.parse_args() + + device = "cuda" if torch.cuda.is_available() else "cpu" + object_detector, processor, segmentator = prepare_model( + device=device, detector_id=args.detector_id, segmenter_id=args.segmenter_id + ) + + image_array, detections, seg_map_pil = grounded_segmentation( + object_detector, + processor, + segmentator, + image=args.image, + labels=args.labels, + threshold=args.threshold, + polygon_refinement=True, + ) + + os.makedirs(args.output, exist_ok=True) + seg_map_pil.save(os.path.join(args.output, "segmentation.png")) diff --git a/scripts/inference_midi.py b/scripts/inference_midi.py new file mode 100644 index 0000000000000000000000000000000000000000..1a266d278cd5b5ddf9b94668267caafc27dbf52b --- /dev/null +++ b/scripts/inference_midi.py @@ -0,0 +1,184 @@ +import argparse +import os +from glob import glob +from typing import Any, List, Union + +import gradio as gr +import numpy as np +import torch +import trimesh +from huggingface_hub import snapshot_download +from PIL import Image, ImageOps +from skimage import measure + +from midi.pipelines.pipeline_midi import MIDIPipeline +from midi.utils.smoothing import smooth_gpu + + +def preprocess_image(rgb_image, seg_image): + if isinstance(rgb_image, str): + rgb_image = Image.open(rgb_image) + if isinstance(seg_image, str): + seg_image = Image.open(seg_image) + rgb_image = rgb_image.convert("RGB") + seg_image = seg_image.convert("L") + + width, height = rgb_image.size + + seg_np = np.array(seg_image) + rows, cols = np.where(seg_np > 0) + if rows.size == 0 or cols.size == 0: + return rgb_image, seg_image + + # compute the bounding box of combined instances + min_row, max_row = min(rows), max(rows) + min_col, max_col = min(cols), max(cols) + L = max( + max(abs(max_row - width // 2), abs(min_row - width // 2)) * 2, + max(abs(max_col - height // 2), abs(min_col - height // 2)) * 2, + ) + + # pad the image + if L > width * 0.8: + width = int(L / 4 * 5) + if L > height * 0.8: + height = int(L / 4 * 5) + rgb_new = Image.new("RGB", (width, height), (255, 255, 255)) + seg_new = Image.new("L", (width, height), 0) + x_offset = (width - rgb_image.size[0]) // 2 + y_offset = (height - rgb_image.size[1]) // 2 + rgb_new.paste(rgb_image, (x_offset, y_offset)) + seg_new.paste(seg_image, (x_offset, y_offset)) + + # pad to the square + max_dim = max(width, height) + rgb_new = ImageOps.expand( + rgb_new, border=(0, 0, max_dim - width, max_dim - height), fill="white" + ) + seg_new = ImageOps.expand( + seg_new, border=(0, 0, max_dim - width, max_dim - height), fill=0 + ) + + return rgb_new, seg_new + + +def split_rgb_mask(rgb_image, seg_image): + if isinstance(rgb_image, str): + rgb_image = Image.open(rgb_image) + if isinstance(seg_image, str): + seg_image = Image.open(seg_image) + rgb_image = rgb_image.convert("RGB") + seg_image = seg_image.convert("L") + + rgb_array = np.array(rgb_image) + seg_array = np.array(seg_image) + + label_ids = np.unique(seg_array) + label_ids = label_ids[label_ids > 0] + + instance_rgbs, instance_masks, scene_rgbs = [], [], [] + + for segment_id in sorted(label_ids): + # Here we set the background to white + white_background = np.ones_like(rgb_array) * 255 + + mask = np.zeros_like(seg_array, dtype=np.uint8) + mask[seg_array == segment_id] = 255 + segment_rgb = white_background.copy() + segment_rgb[mask == 255] = rgb_array[mask == 255] + + segment_rgb_image = Image.fromarray(segment_rgb) + segment_mask_image = Image.fromarray(mask) + instance_rgbs.append(segment_rgb_image) + instance_masks.append(segment_mask_image) + scene_rgbs.append(rgb_image) + + return instance_rgbs, instance_masks, scene_rgbs + + +@torch.no_grad() +def run_midi( + pipe: Any, + rgb_image: Union[str, Image.Image], + seg_image: Union[str, Image.Image], + seed: int, + num_inference_steps: int = 50, + guidance_scale: float = 7.0, + do_image_padding: bool = False, +) -> trimesh.Scene: + if do_image_padding: + rgb_image, seg_image = preprocess_image(rgb_image, seg_image) + instance_rgbs, instance_masks, scene_rgbs = split_rgb_mask(rgb_image, seg_image) + + num_instances = len(instance_rgbs) + outputs = pipe( + image=instance_rgbs, + mask=instance_masks, + image_scene=scene_rgbs, + attention_kwargs={"num_instances": num_instances}, + generator=torch.Generator(device=pipe.device).manual_seed(seed), + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + decode_progressive=True, + return_dict=False, + ) + + # marching cubes + trimeshes = [] + for _, (logits_, grid_size, bbox_size, bbox_min, bbox_max) in enumerate( + zip(*outputs) + ): + grid_logits = logits_.view(grid_size) + grid_logits = smooth_gpu(grid_logits, method="gaussian", sigma=1) + torch.cuda.empty_cache() + vertices, faces, normals, _ = measure.marching_cubes( + grid_logits.float().cpu().numpy(), 0, method="lewiner" + ) + vertices = vertices / grid_size * bbox_size + bbox_min + + # Trimesh + mesh = trimesh.Trimesh(vertices.astype(np.float32), np.ascontiguousarray(faces)) + trimeshes.append(mesh) + + # compose the output meshes + scene = trimesh.Scene(trimeshes) + + return scene + + +if __name__ == "__main__": + device = "cuda" + dtype = torch.bfloat16 + + parser = argparse.ArgumentParser() + parser.add_argument("--rgb", type=str, required=True) + parser.add_argument("--seg", type=str, required=True) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--num-inference-steps", type=int, default=50) + parser.add_argument("--guidance-scale", type=float, default=7.0) + parser.add_argument("--do-image-padding", action="store_true") + parser.add_argument("--output-dir", type=str, default="./") + args = parser.parse_args() + + local_dir = "pretrained_weights/MIDI-3D" + snapshot_download(repo_id="VAST-AI/MIDI-3D", local_dir=local_dir) + pipe: MIDIPipeline = MIDIPipeline.from_pretrained(local_dir).to(device, dtype) + pipe.init_custom_adapter( + set_self_attn_module_names=[ + "blocks.8", + "blocks.9", + "blocks.10", + "blocks.11", + "blocks.12", + ] + ) + + run_midi( + pipe, + rgb_image=args.rgb, + seg_image=args.seg, + seed=args.seed, + num_inference_steps=args.num_inference_steps, + guidance_scale=args.guidance_scale, + do_image_padding=args.do_image_padding, + ).export(os.path.join(args.output_dir, "output.glb"))