|  |  | 
					
						
						|  |  | 
					
						
						|  | import json | 
					
						
						|  | import logging | 
					
						
						|  | from typing import List, Optional | 
					
						
						|  | import torch | 
					
						
						|  | from torch import nn | 
					
						
						|  |  | 
					
						
						|  | from detectron2.utils.file_io import PathManager | 
					
						
						|  |  | 
					
						
						|  | from densepose.structures.mesh import create_mesh | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class MeshAlignmentEvaluator: | 
					
						
						|  | """ | 
					
						
						|  | Class for evaluation of 3D mesh alignment based on the learned vertex embeddings | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, embedder: nn.Module, mesh_names: Optional[List[str]]): | 
					
						
						|  | self.embedder = embedder | 
					
						
						|  |  | 
					
						
						|  | self.mesh_names = mesh_names if mesh_names else embedder.mesh_names | 
					
						
						|  | self.logger = logging.getLogger(__name__) | 
					
						
						|  | with PathManager.open( | 
					
						
						|  | "https://dl.fbaipublicfiles.com/densepose/data/cse/mesh_keyvertices_v0.json", "r" | 
					
						
						|  | ) as f: | 
					
						
						|  | self.mesh_keyvertices = json.load(f) | 
					
						
						|  |  | 
					
						
						|  | def evaluate(self): | 
					
						
						|  | ge_per_mesh = {} | 
					
						
						|  | gps_per_mesh = {} | 
					
						
						|  | for mesh_name_1 in self.mesh_names: | 
					
						
						|  | avg_errors = [] | 
					
						
						|  | avg_gps = [] | 
					
						
						|  | embeddings_1 = self.embedder(mesh_name_1) | 
					
						
						|  | keyvertices_1 = self.mesh_keyvertices[mesh_name_1] | 
					
						
						|  | keyvertex_names_1 = list(keyvertices_1.keys()) | 
					
						
						|  | keyvertex_indices_1 = [keyvertices_1[name] for name in keyvertex_names_1] | 
					
						
						|  | for mesh_name_2 in self.mesh_names: | 
					
						
						|  | if mesh_name_1 == mesh_name_2: | 
					
						
						|  | continue | 
					
						
						|  | embeddings_2 = self.embedder(mesh_name_2) | 
					
						
						|  | keyvertices_2 = self.mesh_keyvertices[mesh_name_2] | 
					
						
						|  | sim_matrix_12 = embeddings_1[keyvertex_indices_1].mm(embeddings_2.T) | 
					
						
						|  | vertices_2_matching_keyvertices_1 = sim_matrix_12.argmax(axis=1) | 
					
						
						|  | mesh_2 = create_mesh(mesh_name_2, embeddings_2.device) | 
					
						
						|  | geodists = mesh_2.geodists[ | 
					
						
						|  | vertices_2_matching_keyvertices_1, | 
					
						
						|  | [keyvertices_2[name] for name in keyvertex_names_1], | 
					
						
						|  | ] | 
					
						
						|  | Current_Mean_Distances = 0.255 | 
					
						
						|  | gps = (-(geodists**2) / (2 * (Current_Mean_Distances**2))).exp() | 
					
						
						|  | avg_errors.append(geodists.mean().item()) | 
					
						
						|  | avg_gps.append(gps.mean().item()) | 
					
						
						|  |  | 
					
						
						|  | ge_mean = torch.as_tensor(avg_errors).mean().item() | 
					
						
						|  | gps_mean = torch.as_tensor(avg_gps).mean().item() | 
					
						
						|  | ge_per_mesh[mesh_name_1] = ge_mean | 
					
						
						|  | gps_per_mesh[mesh_name_1] = gps_mean | 
					
						
						|  | ge_mean_global = torch.as_tensor(list(ge_per_mesh.values())).mean().item() | 
					
						
						|  | gps_mean_global = torch.as_tensor(list(gps_per_mesh.values())).mean().item() | 
					
						
						|  | per_mesh_metrics = { | 
					
						
						|  | "GE": ge_per_mesh, | 
					
						
						|  | "GPS": gps_per_mesh, | 
					
						
						|  | } | 
					
						
						|  | return ge_mean_global, gps_mean_global, per_mesh_metrics | 
					
						
						|  |  |