Spaces:
Runtime error
Runtime error
NICP for SMPL-X completion
Browse files- README.md +2 -2
- apps/avatarizer.py +109 -8
- apps/infer.py +11 -9
- lib/common/local_affine.py +136 -0
- lib/common/train_util.py +1 -1
- lib/dataset/mesh_util.py +23 -1
- lib/smplx/lbs.py +72 -0
README.md
CHANGED
|
@@ -4,7 +4,7 @@
|
|
| 4 |
|
| 5 |
<h1 align="center">ECON: Explicit Clothed humans Obtained from Normals</h1>
|
| 6 |
<p align="center">
|
| 7 |
-
<a href="
|
| 8 |
·
|
| 9 |
<a href="https://ps.is.tuebingen.mpg.de/person/jyang"><strong>Jinlong Yang</strong></a>
|
| 10 |
·
|
|
@@ -28,7 +28,7 @@
|
|
| 28 |
<img src='https://img.shields.io/badge/Paper-PDF (coming soon)-green?style=for-the-badge&logo=arXiv&logoColor=green' alt='Paper PDF'>
|
| 29 |
</a>
|
| 30 |
<a href='https://xiuyuliang.cn/econ/'>
|
| 31 |
-
<img src='https://img.shields.io/badge/ECON-Page-orange?style=for-the-badge&logo=Google%20chrome&logoColor=
|
| 32 |
<a href="https://discord.gg/Vqa7KBGRyk"><img src="https://img.shields.io/discord/940240966844035082?color=7289DA&labelColor=4a64bd&logo=discord&logoColor=white&style=for-the-badge"></a>
|
| 33 |
<a href="https://youtu.be/j5hw4tsWpoY"><img alt="youtube views" title="Subscribe to my YouTube channel" src="https://img.shields.io/youtube/views/j5hw4tsWpoY?logo=youtube&labelColor=ce4630&style=for-the-badge"/></a>
|
| 34 |
</p>
|
|
|
|
| 4 |
|
| 5 |
<h1 align="center">ECON: Explicit Clothed humans Obtained from Normals</h1>
|
| 6 |
<p align="center">
|
| 7 |
+
<a href="http://xiuyuliang.cn/"><strong>Yuliang Xiu</strong></a>
|
| 8 |
·
|
| 9 |
<a href="https://ps.is.tuebingen.mpg.de/person/jyang"><strong>Jinlong Yang</strong></a>
|
| 10 |
·
|
|
|
|
| 28 |
<img src='https://img.shields.io/badge/Paper-PDF (coming soon)-green?style=for-the-badge&logo=arXiv&logoColor=green' alt='Paper PDF'>
|
| 29 |
</a>
|
| 30 |
<a href='https://xiuyuliang.cn/econ/'>
|
| 31 |
+
<img src='https://img.shields.io/badge/ECON-Page-orange?style=for-the-badge&logo=Google%20chrome&logoColor=white' alt='Project Page'></a>
|
| 32 |
<a href="https://discord.gg/Vqa7KBGRyk"><img src="https://img.shields.io/discord/940240966844035082?color=7289DA&labelColor=4a64bd&logo=discord&logoColor=white&style=for-the-badge"></a>
|
| 33 |
<a href="https://youtu.be/j5hw4tsWpoY"><img alt="youtube views" title="Subscribe to my YouTube channel" src="https://img.shields.io/youtube/views/j5hw4tsWpoY?logo=youtube&labelColor=ce4630&style=for-the-badge"/></a>
|
| 34 |
</p>
|
apps/avatarizer.py
CHANGED
|
@@ -3,12 +3,27 @@ import trimesh
|
|
| 3 |
import torch
|
| 4 |
import os.path as osp
|
| 5 |
import lib.smplx as smplx
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
from lib.dataset.mesh_util import SMPLX
|
|
|
|
| 7 |
|
| 8 |
smplx_container = SMPLX()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
-
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
for key in smplx_param.keys():
|
| 14 |
smplx_param[key] = smplx_param[key].cpu().view(1, -1)
|
|
@@ -28,20 +43,106 @@ smpl_model = smplx.create(
|
|
| 28 |
smpl_out = smpl_model(
|
| 29 |
body_pose=smplx_param["body_pose"],
|
| 30 |
global_orient=smplx_param["global_orient"],
|
| 31 |
-
# transl=smplx_param["transl"],
|
| 32 |
betas=smplx_param["betas"],
|
| 33 |
expression=smplx_param["expression"],
|
| 34 |
jaw_pose=smplx_param["jaw_pose"],
|
| 35 |
left_hand_pose=smplx_param["left_hand_pose"],
|
| 36 |
right_hand_pose=smplx_param["right_hand_pose"],
|
| 37 |
return_verts=True,
|
|
|
|
| 38 |
return_joint_transformation=True,
|
| 39 |
return_vertex_transformation=True)
|
| 40 |
|
| 41 |
smpl_verts = smpl_out.vertices.detach()[0]
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
-
trimesh.Trimesh(
|
|
|
|
|
|
| 3 |
import torch
|
| 4 |
import os.path as osp
|
| 5 |
import lib.smplx as smplx
|
| 6 |
+
from pytorch3d.ops import SubdivideMeshes
|
| 7 |
+
from pytorch3d.structures import Meshes
|
| 8 |
+
|
| 9 |
+
from lib.smplx.lbs import general_lbs
|
| 10 |
+
from lib.dataset.mesh_util import keep_largest, poisson
|
| 11 |
+
from scipy.spatial import cKDTree
|
| 12 |
from lib.dataset.mesh_util import SMPLX
|
| 13 |
+
from lib.common.local_affine import register
|
| 14 |
|
| 15 |
smplx_container = SMPLX()
|
| 16 |
+
device = torch.device("cuda:0")
|
| 17 |
+
|
| 18 |
+
prefix = "./results/github/econ/obj/304e9c4798a8c3967de7c74c24ef2e38"
|
| 19 |
+
smpl_path = f"{prefix}_smpl_00.npy"
|
| 20 |
+
econ_path = f"{prefix}_0_full.obj"
|
| 21 |
|
| 22 |
+
smplx_param = np.load(smpl_path, allow_pickle=True).item()
|
| 23 |
+
econ_obj = trimesh.load(econ_path)
|
| 24 |
+
econ_obj.vertices *= np.array([1.0, -1.0, -1.0])
|
| 25 |
+
econ_obj.vertices /= smplx_param["scale"].cpu().numpy()
|
| 26 |
+
econ_obj.vertices -= smplx_param["transl"].cpu().numpy()
|
| 27 |
|
| 28 |
for key in smplx_param.keys():
|
| 29 |
smplx_param[key] = smplx_param[key].cpu().view(1, -1)
|
|
|
|
| 43 |
smpl_out = smpl_model(
|
| 44 |
body_pose=smplx_param["body_pose"],
|
| 45 |
global_orient=smplx_param["global_orient"],
|
|
|
|
| 46 |
betas=smplx_param["betas"],
|
| 47 |
expression=smplx_param["expression"],
|
| 48 |
jaw_pose=smplx_param["jaw_pose"],
|
| 49 |
left_hand_pose=smplx_param["left_hand_pose"],
|
| 50 |
right_hand_pose=smplx_param["right_hand_pose"],
|
| 51 |
return_verts=True,
|
| 52 |
+
return_full_pose=True,
|
| 53 |
return_joint_transformation=True,
|
| 54 |
return_vertex_transformation=True)
|
| 55 |
|
| 56 |
smpl_verts = smpl_out.vertices.detach()[0]
|
| 57 |
+
smpl_tree = cKDTree(smpl_verts.cpu().numpy())
|
| 58 |
+
dist, idx = smpl_tree.query(econ_obj.vertices, k=5)
|
| 59 |
+
|
| 60 |
+
if not osp.exists(f"{prefix}_econ_cano.obj") or not osp.exists(f"{prefix}_smpl_cano.obj"):
|
| 61 |
+
|
| 62 |
+
# canonicalize for ECON
|
| 63 |
+
econ_verts = torch.tensor(econ_obj.vertices).float()
|
| 64 |
+
inv_mat = torch.inverse(smpl_out.vertex_transformation.detach()[0][idx[:, 0]])
|
| 65 |
+
homo_coord = torch.ones_like(econ_verts)[..., :1]
|
| 66 |
+
econ_cano_verts = inv_mat @ torch.cat([econ_verts, homo_coord], dim=1).unsqueeze(-1)
|
| 67 |
+
econ_cano_verts = econ_cano_verts[:, :3, 0].cpu()
|
| 68 |
+
econ_cano = trimesh.Trimesh(econ_cano_verts, econ_obj.faces)
|
| 69 |
+
|
| 70 |
+
# canonicalize for SMPL-X
|
| 71 |
+
inv_mat = torch.inverse(smpl_out.vertex_transformation.detach()[0])
|
| 72 |
+
homo_coord = torch.ones_like(smpl_verts)[..., :1]
|
| 73 |
+
smpl_cano_verts = inv_mat @ torch.cat([smpl_verts, homo_coord], dim=1).unsqueeze(-1)
|
| 74 |
+
smpl_cano_verts = smpl_cano_verts[:, :3, 0].cpu()
|
| 75 |
+
smpl_cano = trimesh.Trimesh(smpl_cano_verts, smpl_model.faces, maintain_orders=True, process=False)
|
| 76 |
+
smpl_cano.export(f"{prefix}_smpl_cano.obj")
|
| 77 |
+
|
| 78 |
+
# remove hands from ECON for next registeration
|
| 79 |
+
econ_cano_body = econ_cano.copy()
|
| 80 |
+
mano_mask = ~np.isin(idx[:, 0], smplx_container.smplx_mano_vid)
|
| 81 |
+
econ_cano_body.update_faces(mano_mask[econ_cano.faces].all(axis=1))
|
| 82 |
+
econ_cano_body.remove_unreferenced_vertices()
|
| 83 |
+
econ_cano_body = keep_largest(econ_cano_body)
|
| 84 |
+
|
| 85 |
+
# remove SMPL-X hand and face
|
| 86 |
+
register_mask = ~np.isin(
|
| 87 |
+
np.arange(smpl_cano_verts.shape[0]),
|
| 88 |
+
np.concatenate([smplx_container.smplx_mano_vid, smplx_container.smplx_front_flame_vid]))
|
| 89 |
+
register_mask *= ~smplx_container.eyeball_vertex_mask.bool().numpy()
|
| 90 |
+
smpl_cano_body = smpl_cano.copy()
|
| 91 |
+
smpl_cano_body.update_faces(register_mask[smpl_cano.faces].all(axis=1))
|
| 92 |
+
smpl_cano_body.remove_unreferenced_vertices()
|
| 93 |
+
smpl_cano_body = keep_largest(smpl_cano_body)
|
| 94 |
+
|
| 95 |
+
# upsample the smpl_cano_body and do registeration
|
| 96 |
+
smpl_cano_body = Meshes(
|
| 97 |
+
verts=[torch.tensor(smpl_cano_body.vertices).float()],
|
| 98 |
+
faces=[torch.tensor(smpl_cano_body.faces).long()],
|
| 99 |
+
).to(device)
|
| 100 |
+
sm = SubdivideMeshes(smpl_cano_body)
|
| 101 |
+
smpl_cano_body = register(econ_cano_body, sm(smpl_cano_body), device)
|
| 102 |
+
|
| 103 |
+
# remove over-streched+hand faces from ECON
|
| 104 |
+
econ_cano_body = econ_cano.copy()
|
| 105 |
+
edge_before = np.sqrt(
|
| 106 |
+
((econ_obj.vertices[econ_cano.edges[:, 0]] - econ_obj.vertices[econ_cano.edges[:, 1]])**2).sum(axis=1))
|
| 107 |
+
edge_after = np.sqrt(
|
| 108 |
+
((econ_cano.vertices[econ_cano.edges[:, 0]] - econ_cano.vertices[econ_cano.edges[:, 1]])**2).sum(axis=1))
|
| 109 |
+
edge_diff = edge_after / edge_before.clip(1e-2)
|
| 110 |
+
streched_mask = np.unique(econ_cano.edges[edge_diff > 6])
|
| 111 |
+
mano_mask = ~np.isin(idx[:, 0], smplx_container.smplx_mano_vid)
|
| 112 |
+
mano_mask[streched_mask] = False
|
| 113 |
+
econ_cano_body.update_faces(mano_mask[econ_cano.faces].all(axis=1))
|
| 114 |
+
econ_cano_body.remove_unreferenced_vertices()
|
| 115 |
+
|
| 116 |
+
# stitch the registered SMPL-X body and floating hands to ECON
|
| 117 |
+
econ_cano_tree = cKDTree(econ_cano.vertices)
|
| 118 |
+
dist, idx = econ_cano_tree.query(smpl_cano_body.vertices, k=1)
|
| 119 |
+
smpl_cano_body.update_faces((dist > 0.02)[smpl_cano_body.faces].all(axis=1))
|
| 120 |
+
smpl_cano_body.remove_unreferenced_vertices()
|
| 121 |
+
|
| 122 |
+
smpl_hand = smpl_cano.copy()
|
| 123 |
+
smpl_hand.update_faces(smplx_container.mano_vertex_mask.numpy()[smpl_hand.faces].all(axis=1))
|
| 124 |
+
smpl_hand.remove_unreferenced_vertices()
|
| 125 |
+
econ_cano = sum([smpl_hand, smpl_cano_body, econ_cano_body])
|
| 126 |
+
econ_cano = poisson(econ_cano, f"{prefix}_econ_cano.obj")
|
| 127 |
+
else:
|
| 128 |
+
econ_cano = trimesh.load(f"{prefix}_econ_cano.obj")
|
| 129 |
+
smpl_cano = trimesh.load(f"{prefix}_smpl_cano.obj", maintain_orders=True, process=False)
|
| 130 |
+
|
| 131 |
+
smpl_tree = cKDTree(smpl_cano.vertices)
|
| 132 |
+
dist, idx = smpl_tree.query(econ_cano.vertices, k=2)
|
| 133 |
+
knn_weights = np.exp(-dist**2)
|
| 134 |
+
knn_weights /= knn_weights.sum(axis=1, keepdims=True)
|
| 135 |
+
econ_J_regressor = (smpl_model.J_regressor[:, idx] * knn_weights[None]).sum(axis=-1)
|
| 136 |
+
econ_lbs_weights = (smpl_model.lbs_weights.T[:, idx] * knn_weights[None]).sum(axis=-1).T
|
| 137 |
+
econ_J_regressor /= econ_J_regressor.sum(axis=1, keepdims=True)
|
| 138 |
+
econ_lbs_weights /= econ_lbs_weights.sum(axis=1, keepdims=True)
|
| 139 |
+
|
| 140 |
+
posed_econ_verts, _ = general_lbs(
|
| 141 |
+
pose=smpl_out.full_pose,
|
| 142 |
+
v_template=torch.tensor(econ_cano.vertices).unsqueeze(0),
|
| 143 |
+
J_regressor=econ_J_regressor,
|
| 144 |
+
parents=smpl_model.parents,
|
| 145 |
+
lbs_weights=econ_lbs_weights)
|
| 146 |
|
| 147 |
+
econ_pose = trimesh.Trimesh(posed_econ_verts[0].detach(), econ_cano.faces)
|
| 148 |
+
econ_pose.export(f"{prefix}_econ_pose.obj")
|
apps/infer.py
CHANGED
|
@@ -37,6 +37,7 @@ from lib.common.train_util import init_loss, load_normal_networks, load_networks
|
|
| 37 |
from lib.common.BNI import BNI
|
| 38 |
from lib.common.BNI_utils import save_normal_tensor
|
| 39 |
from lib.dataset.TestDataset import TestDataset
|
|
|
|
| 40 |
from lib.net.geometry import rot6d_to_rotmat, rotation_matrix_to_angle_axis
|
| 41 |
from lib.dataset.mesh_util import *
|
| 42 |
from lib.common.voxelize import VoxelGrid
|
|
@@ -156,8 +157,8 @@ if __name__ == "__main__":
|
|
| 156 |
|
| 157 |
N_body, N_pose = optimed_pose.shape[:2]
|
| 158 |
|
| 159 |
-
smpl_path =
|
| 160 |
-
|
| 161 |
if osp.exists(smpl_path):
|
| 162 |
|
| 163 |
smpl_verts_lst = []
|
|
@@ -182,6 +183,7 @@ if __name__ == "__main__":
|
|
| 182 |
|
| 183 |
in_tensor["smpl_verts"] = batch_smpl_verts * torch.tensor([1., -1., 1.]).to(device)
|
| 184 |
in_tensor["smpl_faces"] = batch_smpl_faces[:, :, [0, 2, 1]]
|
|
|
|
| 185 |
else:
|
| 186 |
# smpl optimization
|
| 187 |
loop_smpl = tqdm(range(args.loop_smpl))
|
|
@@ -447,15 +449,15 @@ if __name__ == "__main__":
|
|
| 447 |
(SMPLX_object.front_flame_vertex_mask + SMPLX_object.mano_vertex_mask +
|
| 448 |
SMPLX_object.eyeball_vertex_mask).eq(0).float(),
|
| 449 |
)
|
| 450 |
-
|
| 451 |
-
#
|
| 452 |
-
|
| 453 |
verts=[torch.tensor(side_mesh.vertices).float()],
|
| 454 |
faces=[torch.tensor(side_mesh.faces).long()],
|
| 455 |
-
)
|
| 456 |
-
sm = SubdivideMeshes(
|
| 457 |
-
|
| 458 |
-
|
| 459 |
|
| 460 |
side_verts = torch.tensor(side_mesh.vertices).float().to(device)
|
| 461 |
side_faces = torch.tensor(side_mesh.faces).long().to(device)
|
|
|
|
| 37 |
from lib.common.BNI import BNI
|
| 38 |
from lib.common.BNI_utils import save_normal_tensor
|
| 39 |
from lib.dataset.TestDataset import TestDataset
|
| 40 |
+
from lib.common.local_affine import register
|
| 41 |
from lib.net.geometry import rot6d_to_rotmat, rotation_matrix_to_angle_axis
|
| 42 |
from lib.dataset.mesh_util import *
|
| 43 |
from lib.common.voxelize import VoxelGrid
|
|
|
|
| 157 |
|
| 158 |
N_body, N_pose = optimed_pose.shape[:2]
|
| 159 |
|
| 160 |
+
smpl_path = f"{args.out_dir}/{cfg.name}/obj/{data['name']}_smpl_00.obj"
|
| 161 |
+
|
| 162 |
if osp.exists(smpl_path):
|
| 163 |
|
| 164 |
smpl_verts_lst = []
|
|
|
|
| 183 |
|
| 184 |
in_tensor["smpl_verts"] = batch_smpl_verts * torch.tensor([1., -1., 1.]).to(device)
|
| 185 |
in_tensor["smpl_faces"] = batch_smpl_faces[:, :, [0, 2, 1]]
|
| 186 |
+
|
| 187 |
else:
|
| 188 |
# smpl optimization
|
| 189 |
loop_smpl = tqdm(range(args.loop_smpl))
|
|
|
|
| 449 |
(SMPLX_object.front_flame_vertex_mask + SMPLX_object.mano_vertex_mask +
|
| 450 |
SMPLX_object.eyeball_vertex_mask).eq(0).float(),
|
| 451 |
)
|
| 452 |
+
|
| 453 |
+
#register side_mesh to BNI surfaces
|
| 454 |
+
side_mesh = Meshes(
|
| 455 |
verts=[torch.tensor(side_mesh.vertices).float()],
|
| 456 |
faces=[torch.tensor(side_mesh.faces).long()],
|
| 457 |
+
).to(device)
|
| 458 |
+
sm = SubdivideMeshes(side_mesh)
|
| 459 |
+
side_mesh = register(BNI_object.F_B_trimesh, sm(side_mesh), device)
|
| 460 |
+
|
| 461 |
|
| 462 |
side_verts = torch.tensor(side_mesh.vertices).float().to(device)
|
| 463 |
side_faces = torch.tensor(side_mesh.faces).long().to(device)
|
lib/common/local_affine.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 by Haozhe Wu, Tsinghua University, Department of Computer Science and Technology.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
# This file is part of the pytorch-nicp,
|
| 4 |
+
# and is released under the "MIT License Agreement". Please see the LICENSE
|
| 5 |
+
# file that should have been included as part of this package.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import trimesh
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
from pytorch3d.structures import Meshes
|
| 12 |
+
from pytorch3d.loss import chamfer_distance
|
| 13 |
+
from lib.dataset.mesh_util import update_mesh_shape_prior_losses
|
| 14 |
+
from lib.common.train_util import init_loss
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# reference: https://github.com/wuhaozhe/pytorch-nicp
|
| 18 |
+
class LocalAffine(nn.Module):
|
| 19 |
+
|
| 20 |
+
def __init__(self, num_points, batch_size=1, edges=None):
|
| 21 |
+
'''
|
| 22 |
+
specify the number of points, the number of points should be constant across the batch
|
| 23 |
+
and the edges torch.Longtensor() with shape N * 2
|
| 24 |
+
the local affine operator supports batch operation
|
| 25 |
+
batch size must be constant
|
| 26 |
+
add additional pooling on top of w matrix
|
| 27 |
+
'''
|
| 28 |
+
super(LocalAffine, self).__init__()
|
| 29 |
+
self.A = nn.Parameter(torch.eye(3).unsqueeze(0).unsqueeze(0).repeat(batch_size, num_points, 1, 1))
|
| 30 |
+
self.b = nn.Parameter(torch.zeros(3).unsqueeze(0).unsqueeze(0).unsqueeze(3).repeat(batch_size, num_points, 1, 1))
|
| 31 |
+
self.edges = edges
|
| 32 |
+
self.num_points = num_points
|
| 33 |
+
|
| 34 |
+
def stiffness(self):
|
| 35 |
+
'''
|
| 36 |
+
calculate the stiffness of local affine transformation
|
| 37 |
+
f norm get infinity gradient when w is zero matrix,
|
| 38 |
+
'''
|
| 39 |
+
if self.edges is None:
|
| 40 |
+
raise Exception("edges cannot be none when calculate stiff")
|
| 41 |
+
idx1 = self.edges[:, 0]
|
| 42 |
+
idx2 = self.edges[:, 1]
|
| 43 |
+
affine_weight = torch.cat((self.A, self.b), dim=3)
|
| 44 |
+
w1 = torch.index_select(affine_weight, dim=1, index=idx1)
|
| 45 |
+
w2 = torch.index_select(affine_weight, dim=1, index=idx2)
|
| 46 |
+
w_diff = (w1 - w2)**2
|
| 47 |
+
w_rigid = (torch.linalg.det(self.A) - 1.0)**2
|
| 48 |
+
return w_diff, w_rigid
|
| 49 |
+
|
| 50 |
+
def forward(self, x):
|
| 51 |
+
'''
|
| 52 |
+
x should have shape of B * N * 3
|
| 53 |
+
'''
|
| 54 |
+
x = x.unsqueeze(3)
|
| 55 |
+
out_x = torch.matmul(self.A, x)
|
| 56 |
+
out_x = out_x + self.b
|
| 57 |
+
stiffness, rigid = self.stiffness()
|
| 58 |
+
out_x.squeeze_(3)
|
| 59 |
+
return out_x, stiffness, rigid
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def trimesh2meshes(mesh):
|
| 63 |
+
'''
|
| 64 |
+
convert trimesh mesh to pytorch3d mesh
|
| 65 |
+
'''
|
| 66 |
+
verts = torch.from_numpy(mesh.vertices).float()
|
| 67 |
+
faces = torch.from_numpy(mesh.faces).long()
|
| 68 |
+
mesh = Meshes(verts.unsqueeze(0), faces.unsqueeze(0))
|
| 69 |
+
return mesh
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def register(target_mesh, src_mesh, device):
|
| 73 |
+
|
| 74 |
+
# define local_affine deform verts
|
| 75 |
+
tgt_mesh = trimesh2meshes(target_mesh).to(device)
|
| 76 |
+
src_verts = src_mesh.verts_padded().clone()
|
| 77 |
+
|
| 78 |
+
local_affine_model = LocalAffine(src_mesh.verts_padded().shape[1],
|
| 79 |
+
src_mesh.verts_padded().shape[0], src_mesh.edges_packed()).to(device)
|
| 80 |
+
|
| 81 |
+
optimizer_cloth = torch.optim.Adam([{'params': local_affine_model.parameters()}], lr=1e-2, amsgrad=True)
|
| 82 |
+
scheduler_cloth = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
| 83 |
+
optimizer_cloth,
|
| 84 |
+
mode="min",
|
| 85 |
+
factor=0.1,
|
| 86 |
+
verbose=0,
|
| 87 |
+
min_lr=1e-5,
|
| 88 |
+
patience=5,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
losses = init_loss()
|
| 92 |
+
|
| 93 |
+
loop_cloth = tqdm(range(200))
|
| 94 |
+
|
| 95 |
+
for i in loop_cloth:
|
| 96 |
+
|
| 97 |
+
optimizer_cloth.zero_grad()
|
| 98 |
+
|
| 99 |
+
deformed_verts, stiffness, rigid = local_affine_model(src_verts)
|
| 100 |
+
src_mesh = src_mesh.update_padded(deformed_verts)
|
| 101 |
+
|
| 102 |
+
# losses for laplacian, edge, normal consistency
|
| 103 |
+
update_mesh_shape_prior_losses(src_mesh, losses)
|
| 104 |
+
|
| 105 |
+
losses["cloth"]["value"] = chamfer_distance(
|
| 106 |
+
x=src_mesh.verts_padded(),
|
| 107 |
+
y=tgt_mesh.verts_padded())[0]
|
| 108 |
+
|
| 109 |
+
losses["stiffness"]["value"] = torch.mean(stiffness)
|
| 110 |
+
losses["rigid"]["value"] = torch.mean(rigid)
|
| 111 |
+
|
| 112 |
+
# Weighted sum of the losses
|
| 113 |
+
cloth_loss = torch.tensor(0.0, requires_grad=True).to(device)
|
| 114 |
+
pbar_desc = "Register SMPL-X towards ECON --- "
|
| 115 |
+
|
| 116 |
+
for k in losses.keys():
|
| 117 |
+
if losses[k]["weight"] > 0.0 and losses[k]["value"] != 0.0:
|
| 118 |
+
cloth_loss = cloth_loss + \
|
| 119 |
+
losses[k]["value"] * losses[k]["weight"]
|
| 120 |
+
pbar_desc += f"{k}:{losses[k]['value']* losses[k]['weight']:.3f} | "
|
| 121 |
+
|
| 122 |
+
pbar_desc += f"Total: {cloth_loss:.5f}"
|
| 123 |
+
loop_cloth.set_description(pbar_desc)
|
| 124 |
+
|
| 125 |
+
# update params
|
| 126 |
+
cloth_loss.backward(retain_graph=True)
|
| 127 |
+
optimizer_cloth.step()
|
| 128 |
+
scheduler_cloth.step(cloth_loss)
|
| 129 |
+
|
| 130 |
+
final = trimesh.Trimesh(
|
| 131 |
+
src_mesh.verts_packed().detach().squeeze(0).cpu(),
|
| 132 |
+
src_mesh.faces_packed().detach().squeeze(0).cpu(),
|
| 133 |
+
process=False,
|
| 134 |
+
maintains_order=True)
|
| 135 |
+
|
| 136 |
+
return final
|
lib/common/train_util.py
CHANGED
|
@@ -32,7 +32,7 @@ def init_loss():
|
|
| 32 |
losses = {
|
| 33 |
# Cloth: Normal_recon - Normal_pred
|
| 34 |
"cloth": {
|
| 35 |
-
"weight":
|
| 36 |
"value": 0.0
|
| 37 |
},
|
| 38 |
# Cloth: [RT]_v1 - [RT]_v2 (v1-edge-v2)
|
|
|
|
| 32 |
losses = {
|
| 33 |
# Cloth: Normal_recon - Normal_pred
|
| 34 |
"cloth": {
|
| 35 |
+
"weight": 1e3,
|
| 36 |
"value": 0.0
|
| 37 |
},
|
| 38 |
# Cloth: [RT]_v1 - [RT]_v2 (v1-edge-v2)
|
lib/dataset/mesh_util.py
CHANGED
|
@@ -552,7 +552,7 @@ def poisson_remesh(obj_path):
|
|
| 552 |
ms.meshing_decimation_quadric_edge_collapse(targetfacenum=50000)
|
| 553 |
# ms.apply_coord_laplacian_smoothing()
|
| 554 |
ms.save_current_mesh(obj_path)
|
| 555 |
-
ms.save_current_mesh(obj_path.replace(".obj", ".ply"))
|
| 556 |
polished_mesh = trimesh.load_mesh(obj_path)
|
| 557 |
|
| 558 |
return polished_mesh
|
|
@@ -1013,6 +1013,15 @@ def clean_floats(mesh):
|
|
| 1013 |
return sum(clean_mesh_lst)
|
| 1014 |
|
| 1015 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1016 |
def mesh_move(mesh_lst, step, scale=1.0):
|
| 1017 |
|
| 1018 |
trans = np.array([1.0, 0.0, 0.0]) * step
|
|
@@ -1036,3 +1045,16 @@ def rescale_smpl(fitted_path, scale=100, translate=(0, 0, 0)):
|
|
| 1036 |
fitted_body.apply_transform(resize_matrix)
|
| 1037 |
|
| 1038 |
return np.array(fitted_body.vertices)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 552 |
ms.meshing_decimation_quadric_edge_collapse(targetfacenum=50000)
|
| 553 |
# ms.apply_coord_laplacian_smoothing()
|
| 554 |
ms.save_current_mesh(obj_path)
|
| 555 |
+
# ms.save_current_mesh(obj_path.replace(".obj", ".ply"))
|
| 556 |
polished_mesh = trimesh.load_mesh(obj_path)
|
| 557 |
|
| 558 |
return polished_mesh
|
|
|
|
| 1013 |
return sum(clean_mesh_lst)
|
| 1014 |
|
| 1015 |
|
| 1016 |
+
def keep_largest(mesh):
|
| 1017 |
+
mesh_lst = mesh.split(only_watertight=False)
|
| 1018 |
+
keep_mesh = mesh_lst[0]
|
| 1019 |
+
for mesh in mesh_lst:
|
| 1020 |
+
if mesh.vertices.shape[0] > keep_mesh.vertices.shape[0]:
|
| 1021 |
+
keep_mesh = mesh
|
| 1022 |
+
return keep_mesh
|
| 1023 |
+
|
| 1024 |
+
|
| 1025 |
def mesh_move(mesh_lst, step, scale=1.0):
|
| 1026 |
|
| 1027 |
trans = np.array([1.0, 0.0, 0.0]) * step
|
|
|
|
| 1045 |
fitted_body.apply_transform(resize_matrix)
|
| 1046 |
|
| 1047 |
return np.array(fitted_body.vertices)
|
| 1048 |
+
|
| 1049 |
+
|
| 1050 |
+
def get_joint_mesh(joints, radius=2.0):
|
| 1051 |
+
|
| 1052 |
+
ball = trimesh.creation.icosphere(radius=radius)
|
| 1053 |
+
combined = None
|
| 1054 |
+
for joint in joints:
|
| 1055 |
+
ball_new = trimesh.Trimesh(vertices=ball.vertices + joint, faces=ball.faces, process=False)
|
| 1056 |
+
if combined is None:
|
| 1057 |
+
combined = ball_new
|
| 1058 |
+
else:
|
| 1059 |
+
combined = sum([combined, ball_new])
|
| 1060 |
+
return combined
|
lib/smplx/lbs.py
CHANGED
|
@@ -194,6 +194,7 @@ def lbs(
|
|
| 194 |
# 3. Add pose blend shapes
|
| 195 |
# N x J x 3 x 3
|
| 196 |
ident = torch.eye(3, dtype=dtype, device=device)
|
|
|
|
| 197 |
if pose2rot:
|
| 198 |
rot_mats = batch_rodrigues(pose.view(-1, 3)).view([batch_size, -1, 3, 3])
|
| 199 |
|
|
@@ -229,6 +230,77 @@ def lbs(
|
|
| 229 |
return verts, J_transformed
|
| 230 |
|
| 231 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
def vertices2joints(J_regressor: Tensor, vertices: Tensor) -> Tensor:
|
| 233 |
"""Calculates the 3D joint locations from the vertices
|
| 234 |
|
|
|
|
| 194 |
# 3. Add pose blend shapes
|
| 195 |
# N x J x 3 x 3
|
| 196 |
ident = torch.eye(3, dtype=dtype, device=device)
|
| 197 |
+
|
| 198 |
if pose2rot:
|
| 199 |
rot_mats = batch_rodrigues(pose.view(-1, 3)).view([batch_size, -1, 3, 3])
|
| 200 |
|
|
|
|
| 230 |
return verts, J_transformed
|
| 231 |
|
| 232 |
|
| 233 |
+
def general_lbs(
|
| 234 |
+
pose: Tensor,
|
| 235 |
+
v_template: Tensor,
|
| 236 |
+
J_regressor: Tensor,
|
| 237 |
+
parents: Tensor,
|
| 238 |
+
lbs_weights: Tensor,
|
| 239 |
+
pose2rot: bool = True,
|
| 240 |
+
) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
|
| 241 |
+
"""Performs Linear Blend Skinning with the given shape and pose parameters
|
| 242 |
+
|
| 243 |
+
Parameters
|
| 244 |
+
----------
|
| 245 |
+
pose : torch.tensor Bx(J + 1) * 3
|
| 246 |
+
The pose parameters in axis-angle format
|
| 247 |
+
v_template torch.tensor BxVx3
|
| 248 |
+
The template mesh that will be deformed
|
| 249 |
+
J_regressor : torch.tensor JxV
|
| 250 |
+
The regressor array that is used to calculate the joints from
|
| 251 |
+
the position of the vertices
|
| 252 |
+
parents: torch.tensor J
|
| 253 |
+
The array that describes the kinematic tree for the model
|
| 254 |
+
lbs_weights: torch.tensor N x V x (J + 1)
|
| 255 |
+
The linear blend skinning weights that represent how much the
|
| 256 |
+
rotation matrix of each part affects each vertex
|
| 257 |
+
pose2rot: bool, optional
|
| 258 |
+
Flag on whether to convert the input pose tensor to rotation
|
| 259 |
+
matrices. The default value is True. If False, then the pose tensor
|
| 260 |
+
should already contain rotation matrices and have a size of
|
| 261 |
+
Bx(J + 1)x9
|
| 262 |
+
dtype: torch.dtype, optional
|
| 263 |
+
|
| 264 |
+
Returns
|
| 265 |
+
-------
|
| 266 |
+
verts: torch.tensor BxVx3
|
| 267 |
+
The vertices of the mesh after applying the shape and pose
|
| 268 |
+
displacements.
|
| 269 |
+
joints: torch.tensor BxJx3
|
| 270 |
+
The joints of the model
|
| 271 |
+
"""
|
| 272 |
+
|
| 273 |
+
batch_size = pose.shape[0]
|
| 274 |
+
device, dtype = pose.device, pose.dtype
|
| 275 |
+
|
| 276 |
+
# Get the joints
|
| 277 |
+
# NxJx3 array
|
| 278 |
+
J = vertices2joints(J_regressor, v_template)
|
| 279 |
+
|
| 280 |
+
if pose2rot:
|
| 281 |
+
rot_mats = batch_rodrigues(pose.view(-1, 3)).view([batch_size, -1, 3, 3])
|
| 282 |
+
else:
|
| 283 |
+
rot_mats = pose.view(batch_size, -1, 3, 3)
|
| 284 |
+
|
| 285 |
+
# 4. Get the global joint location
|
| 286 |
+
J_transformed, A = batch_rigid_transform(rot_mats, J, parents, dtype=dtype)
|
| 287 |
+
|
| 288 |
+
# 5. Do skinning:
|
| 289 |
+
# W is N x V x (J + 1)
|
| 290 |
+
W = lbs_weights.unsqueeze(dim=0).expand([batch_size, -1, -1])
|
| 291 |
+
# (N x V x (J + 1)) x (N x (J + 1) x 16)
|
| 292 |
+
num_joints = J_regressor.shape[0]
|
| 293 |
+
T = torch.matmul(W, A.view(batch_size, num_joints, 16)).view(batch_size, -1, 4, 4)
|
| 294 |
+
|
| 295 |
+
homogen_coord = torch.ones([batch_size, v_template.shape[1], 1], dtype=dtype, device=device)
|
| 296 |
+
v_posed_homo = torch.cat([v_template, homogen_coord], dim=2)
|
| 297 |
+
v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, dim=-1))
|
| 298 |
+
|
| 299 |
+
verts = v_homo[:, :, :3, 0]
|
| 300 |
+
|
| 301 |
+
return verts, J
|
| 302 |
+
|
| 303 |
+
|
| 304 |
def vertices2joints(J_regressor: Tensor, vertices: Tensor) -> Tensor:
|
| 305 |
"""Calculates the 3D joint locations from the vertices
|
| 306 |
|