Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import json | |
| import os | |
| import itertools | |
| import trimesh | |
| from matplotlib.path import Path | |
| from collections import Counter | |
| from sklearn.neighbors import KNeighborsClassifier | |
| def load_segmentation(path, shape): | |
| """ | |
| Get a segmentation mask for a given image | |
| Arguments: | |
| path: path to the segmentation json file | |
| shape: shape of the output mask | |
| Returns: | |
| Returns a segmentation mask | |
| """ | |
| with open(path) as json_file: | |
| dict = json.load(json_file) | |
| segmentations = [] | |
| for key, val in dict.items(): | |
| if not key.startswith('item'): | |
| continue | |
| # Each item can have multiple polygons. Combine them to one | |
| # segmentation_coord = list(itertools.chain.from_iterable(val['segmentation'])) | |
| # segmentation_coord = np.round(np.array(segmentation_coord)).astype(int) | |
| coordinates = [] | |
| for segmentation_coord in val['segmentation']: | |
| # The format before is [x1,y1, x2, y2, ....] | |
| x = segmentation_coord[::2] | |
| y = segmentation_coord[1::2] | |
| xy = np.vstack((x, y)).T | |
| coordinates.append(xy) | |
| segmentations.append( | |
| {'type': val['category_name'], 'type_id': val['category_id'], 'coordinates': coordinates}) | |
| return segmentations | |
| def smpl_to_recon_labels(recon, smpl, k=1): | |
| """ | |
| Get the bodypart labels for the recon object by using the labels from the corresponding smpl object | |
| Arguments: | |
| recon: trimesh object (fully clothed model) | |
| shape: trimesh object (smpl model) | |
| k: number of nearest neighbours to use | |
| Returns: | |
| Returns a dictionary containing the bodypart and the corresponding indices | |
| """ | |
| smpl_vert_segmentation = json.load( | |
| open(os.path.join(os.path.dirname(__file__), 'smpl_vert_segmentation.json'))) | |
| n = smpl.vertices.shape[0] | |
| y = np.array([None] * n) | |
| for key, val in smpl_vert_segmentation.items(): | |
| y[val] = key | |
| classifier = KNeighborsClassifier(n_neighbors=1) | |
| classifier.fit(smpl.vertices, y) | |
| y_pred = classifier.predict(recon.vertices) | |
| recon_labels = {} | |
| for key in smpl_vert_segmentation.keys(): | |
| recon_labels[key] = list(np.argwhere( | |
| y_pred == key).flatten().astype(int)) | |
| return recon_labels | |
| def extract_cloth(recon, segmentation, K, R, t, smpl=None): | |
| """ | |
| Extract a portion of a mesh using 2d segmentation coordinates | |
| Arguments: | |
| recon: fully clothed mesh | |
| seg_coord: segmentation coordinates in 2D (NDC) | |
| K: intrinsic matrix of the projection | |
| R: rotation matrix of the projection | |
| t: translation vector of the projection | |
| Returns: | |
| Returns a submesh using the segmentation coordinates | |
| """ | |
| seg_coord = segmentation['coord_normalized'] | |
| mesh = trimesh.Trimesh(recon.vertices, recon.faces) | |
| extrinsic = np.zeros((3, 4)) | |
| extrinsic[:3, :3] = R | |
| extrinsic[:, 3] = t | |
| P = K[:3, :3] @ extrinsic | |
| P_inv = np.linalg.pinv(P) | |
| # Each segmentation can contain multiple polygons | |
| # We need to check them separately | |
| points_so_far = [] | |
| faces = recon.faces | |
| for polygon in seg_coord: | |
| n = len(polygon) | |
| coords_h = np.hstack((polygon, np.ones((n, 1)))) | |
| # Apply the inverse projection on homogeneus 2D coordinates to get the corresponding 3d Coordinates | |
| XYZ = P_inv @ coords_h[:, :, None] | |
| XYZ = XYZ.reshape((XYZ.shape[0], XYZ.shape[1])) | |
| XYZ = XYZ[:, :3] / XYZ[:, 3, None] | |
| p = Path(XYZ[:, :2]) | |
| grid = p.contains_points(recon.vertices[:, :2]) | |
| indeces = np.argwhere(grid == True) | |
| points_so_far += list(indeces.flatten()) | |
| if smpl is not None: | |
| num_verts = recon.vertices.shape[0] | |
| recon_labels = smpl_to_recon_labels(recon, smpl) | |
| body_parts_to_remove = ['rightHand', 'leftToeBase', 'leftFoot', 'rightFoot', 'head', | |
| 'leftHandIndex1', 'rightHandIndex1', 'rightToeBase', 'leftHand', 'rightHand'] | |
| type = segmentation['type_id'] | |
| # Remove additional bodyparts that are most likely not part of the segmentation but might intersect (e.g. hand in front of torso) | |
| # https://github.com/switchablenorms/DeepFashion2 | |
| # Short sleeve clothes | |
| if type == 1 or type == 3 or type == 10: | |
| body_parts_to_remove += ['leftForeArm', 'rightForeArm'] | |
| # No sleeves at all or lower body clothes | |
| elif type == 5 or type == 6 or type == 12 or type == 13 or type == 8 or type == 9: | |
| body_parts_to_remove += ['leftForeArm', | |
| 'rightForeArm', 'leftArm', 'rightArm'] | |
| # Shorts | |
| elif type == 7: | |
| body_parts_to_remove += ['leftLeg', 'rightLeg', | |
| 'leftForeArm', 'rightForeArm', 'leftArm', 'rightArm'] | |
| verts_to_remove = list(itertools.chain.from_iterable( | |
| [recon_labels[part] for part in body_parts_to_remove])) | |
| label_mask = np.zeros(num_verts, dtype=bool) | |
| label_mask[verts_to_remove] = True | |
| seg_mask = np.zeros(num_verts, dtype=bool) | |
| seg_mask[points_so_far] = True | |
| # Remove points that belong to other bodyparts | |
| # If a vertice in pointsSoFar is included in the bodyparts to remove, then these points should be removed | |
| extra_verts_to_remove = np.array(list(seg_mask) and list(label_mask)) | |
| combine_mask = np.zeros(num_verts, dtype=bool) | |
| combine_mask[points_so_far] = True | |
| combine_mask[extra_verts_to_remove] = False | |
| all_indices = np.argwhere(combine_mask == True).flatten() | |
| i_x = np.where(np.in1d(faces[:, 0], all_indices))[0] | |
| i_y = np.where(np.in1d(faces[:, 1], all_indices))[0] | |
| i_z = np.where(np.in1d(faces[:, 2], all_indices))[0] | |
| faces_to_keep = np.array(list(set(i_x).union(i_y).union(i_z))) | |
| mask = np.zeros(len(recon.faces), dtype=bool) | |
| if len(faces_to_keep) > 0: | |
| mask[faces_to_keep] = True | |
| mesh.update_faces(mask) | |
| mesh.remove_unreferenced_vertices() | |
| # mesh.rezero() | |
| return mesh | |
| return None | |