Spaces:
Running
on
T4
Running
on
T4
| import numpy as np | |
| from alphafold.common import residue_constants | |
| def make_atom14_positions(prot): | |
| """Constructs denser atom positions (14 dimensions instead of 37).""" | |
| restype_atom14_to_atom37 = [] # mapping (restype, atom14) --> atom37 | |
| restype_atom37_to_atom14 = [] # mapping (restype, atom37) --> atom14 | |
| restype_atom14_mask = [] | |
| for rt in residue_constants.restypes: | |
| atom_names = residue_constants.restype_name_to_atom14_names[ | |
| residue_constants.restype_1to3[rt]] | |
| restype_atom14_to_atom37.append([ | |
| (residue_constants.atom_order[name] if name else 0) | |
| for name in atom_names | |
| ]) | |
| atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)} | |
| restype_atom37_to_atom14.append([ | |
| (atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0) | |
| for name in residue_constants.atom_types | |
| ]) | |
| restype_atom14_mask.append([(1. if name else 0.) for name in atom_names]) | |
| # Add dummy mapping for restype 'UNK'. | |
| restype_atom14_to_atom37.append([0] * 14) | |
| restype_atom37_to_atom14.append([0] * 37) | |
| restype_atom14_mask.append([0.] * 14) | |
| restype_atom14_to_atom37 = np.array(restype_atom14_to_atom37, dtype=np.int32) | |
| restype_atom37_to_atom14 = np.array(restype_atom37_to_atom14, dtype=np.int32) | |
| restype_atom14_mask = np.array(restype_atom14_mask, dtype=np.float32) | |
| # Create the mapping for (residx, atom14) --> atom37, i.e. an array | |
| # with shape (num_res, 14) containing the atom37 indices for this protein. | |
| residx_atom14_to_atom37 = restype_atom14_to_atom37[prot["aatype"]] | |
| residx_atom14_mask = restype_atom14_mask[prot["aatype"]] | |
| # Create a mask for known ground truth positions. | |
| residx_atom14_gt_mask = residx_atom14_mask * np.take_along_axis( | |
| prot["all_atom_mask"], residx_atom14_to_atom37, axis=1).astype(np.float32) | |
| # Gather the ground truth positions. | |
| residx_atom14_gt_positions = residx_atom14_gt_mask[:, :, None] * ( | |
| np.take_along_axis(prot["all_atom_positions"], | |
| residx_atom14_to_atom37[..., None], | |
| axis=1)) | |
| prot["atom14_atom_exists"] = residx_atom14_mask | |
| prot["atom14_gt_exists"] = residx_atom14_gt_mask | |
| prot["atom14_gt_positions"] = residx_atom14_gt_positions | |
| prot["residx_atom14_to_atom37"] = residx_atom14_to_atom37 | |
| # Create the gather indices for mapping back. | |
| residx_atom37_to_atom14 = restype_atom37_to_atom14[prot["aatype"]] | |
| prot["residx_atom37_to_atom14"] = residx_atom37_to_atom14 | |
| # Create the corresponding mask. | |
| restype_atom37_mask = np.zeros([21, 37], dtype=np.float32) | |
| for restype, restype_letter in enumerate(residue_constants.restypes): | |
| restype_name = residue_constants.restype_1to3[restype_letter] | |
| atom_names = residue_constants.residue_atoms[restype_name] | |
| for atom_name in atom_names: | |
| atom_type = residue_constants.atom_order[atom_name] | |
| restype_atom37_mask[restype, atom_type] = 1 | |
| residx_atom37_mask = restype_atom37_mask[prot["aatype"]] | |
| prot["atom37_atom_exists"] = residx_atom37_mask | |
| # As the atom naming is ambiguous for 7 of the 20 amino acids, provide | |
| # alternative ground truth coordinates where the naming is swapped | |
| restype_3 = [ | |
| residue_constants.restype_1to3[res] for res in residue_constants.restypes | |
| ] | |
| restype_3 += ["UNK"] | |
| # Matrices for renaming ambiguous atoms. | |
| all_matrices = {res: np.eye(14, dtype=np.float32) for res in restype_3} | |
| for resname, swap in residue_constants.residue_atom_renaming_swaps.items(): | |
| correspondences = np.arange(14) | |
| for source_atom_swap, target_atom_swap in swap.items(): | |
| source_index = residue_constants.restype_name_to_atom14_names[ | |
| resname].index(source_atom_swap) | |
| target_index = residue_constants.restype_name_to_atom14_names[ | |
| resname].index(target_atom_swap) | |
| correspondences[source_index] = target_index | |
| correspondences[target_index] = source_index | |
| renaming_matrix = np.zeros((14, 14), dtype=np.float32) | |
| for index, correspondence in enumerate(correspondences): | |
| renaming_matrix[index, correspondence] = 1. | |
| all_matrices[resname] = renaming_matrix.astype(np.float32) | |
| renaming_matrices = np.stack([all_matrices[restype] for restype in restype_3]) | |
| # Pick the transformation matrices for the given residue sequence | |
| # shape (num_res, 14, 14). | |
| renaming_transform = renaming_matrices[prot["aatype"]] | |
| # Apply it to the ground truth positions. shape (num_res, 14, 3). | |
| alternative_gt_positions = np.einsum("rac,rab->rbc", | |
| residx_atom14_gt_positions, | |
| renaming_transform) | |
| prot["atom14_alt_gt_positions"] = alternative_gt_positions | |
| # Create the mask for the alternative ground truth (differs from the | |
| # ground truth mask, if only one of the atoms in an ambiguous pair has a | |
| # ground truth position). | |
| alternative_gt_mask = np.einsum("ra,rab->rb", | |
| residx_atom14_gt_mask, | |
| renaming_transform) | |
| prot["atom14_alt_gt_exists"] = alternative_gt_mask | |
| # Create an ambiguous atoms mask. shape: (21, 14). | |
| restype_atom14_is_ambiguous = np.zeros((21, 14), dtype=np.float32) | |
| for resname, swap in residue_constants.residue_atom_renaming_swaps.items(): | |
| for atom_name1, atom_name2 in swap.items(): | |
| restype = residue_constants.restype_order[ | |
| residue_constants.restype_3to1[resname]] | |
| atom_idx1 = residue_constants.restype_name_to_atom14_names[resname].index( | |
| atom_name1) | |
| atom_idx2 = residue_constants.restype_name_to_atom14_names[resname].index( | |
| atom_name2) | |
| restype_atom14_is_ambiguous[restype, atom_idx1] = 1 | |
| restype_atom14_is_ambiguous[restype, atom_idx2] = 1 | |
| # From this create an ambiguous_mask for the given sequence. | |
| prot["atom14_atom_is_ambiguous"] = ( | |
| restype_atom14_is_ambiguous[prot["aatype"]]) | |
| return prot |