jkorstad's picture
Correctly add UniRig source files
f499d3b
import os
import sys
import weakref
import torch
torch.multiprocessing.set_start_method('spawn')
import torch.nn as nn
import torch.utils.data
from functools import partial
if sys.version_info >= (3, 10):
from collections.abc import Iterator
else:
from collections import Iterator
from tensorboardX import SummaryWriter
from .defaults import create_ddp_model, worker_init_fn
from .hooks import HookBase, build_hooks
import pointcept.utils.comm as comm
from pointcept.datasets import build_dataset, point_collate_fn, collate_fn
from pointcept.models import build_model
from pointcept.utils.logger import get_root_logger
from pointcept.utils.optimizer import build_optimizer
from pointcept.utils.scheduler import build_scheduler
from pointcept.utils.events import EventStorage
from pointcept.utils.registry import Registry
from sklearn.preprocessing import QuantileTransformer
from pointcept.utils.timer import Timer
TRAINERS = Registry("trainers")
from cuml.cluster.hdbscan import HDBSCAN
# from sklearn.cluster import HDBSCAN
import open3d as o3d
import matplotlib.colors as mcolors
import numpy as np
from collections import OrderedDict
import trimesh
import pointops
class TrainerBase:
def __init__(self) -> None:
self.hooks = []
self.epoch = 0
self.start_epoch = 0
self.max_epoch = 0
self.max_iter = 0
self.comm_info = dict()
self.data_iterator: Iterator = enumerate([])
self.storage: EventStorage
self.writer: SummaryWriter
self._iter_timer = Timer()
def register_hooks(self, hooks) -> None:
hooks = build_hooks(hooks)
for h in hooks:
assert isinstance(h, HookBase)
# To avoid circular reference, hooks and trainer cannot own each other.
# This normally does not matter, but will cause memory leak if the
# involved objects contain __del__:
# See http://engineering.hearsaysocial.com/2013/06/16/circular-references-in-python/
h.trainer = weakref.proxy(self)
self.hooks.extend(hooks)
def train(self):
with EventStorage() as self.storage:
# => before train
self.before_train()
for self.epoch in range(self.start_epoch, self.max_epoch):
# => before epoch
self.before_epoch()
# => run_epoch
for (
self.comm_info["iter"],
self.comm_info["input_dict"],
) in self.data_iterator:
# => before_step
self.before_step()
# => run_step
self.run_step()
# => after_step
self.after_step()
# => after epoch
self.after_epoch()
# => after train
self.after_train()
def before_train(self):
for h in self.hooks:
h.before_train()
def before_epoch(self):
for h in self.hooks:
h.before_epoch()
def before_step(self):
for h in self.hooks:
h.before_step()
def run_step(self):
raise NotImplementedError
def after_step(self):
for h in self.hooks:
h.after_step()
def after_epoch(self):
for h in self.hooks:
h.after_epoch()
self.storage.reset_histories()
def after_train(self):
# Sync GPU before running train hooks
comm.synchronize()
for h in self.hooks:
h.after_train()
if comm.is_main_process():
self.writer.close()
@TRAINERS.register_module("DefaultTrainer")
class Trainer(TrainerBase):
def __init__(self, cfg):
super(Trainer, self).__init__()
self.epoch = 0
self.start_epoch = 0
self.max_epoch = cfg.eval_epoch
self.best_metric_value = -torch.inf
self.logger = get_root_logger(
log_file=os.path.join(cfg.save_path, "train.log"),
# file_mode="a" if cfg.resume else "w",
file_mode="a",
)
self.logger.info("=> Loading config ...")
self.cfg = cfg
self.logger.info(f"Save path: {cfg.save_path}")
self.logger.info(f"Config:\n{cfg.pretty_text}")
self.logger.info("=> Building model ...")
self.model = self.build_model()
self.logger.info("=> Building val dataset & dataloader ...")
self.train_loader = self.build_train_loader()
self.logger.info("=> Building hooks ...")
self.register_hooks(self.cfg.hooks)
# !!!
self.val_scales_list = self.cfg.val_scales_list
self.mesh_voting = self.cfg.mesh_voting
self.backbone_weight_path = self.cfg.backbone_weight_path
def eval(self):
# val_data = build_dataset(self.cfg.data.val)
self.logger.info("=> Loading checkpoint & weight ...")
if self.backbone_weight_path != None:
self.logger.info("=> Loading checkpoint of pretrained backbone")
if os.path.isfile(self.backbone_weight_path):
checkpoint = torch.load(
self.backbone_weight_path,
map_location=lambda storage, loc: storage.cuda(),
)
weight = OrderedDict()
for key, value in checkpoint["state_dict"].items():
if not key.startswith("module."):
if comm.get_world_size() > 1:
key = "module." + key # xxx.xxx -> module.xxx.xxx
# Now all keys contain "module." no matter DDP or not.
# if self.keywords in key:
# key = key.replace(self.keywords, self.replacement)
if comm.get_world_size() == 1:
key = key[7:] # module.xxx.xxx -> xxx.xxx
# if key.startswith("backbone."):
# key = key[9:] # backbone.xxx.xxx -> xxx.xxx
key = "backbone." + key # xxx.xxx -> backbone.xxx.xxx
weight[key] = value
load_state_info = self.model.load_state_dict(weight, strict=False)
self.logger.info(f"Missing keys: {load_state_info[0]}")
else:
self.logger.info(f"No weight found at: {self.backbone_weight_path}")
if self.cfg.weight and os.path.isfile(self.cfg.weight):
checkpoint = torch.load(
self.cfg.weight,
map_location=lambda storage, loc: storage.cuda(),
)
load_state_info = self.model.load_state_dict(checkpoint["state_dict"], strict=False)
self.logger.info(f"Missing keys: {load_state_info[0]}")
scale_statistics = checkpoint["state_dict"]["scale_statistics"]
self.model.quantile_transformer = self._get_quantile_func(scale_statistics)
else:
self.logger.info(f"No weight found at: {self.cfg.weight}")
self.cfg.weight = "last"
self.model.eval()
save_root = os.path.join(self.cfg.save_path, "vis_pcd", os.path.splitext(os.path.basename(self.cfg.weight))[0])
os.makedirs(save_root, exist_ok=True)
group_save_root = os.path.join(self.cfg.save_path, "results", os.path.splitext(os.path.basename(self.cfg.weight))[0])
os.makedirs(group_save_root, exist_ok=True)
hex_colors = list(mcolors.CSS4_COLORS.values())
rgb_colors = np.array([mcolors.to_rgb(color) for color in hex_colors if color not in ['#000000', '#FFFFFF']])
def relative_luminance(color):
return 0.2126 * color[0] + 0.7152 * color[1] + 0.0722 * color[2]
rgb_colors = [color for color in rgb_colors if (relative_luminance(color) > 0.4 and relative_luminance(color) < 0.8)]
np.random.shuffle(rgb_colors)
input_dict = self.train_loader.val_data()
pcd_inverse = self.train_loader.pcd_inverse
if self.mesh_voting:
mesh = trimesh.load(self.train_loader.mesh_path)
if isinstance(mesh, trimesh.Scene):
mesh = mesh.dump(concatenate=True)
mesh.visual = trimesh.visual.ColorVisuals(mesh=mesh)
for scale in self.val_scales_list:
input_dict["scale"] = scale
instance_feat = self.model(input_dict).cpu().detach().numpy()
clusterer = HDBSCAN(
cluster_selection_epsilon=0.1,
min_samples=30,
min_cluster_size=30,
allow_single_cluster=False,
).fit(instance_feat)
labels = clusterer.labels_
invalid_label_mask = labels == -1
if invalid_label_mask.sum() > 0:
if invalid_label_mask.sum() == len(invalid_label_mask):
labels = np.zeros_like(labels)
else:
coord = input_dict["obj"]["coord"].cuda().contiguous().float()
valid_coord = coord[~invalid_label_mask]
valid_offset = torch.tensor(valid_coord.shape[0]).cuda()
invalid_coord = coord[invalid_label_mask]
invalid_offset = torch.tensor(invalid_coord.shape[0]).cuda()
indices, distances = pointops.knn_query(1, valid_coord, valid_offset, invalid_coord, invalid_offset)
indices = indices[:, 0].cpu().numpy()
labels[invalid_label_mask] = labels[~invalid_label_mask][indices]
# np.save(os.path.join(group_save_root, f"{str(scale)}.npy"), labels)
save_path = os.path.join(save_root, f"{str(scale)}.ply")
coord = input_dict["obj"]["coord"].cpu().numpy()
random_color = []
for i in range(max(labels) + 1):
random_color.append(rgb_colors[i % len(rgb_colors)])
random_color.append(np.array([0, 0, 0]))
color = [random_color[i] for i in labels]
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(coord)
pcd.colors = o3d.utility.Vector3dVector(color)
o3d.io.write_point_cloud(save_path, pcd)
labels = labels[pcd_inverse]
# print(len(clusterer.labels_))
self.logger.info(f"scale_{scale} has {max(labels)+1} groups")
if self.mesh_voting:
face_index = self.train_loader.face_index
face_index = face_index[pcd_inverse]
# Compute votes for each face using NumPy's bincount function
# labels = clusterer.labels_
num_faces = len(mesh.faces)
num_labels = max(labels) + 1
votes = np.zeros((num_faces, num_labels), dtype=np.int32)
np.add.at(votes, (face_index, labels), 1)
# Find the label with most votes for each face using NumPy's argmax function
max_votes_labels = np.argmax(votes, axis=1)
# Set the label to -1 for faces that have no corresponding points
max_votes_labels[np.all(votes == 0, axis=1)] = -1
valid_mask = max_votes_labels != -1
face_centroids = mesh.triangles_center
coord = torch.tensor(face_centroids).cuda().contiguous().float()
valid_coord = coord[valid_mask]
valid_offset = torch.tensor(valid_coord.shape[0]).cuda()
invalid_coord = coord[~valid_mask]
invalid_offset = torch.tensor(invalid_coord.shape[0]).cuda()
indices, distances = pointops.knn_query(1, valid_coord, valid_offset, invalid_coord, invalid_offset)
# # the first column is the point itself
# indices = indices[:, 1].cpu().numpy()
indices = indices[:, 0].cpu().numpy()
mesh_group = max_votes_labels.copy()
mesh_group[~valid_mask] = mesh_group[valid_mask][indices]
np.save(os.path.join(group_save_root, f"mesh_{str(scale)}.npy"), mesh_group)
# Assign color to each face based on the label with most votes
for face, label in enumerate(mesh_group):
color = (random_color[label] * 255).astype(np.uint8)
color_with_alpha = np.append(color, 255) # Add alpha value
mesh.visual.face_colors[face] = color_with_alpha
# Save the new mesh
mesh_save_path = os.path.join(save_root, f"mesh_{str(scale)}.ply")
mesh.export(mesh_save_path)
def _get_quantile_func(self, scales: torch.Tensor, distribution="normal"):
"""
Use 3D scale statistics to normalize scales -- use quantile transformer.
"""
scales = scales.flatten()
max_grouping_scale = 2
scales = scales[(scales > 0) & (scales < max_grouping_scale)]
scales = scales.detach().cpu().numpy()
# Calculate quantile transformer
quantile_transformer = QuantileTransformer(output_distribution=distribution)
quantile_transformer = quantile_transformer.fit(scales.reshape(-1, 1))
def quantile_transformer_func(scales):
# This function acts as a wrapper for QuantileTransformer.
# QuantileTransformer expects a numpy array, while we have a torch tensor.
return torch.Tensor(
quantile_transformer.transform(scales.cpu().numpy())
).to(scales.device)
return quantile_transformer_func
def run_step(self):
input_dict = self.comm_info["input_dict"]
for key in input_dict.keys():
if isinstance(input_dict[key], torch.Tensor):
input_dict[key] = input_dict[key].cuda(non_blocking=True)
with torch.cuda.amp.autocast(enabled=self.cfg.enable_amp):
output_dict = self.model(input_dict)
loss = output_dict["loss"]
self.optimizer.zero_grad()
if self.cfg.enable_amp:
self.scaler.scale(loss).backward()
self.scaler.step(self.optimizer)
# When enable amp, optimizer.step call are skipped if the loss scaling factor is too large.
# Fix torch warning scheduler step before optimizer step.
scaler = self.scaler.get_scale()
self.scaler.update()
if scaler <= self.scaler.get_scale():
self.scheduler.step()
else:
loss.backward()
self.optimizer.step()
self.scheduler.step()
if self.cfg.empty_cache:
torch.cuda.empty_cache()
self.comm_info["model_output_dict"] = output_dict
def build_model(self):
model = build_model(self.cfg.model)
if self.cfg.sync_bn:
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
# logger.info(f"Model: \n{self.model}")
self.logger.info(f"Num params: {n_parameters}")
model = create_ddp_model(
model.cuda(),
broadcast_buffers=False,
find_unused_parameters=self.cfg.find_unused_parameters,
)
return model
def build_writer(self):
writer = SummaryWriter(self.cfg.save_path) if comm.is_main_process() else None
self.logger.info(f"Tensorboard writer logging dir: {self.cfg.save_path}")
return writer
def build_train_loader(self):
self.cfg.data.train.split = "val"
self.cfg.data.train.oid = self.cfg.oid
self.cfg.data.train.label = self.cfg.label
train_data = build_dataset(self.cfg.data.train)
return train_data
def build_val_loader(self):
val_loader = None
if self.cfg.evaluate:
val_data = build_dataset(self.cfg.data.val)
if comm.get_world_size() > 1:
val_sampler = torch.utils.data.distributed.DistributedSampler(val_data)
else:
val_sampler = None
val_loader = torch.utils.data.DataLoader(
val_data,
batch_size=self.cfg.batch_size_val_per_gpu,
shuffle=False,
num_workers=self.cfg.num_worker_per_gpu,
pin_memory=True,
sampler=val_sampler,
collate_fn=collate_fn,
)
return val_loader
def build_optimizer(self):
return build_optimizer(self.cfg.optimizer, self.model, self.cfg.param_dicts)
def build_scheduler(self):
assert hasattr(self, "optimizer")
assert hasattr(self, "train_loader")
# self.cfg.scheduler.total_steps = len(self.train_loader) * self.cfg.eval_epoch
self.cfg.scheduler.total_steps = self.max_epoch
return build_scheduler(self.cfg.scheduler, self.optimizer)
def build_scaler(self):
scaler = torch.cuda.amp.GradScaler() if self.cfg.enable_amp else None
return scaler