Spaces:
Runtime error
Runtime error
| import torch | |
| EPS=1e-10 | |
| def get_CosineDistance_matrix(features): | |
| if features.dim() >2: | |
| features = features.reshape(features.shape[0], -1) | |
| features_norm = features / (EPS + features.norm(dim=1)[:, None]) | |
| ans = torch.mm(features_norm, features_norm.transpose(0,1)) | |
| # We want distance, not similarity. | |
| ans = torch.add(-ans, 1.) | |
| return ans | |
| def aggregatefrom_specimen_to_species(sorted_class_names_according_to_class_indx, specimen_distance_matrix, z_size, channels): | |
| unique_sorted_class_names_according_to_class_indx = sorted(set(sorted_class_names_according_to_class_indx)) | |
| # species_dist_matrix = torch.zeros(len(unique_sorted_class_names_according_to_class_indx), 256, 16, 16) | |
| species_dist_matrix = torch.zeros(len(unique_sorted_class_names_according_to_class_indx), channels, z_size, z_size) | |
| for indx_i, i in enumerate(unique_sorted_class_names_according_to_class_indx): | |
| class_i_indices = [idx for idx, element in enumerate(sorted_class_names_according_to_class_indx) if element == i] | |
| species_dist_matrix[indx_i] = torch.mean(specimen_distance_matrix[class_i_indices,:], dim=0, keepdim=True) | |
| return species_dist_matrix |