Spaces:
Runtime error
Runtime error
| # Copyright 2021 AlQuraishi Laboratory | |
| # Copyright 2021 DeepMind Technologies Limited | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from functools import partial | |
| import torch | |
| import torch.nn as nn | |
| from typing import Tuple, Optional | |
| from dockformer.model.primitives import Linear, LayerNorm | |
| from dockformer.utils.tensor_utils import add | |
| class StructureInputEmbedder(nn.Module): | |
| """ | |
| Embeds a subset of the input features. | |
| Implements a merge of Algorithms 3 and Algorithm 32. | |
| """ | |
| def __init__( | |
| self, | |
| protein_tf_dim: int, | |
| ligand_tf_dim: int, | |
| additional_tf_dim: int, | |
| ligand_bond_dim: int, | |
| c_z: int, | |
| c_m: int, | |
| relpos_k: int, | |
| prot_min_bin: float, | |
| prot_max_bin: float, | |
| prot_no_bins: int, | |
| lig_min_bin: float, | |
| lig_max_bin: float, | |
| lig_no_bins: int, | |
| inf: float = 1e8, | |
| **kwargs, | |
| ): | |
| """ | |
| Args: | |
| tf_dim: | |
| Final dimension of the target features | |
| c_z: | |
| Pair embedding dimension | |
| c_m: | |
| Single embedding dimension | |
| relpos_k: | |
| Window size used in relative positional encoding | |
| """ | |
| super(StructureInputEmbedder, self).__init__() | |
| self.tf_dim = protein_tf_dim + ligand_tf_dim + additional_tf_dim | |
| self.pair_tf_dim = ligand_bond_dim | |
| self.c_z = c_z | |
| self.c_m = c_m | |
| self.linear_tf_z_i = Linear(self.tf_dim, c_z) | |
| self.linear_tf_z_j = Linear(self.tf_dim, c_z) | |
| self.linear_tf_m = Linear(self.tf_dim, c_m) | |
| self.ligand_linear_bond_z = Linear(ligand_bond_dim, c_z) | |
| # RPE stuff | |
| self.relpos_k = relpos_k | |
| self.no_bins = 2 * relpos_k + 1 | |
| self.linear_relpos = Linear(self.no_bins, c_z) | |
| # Recycling stuff | |
| self.prot_min_bin = prot_min_bin | |
| self.prot_max_bin = prot_max_bin | |
| self.prot_no_bins = prot_no_bins | |
| self.lig_min_bin = lig_min_bin | |
| self.lig_max_bin = lig_max_bin | |
| self.lig_no_bins = lig_no_bins | |
| self.inf = inf | |
| self.prot_recycling_linear = Linear(self.prot_no_bins + 1, self.c_z) | |
| self.lig_recycling_linear = Linear(self.lig_no_bins, self.c_z) | |
| self.layer_norm_m = LayerNorm(self.c_m) | |
| self.layer_norm_z = LayerNorm(self.c_z) | |
| def relpos(self, ri: torch.Tensor): | |
| """ | |
| Computes relative positional encodings | |
| Implements Algorithm 4. | |
| Args: | |
| ri: | |
| "residue_index" features of shape [*, N] | |
| """ | |
| d = ri[..., None] - ri[..., None, :] | |
| boundaries = torch.arange( | |
| start=-self.relpos_k, end=self.relpos_k + 1, device=d.device | |
| ) | |
| reshaped_bins = boundaries.view(((1,) * len(d.shape)) + (len(boundaries),)) | |
| d = d[..., None] - reshaped_bins | |
| d = torch.abs(d) | |
| d = torch.argmin(d, dim=-1) | |
| d = nn.functional.one_hot(d, num_classes=len(boundaries)).float() | |
| d = d.to(ri.dtype) | |
| return self.linear_relpos(d) | |
| def _get_binned_distogram(self, x, min_bin, max_bin, no_bins, recycling_linear, prot_distogram_mask=None): | |
| # This squared method might become problematic in FP16 mode. | |
| bins = torch.linspace( | |
| min_bin, | |
| max_bin, | |
| no_bins, | |
| dtype=x.dtype, | |
| device=x.device, | |
| requires_grad=False, | |
| ) | |
| squared_bins = bins ** 2 | |
| upper = torch.cat( | |
| [squared_bins[1:], squared_bins.new_tensor([self.inf])], dim=-1 | |
| ) | |
| d = torch.sum((x[..., None, :] - x[..., None, :, :]) ** 2, dim=-1, keepdims=True) | |
| # [*, N, N, no_bins] | |
| d = ((d > squared_bins) * (d < upper)).type(x.dtype) | |
| # print("d shape", d.shape, d[0][0][:10]) | |
| if prot_distogram_mask is not None: | |
| expanded_d = torch.cat([d, torch.zeros(*d.shape[:-1], 1, device=d.device)], dim=-1) | |
| # Step 2: Create a mask where `input_positions_masked` is 0 | |
| # Use broadcasting and tensor operations directly without additional variables | |
| input_positions_mask = (prot_distogram_mask == 1).float() # Shape [N, crop_size] | |
| mask_i = input_positions_mask.unsqueeze(2) # Shape [N, crop_size, 1] | |
| mask_j = input_positions_mask.unsqueeze(1) # Shape [N, 1, crop_size] | |
| # Step 3: Combine masks for both [N, :, i, :] and [N, i, :, :] | |
| combined_mask = mask_i + mask_j # Shape [N, crop_size, crop_size] | |
| combined_mask = combined_mask.clamp(max=1) # Ensure binary mask | |
| # Step 4: Apply the mask | |
| # a. Set all but the last position in the `no_bins + 1` dimension to 0 where the mask is 1 | |
| expanded_d[..., :-1] *= (1 - combined_mask).unsqueeze(-1) # Shape [N, crop_size, crop_size, no_bins] | |
| # print("expanded_d shape1", expanded_d.shape, expanded_d[0][0][:10]) | |
| # b. Set the last position in the `no_bins + 1` dimension to 1 where the mask is 1 | |
| expanded_d[..., -1] += combined_mask # Shape [N, crop_size, crop_size, 1] | |
| d = expanded_d | |
| # print("expanded_d shape2", d.shape, d[0][0][:10]) | |
| return recycling_linear(d) | |
| def forward( | |
| self, | |
| token_mask: torch.Tensor, | |
| protein_mask: torch.Tensor, | |
| ligand_mask: torch.Tensor, | |
| target_feat: torch.Tensor, | |
| ligand_bonds_feat: torch.Tensor, | |
| input_positions: torch.Tensor, | |
| protein_residue_index: torch.Tensor, | |
| protein_distogram_mask: torch.Tensor, | |
| inplace_safe: bool = False, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Args: | |
| batch: Dict containing | |
| "protein_target_feat": | |
| Features of shape [*, N_res + N_lig_atoms, tf_dim] | |
| "residue_index": | |
| Features of shape [*, N_res] | |
| input_protein_coords: | |
| [*, N_res, 3] AF predicted C_beta coordinates supplied as input | |
| ligand_bonds_feat: | |
| [*, N_lig_atoms, N_lig_atoms, tf_dim] ligand bonds features | |
| Returns: | |
| single_emb: | |
| [*, N_res + N_lig_atoms, C_m] single embedding | |
| pair_emb: | |
| [*, N_res + N_lig_atoms, N_res + N_lig_atoms, C_z] pair embedding | |
| """ | |
| device = token_mask.device | |
| pair_protein_mask = protein_mask[..., None] * protein_mask[..., None, :] | |
| pair_ligand_mask = ligand_mask[..., None] * ligand_mask[..., None, :] | |
| # Single representation embedding - Algorithm 3 | |
| tf_m = self.linear_tf_m(target_feat) | |
| tf_m = self.layer_norm_m(tf_m) # previously this happend in the do_recycle function | |
| # Pair representation | |
| # protein pair embedding - Algorithm 3 | |
| # [*, N_res, c_z] | |
| tf_emb_i = self.linear_tf_z_i(target_feat) | |
| tf_emb_j = self.linear_tf_z_j(target_feat) | |
| pair_emb = torch.zeros(*pair_protein_mask.shape, self.c_z, device=device) | |
| pair_emb = add(pair_emb, tf_emb_i[..., None, :], inplace=inplace_safe) | |
| pair_emb = add(pair_emb, tf_emb_j[..., None, :, :], inplace=inplace_safe) | |
| # Apply relpos | |
| relpos = self.relpos(protein_residue_index.type(tf_emb_i.dtype)) | |
| pair_emb += pair_protein_mask[..., None] * relpos | |
| del relpos | |
| # apply ligand bonds | |
| ligand_bonds = self.ligand_linear_bond_z(ligand_bonds_feat) | |
| pair_emb += pair_ligand_mask[..., None] * ligand_bonds | |
| del ligand_bonds | |
| # before recycles, do z_norm, this previously was a part of the recycles | |
| pair_emb = self.layer_norm_z(pair_emb) | |
| # apply protein recycle | |
| prot_distogram_embed = self._get_binned_distogram(input_positions, self.prot_min_bin, self.prot_max_bin, | |
| self.prot_no_bins, self.prot_recycling_linear, | |
| protein_distogram_mask) | |
| pair_emb = add(pair_emb, prot_distogram_embed * pair_protein_mask.unsqueeze(-1), inplace_safe) | |
| del prot_distogram_embed | |
| # apply ligand recycle | |
| lig_distogram_embed = self._get_binned_distogram(input_positions, self.lig_min_bin, self.lig_max_bin, | |
| self.lig_no_bins, self.lig_recycling_linear) | |
| pair_emb = add(pair_emb, lig_distogram_embed * pair_ligand_mask.unsqueeze(-1), inplace_safe) | |
| del lig_distogram_embed | |
| return tf_m, pair_emb | |
| class RecyclingEmbedder(nn.Module): | |
| """ | |
| Embeds the output of an iteration of the model for recycling. | |
| Implements Algorithm 32. | |
| """ | |
| def __init__( | |
| self, | |
| c_m: int, | |
| c_z: int, | |
| min_bin: float, | |
| max_bin: float, | |
| no_bins: int, | |
| inf: float = 1e8, | |
| **kwargs, | |
| ): | |
| """ | |
| Args: | |
| c_m: | |
| Single channel dimension | |
| c_z: | |
| Pair embedding channel dimension | |
| min_bin: | |
| Smallest distogram bin (Angstroms) | |
| max_bin: | |
| Largest distogram bin (Angstroms) | |
| no_bins: | |
| Number of distogram bins | |
| """ | |
| super(RecyclingEmbedder, self).__init__() | |
| self.c_m = c_m | |
| self.c_z = c_z | |
| self.min_bin = min_bin | |
| self.max_bin = max_bin | |
| self.no_bins = no_bins | |
| self.inf = inf | |
| self.linear = Linear(self.no_bins, self.c_z) | |
| self.layer_norm_m = LayerNorm(self.c_m) | |
| self.layer_norm_z = LayerNorm(self.c_z) | |
| def forward( | |
| self, | |
| m: torch.Tensor, | |
| z: torch.Tensor, | |
| x: torch.Tensor, | |
| inplace_safe: bool = False, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Args: | |
| m: | |
| First row of the single embedding. [*, N_res, C_m] | |
| z: | |
| [*, N_res, N_res, C_z] pair embedding | |
| x: | |
| [*, N_res, 3] predicted C_beta coordinates | |
| Returns: | |
| m: | |
| [*, N_res, C_m] single embedding update | |
| z: | |
| [*, N_res, N_res, C_z] pair embedding update | |
| """ | |
| # [*, N, C_m] | |
| m_update = self.layer_norm_m(m) | |
| if(inplace_safe): | |
| m.copy_(m_update) | |
| m_update = m | |
| # [*, N, N, C_z] | |
| z_update = self.layer_norm_z(z) | |
| if(inplace_safe): | |
| z.copy_(z_update) | |
| z_update = z | |
| # This squared method might become problematic in FP16 mode. | |
| bins = torch.linspace( | |
| self.min_bin, | |
| self.max_bin, | |
| self.no_bins, | |
| dtype=x.dtype, | |
| device=x.device, | |
| requires_grad=False, | |
| ) | |
| squared_bins = bins ** 2 | |
| upper = torch.cat( | |
| [squared_bins[1:], squared_bins.new_tensor([self.inf])], dim=-1 | |
| ) | |
| d = torch.sum( | |
| (x[..., None, :] - x[..., None, :, :]) ** 2, dim=-1, keepdims=True | |
| ) | |
| # [*, N, N, no_bins] | |
| d = ((d > squared_bins) * (d < upper)).type(x.dtype) | |
| # [*, N, N, C_z] | |
| d = self.linear(d) | |
| z_update = add(z_update, d, inplace_safe) | |
| return m_update, z_update | |