|
from dataclasses import dataclass |
|
import os |
|
import sys |
|
import torch |
|
import trimesh |
|
from torch import nn |
|
from transformers import AutoModelForCausalLM |
|
from transformers.generation.logits_process import LogitsProcessorList |
|
from einops import rearrange |
|
|
|
from modules.bbox_gen.models.image_encoder import DINOv2ImageEncoder |
|
from modules.bbox_gen.config import parse_structured |
|
from modules.bbox_gen.models.bboxopt import BBoxOPT, BBoxOPTConfig |
|
from modules.bbox_gen.utils.bbox_tokenizer import BoundsTokenizerDiag |
|
from modules.bbox_gen.models.bbox_gen_models import GroupEmbedding, MultiModalProjector, MeshDecodeLogitsProcessor, SparseStructureEncoder |
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
modules_dir = os.path.dirname(os.path.dirname(current_dir)) |
|
partfield_dir = os.path.join(modules_dir, 'PartField') |
|
if partfield_dir not in sys.path: |
|
sys.path.insert(0, partfield_dir) |
|
import importlib.util |
|
from partfield.config import default_argument_parser, setup |
|
|
|
|
|
class BboxGen(nn.Module): |
|
|
|
@dataclass |
|
class Config: |
|
|
|
encoder_dim_feat: int = 3 |
|
encoder_dim: int = 64 |
|
encoder_heads: int = 4 |
|
encoder_token_num: int = 256 |
|
encoder_qkv_bias: bool = False |
|
encoder_use_ln_post: bool = True |
|
encoder_use_checkpoint: bool = False |
|
encoder_num_embed_freqs: int = 8 |
|
encoder_embed_include_pi: bool = False |
|
encoder_init_scale: float = 0.25 |
|
encoder_random_fps: bool = True |
|
encoder_learnable_query: bool = False |
|
encoder_layers: int = 4 |
|
group_embedding_dim: int = 64 |
|
|
|
|
|
vocab_size: int = 518 |
|
decoder_hidden_size: int = 1536 |
|
decoder_num_hidden_layers: int = 24 |
|
decoder_ffn_dim: int = 6144 |
|
decoder_heads: int = 16 |
|
decoder_use_flash_attention: bool = True |
|
decoder_gradient_checkpointing: bool = True |
|
|
|
|
|
bins: int = 64 |
|
BOS_id: int = 64 |
|
EOS_id: int = 65 |
|
PAD_id: int = 66 |
|
max_length: int = 2187 |
|
voxel_token_length: int = 1886 |
|
voxel_token_placeholder: int = -1 |
|
|
|
|
|
max_group_size: int = 50 |
|
|
|
|
|
partfield_encoder_path: str = "" |
|
|
|
cfg: Config |
|
|
|
def __init__(self, cfg): |
|
super().__init__() |
|
self.cfg = parse_structured(self.Config, cfg) |
|
|
|
self.image_encoder = DINOv2ImageEncoder( |
|
model_name="facebook/dinov2-with-registers-large", |
|
) |
|
|
|
self.image_projector = MultiModalProjector( |
|
in_features=(1024 + self.cfg.group_embedding_dim), |
|
out_features=self.cfg.decoder_hidden_size, |
|
) |
|
|
|
self.group_embedding = GroupEmbedding( |
|
max_group_size=self.cfg.max_group_size, |
|
hidden_size=self.cfg.group_embedding_dim, |
|
) |
|
|
|
self.decoder_config = BBoxOPTConfig( |
|
vocab_size=self.cfg.vocab_size, |
|
hidden_size=self.cfg.decoder_hidden_size, |
|
num_hidden_layers=self.cfg.decoder_num_hidden_layers, |
|
ffn_dim=self.cfg.decoder_ffn_dim, |
|
max_position_embeddings=self.cfg.max_length, |
|
num_attention_heads=self.cfg.decoder_heads, |
|
pad_token_id=self.cfg.PAD_id, |
|
bos_token_id=self.cfg.BOS_id, |
|
eos_token_id=self.cfg.EOS_id, |
|
use_cache=True, |
|
init_std=0.02, |
|
) |
|
|
|
if self.cfg.decoder_use_flash_attention: |
|
self.decoder: BBoxOPT = AutoModelForCausalLM.from_config( |
|
self.decoder_config, |
|
torch_dtype=torch.bfloat16, |
|
attn_implementation="flash_attention_2" |
|
) |
|
else: |
|
self.decoder: BBoxOPT = AutoModelForCausalLM.from_config( |
|
self.decoder_config, |
|
) |
|
if self.cfg.decoder_gradient_checkpointing: |
|
self.decoder.gradient_checkpointing_enable() |
|
|
|
self.logits_processor = LogitsProcessorList() |
|
|
|
self.logits_processor.append(MeshDecodeLogitsProcessor( |
|
bins=self.cfg.bins, |
|
BOS_id=self.cfg.BOS_id, |
|
EOS_id=self.cfg.EOS_id, |
|
PAD_id=self.cfg.PAD_id, |
|
vertices_num=2, |
|
)) |
|
self.tokenizer = BoundsTokenizerDiag( |
|
bins=self.cfg.bins, |
|
BOS_id=self.cfg.BOS_id, |
|
EOS_id=self.cfg.EOS_id, |
|
PAD_id=self.cfg.PAD_id, |
|
) |
|
|
|
self._load_partfield_encoder() |
|
|
|
self.partfield_voxel_encoder = SparseStructureEncoder( |
|
in_channels=451, |
|
channels=[448, 448, 448, 1024], |
|
latent_channels=448, |
|
num_res_blocks=1, |
|
num_res_blocks_middle=1, |
|
norm_type="layer", |
|
) |
|
|
|
|
|
def _load_partfield_encoder(self): |
|
|
|
model_spec = importlib.util.spec_from_file_location( |
|
"partfield.partfield_encoder", |
|
os.path.join(partfield_dir, "partfield", "partfield_encoder.py") |
|
) |
|
model_module = importlib.util.module_from_spec(model_spec) |
|
model_spec.loader.exec_module(model_module) |
|
Model = model_module.Model |
|
parser = default_argument_parser() |
|
args = [] |
|
args.extend(["-c", os.path.join(partfield_dir, "configs/final/demo.yaml")]) |
|
args.append("--opts") |
|
args.extend(["continue_ckpt", self.cfg.partfield_encoder_path]) |
|
parsed_args = parser.parse_args(args) |
|
cfg = setup(parsed_args, freeze=False) |
|
self.partfield_encoder = Model(cfg) |
|
self.partfield_encoder.eval() |
|
weights = torch.load(self.cfg.partfield_encoder_path)["state_dict"] |
|
self.partfield_encoder.load_state_dict(weights) |
|
for param in self.partfield_encoder.parameters(): |
|
param.requires_grad = False |
|
print("PartField encoder loaded") |
|
|
|
def _prepare_lm_inputs(self, voxel_token, input_ids): |
|
inputs_embeds = torch.zeros(input_ids.shape[0], input_ids.shape[1], self.cfg.decoder_hidden_size, device=input_ids.device, dtype=voxel_token.dtype) |
|
voxel_token_mask = (input_ids == self.cfg.voxel_token_placeholder) |
|
inputs_embeds[voxel_token_mask] = voxel_token.view(-1, self.cfg.decoder_hidden_size) |
|
|
|
inputs_embeds[~voxel_token_mask] = self.decoder.get_input_embeddings()(input_ids[~voxel_token_mask]).to(dtype=inputs_embeds.dtype) |
|
|
|
attention_mask = (input_ids != self.cfg.PAD_id) |
|
return inputs_embeds, attention_mask.long() |
|
|
|
def forward(self, batch): |
|
|
|
image_latents = self.image_encoder(batch['images']) |
|
masks = batch['masks'] |
|
masks_emb = self.group_embedding(masks) |
|
masks_emb = rearrange(masks_emb, 'b c h w -> b (h w) c') |
|
group_emb = torch.zeros((image_latents.shape[0], image_latents.shape[1], masks_emb.shape[2]), device=image_latents.device, dtype=image_latents.dtype) |
|
group_emb[:, :masks_emb.shape[1], :] = masks_emb |
|
image_latents = torch.cat([image_latents, group_emb], dim=-1) |
|
image_latents = self.image_projector(image_latents) |
|
|
|
points = batch['points'][..., :3] |
|
rot_matrix = torch.tensor([[1, 0, 0], [0, 0, -1], [0, 1, 0]], device=points.device, dtype=points.dtype) |
|
rot_points = torch.matmul(points, rot_matrix) |
|
rot_points = rot_points * (2 * 0.9) |
|
|
|
partfield_feat = self.partfield_encoder.encode(rot_points) |
|
feat_volume = torch.zeros((points.shape[0], 448, 64, 64, 64), device=partfield_feat.device, dtype=partfield_feat.dtype) |
|
whole_voxel_index = batch['whole_voxel_index'] |
|
|
|
batch_size, num_points = whole_voxel_index.shape[0], whole_voxel_index.shape[1] |
|
batch_indices = torch.arange(batch_size, device=whole_voxel_index.device).unsqueeze(1).expand(-1, num_points) |
|
batch_flat = batch_indices.flatten() |
|
x_flat = whole_voxel_index[..., 0].flatten() |
|
y_flat = whole_voxel_index[..., 1].flatten() |
|
z_flat = whole_voxel_index[..., 2].flatten() |
|
partfield_feat_flat = partfield_feat.reshape(-1, 448) |
|
feat_volume[batch_flat, :, x_flat, y_flat, z_flat] = partfield_feat_flat |
|
|
|
xyz_volume = torch.zeros((points.shape[0], 3, 64, 64, 64), device=points.device, dtype=points.dtype) |
|
xyz_volume[batch_flat, :, x_flat, y_flat, z_flat] = points.reshape(-1, 3) |
|
feat_volume = torch.cat([feat_volume, xyz_volume], dim=1) |
|
|
|
feat_volume = self.partfield_voxel_encoder(feat_volume) |
|
feat_volume = rearrange(feat_volume, 'b c x y z -> b (x y z) c') |
|
|
|
voxel_token = torch.cat([image_latents, feat_volume], dim=1) |
|
|
|
input_ids = batch['input_ids'] |
|
inputs_embeds, attention_mask = self._prepare_lm_inputs(voxel_token, input_ids) |
|
output = self.decoder( |
|
attention_mask=attention_mask, |
|
inputs_embeds=inputs_embeds, |
|
return_dict=True, |
|
) |
|
return { |
|
"logits": output.logits, |
|
} |
|
|
|
def gen_mesh_from_bounds(self, bounds, random_color): |
|
bboxes = [] |
|
for j in range(bounds.shape[0]): |
|
bbox = trimesh.primitives.Box(bounds=bounds[j]) |
|
color = random_color[j] |
|
bbox.visual.vertex_colors = color |
|
bboxes.append(bbox) |
|
mesh = trimesh.Scene(bboxes) |
|
return mesh |
|
|
|
def generate(self, batch): |
|
|
|
image_latents = self.image_encoder(batch['images']) |
|
masks = batch['masks'] |
|
masks_emb = self.group_embedding(masks) |
|
masks_emb = rearrange(masks_emb, 'b c h w -> b (h w) c') |
|
group_emb = torch.zeros((image_latents.shape[0], image_latents.shape[1], masks_emb.shape[2]), device=image_latents.device, dtype=image_latents.dtype) |
|
group_emb[:, :masks_emb.shape[1], :] = masks_emb |
|
image_latents = torch.cat([image_latents, group_emb], dim=-1) |
|
image_latents = self.image_projector(image_latents) |
|
|
|
points = batch['points'][..., :3] |
|
rot_matrix = torch.tensor([[1, 0, 0], [0, 0, -1], [0, 1, 0]], device=points.device, dtype=points.dtype) |
|
rot_points = torch.matmul(points, rot_matrix) |
|
rot_points = rot_points * (2 * 0.9) |
|
|
|
partfield_feat = self.partfield_encoder.encode(rot_points) |
|
feat_volume = torch.zeros((points.shape[0], 448, 64, 64, 64), device=partfield_feat.device, dtype=partfield_feat.dtype) |
|
whole_voxel_index = batch['whole_voxel_index'] |
|
|
|
batch_size, num_points = whole_voxel_index.shape[0], whole_voxel_index.shape[1] |
|
batch_indices = torch.arange(batch_size, device=whole_voxel_index.device).unsqueeze(1).expand(-1, num_points) |
|
batch_flat = batch_indices.flatten() |
|
x_flat = whole_voxel_index[..., 0].flatten() |
|
y_flat = whole_voxel_index[..., 1].flatten() |
|
z_flat = whole_voxel_index[..., 2].flatten() |
|
partfield_feat_flat = partfield_feat.reshape(-1, 448) |
|
feat_volume[batch_flat, :, x_flat, y_flat, z_flat] = partfield_feat_flat |
|
|
|
xyz_volume = torch.zeros((points.shape[0], 3, 64, 64, 64), device=points.device, dtype=points.dtype) |
|
xyz_volume[batch_flat, :, x_flat, y_flat, z_flat] = points.reshape(-1, 3) |
|
feat_volume = torch.cat([feat_volume, xyz_volume], dim=1) |
|
|
|
feat_volume = self.partfield_voxel_encoder(feat_volume) |
|
feat_volume = rearrange(feat_volume, 'b c x y z -> b (x y z) c') |
|
|
|
voxel_token = torch.cat([image_latents, feat_volume], dim=1) |
|
|
|
meshes = [] |
|
mesh_names = [] |
|
bboxes = [] |
|
|
|
output = self.decoder.generate( |
|
inputs_embeds=voxel_token, |
|
max_new_tokens=self.cfg.max_length - voxel_token.shape[1], |
|
logits_processor=self.logits_processor, |
|
do_sample=True, |
|
top_k=5, |
|
top_p=0.95, |
|
temperature=0.5, |
|
use_cache=True, |
|
) |
|
|
|
for i in range(output.shape[0]): |
|
bounds = self.tokenizer.decode(output[i].detach().cpu().numpy(), coord_rg=(-0.5, 0.5)) |
|
|
|
|
|
mesh_names.append("topk=5") |
|
bboxes.append(bounds) |
|
|
|
return { |
|
|
|
'mesh_names': mesh_names, |
|
'bboxes': bboxes, |
|
} |
|
|
|
|
|
|
|
|