Spaces:
Runtime error
Runtime error
| # Copyright 2021 AlQuraishi Laboratory | |
| # Copyright 2021 DeepMind Technologies Limited | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import math | |
| import sys | |
| import torch | |
| import torch.nn as nn | |
| from typing import Tuple, Sequence, Optional | |
| from functools import partial | |
| from abc import ABC, abstractmethod | |
| from dockformer.model.primitives import Linear, LayerNorm | |
| from dockformer.model.dropout import DropoutRowwise | |
| from dockformer.model.single_attention import SingleRowAttentionWithPairBias | |
| from dockformer.model.pair_transition import PairTransition | |
| from dockformer.model.triangular_attention import ( | |
| TriangleAttention, | |
| ) | |
| from dockformer.model.triangular_multiplicative_update import ( | |
| TriangleMultiplicationOutgoing, | |
| TriangleMultiplicationIncoming, | |
| ) | |
| from dockformer.utils.checkpointing import checkpoint_blocks | |
| from dockformer.utils.tensor_utils import add | |
| class SingleRepTransition(nn.Module): | |
| """ | |
| Feed-forward network applied to single representation activations after attention. | |
| Implements Algorithm 9 | |
| """ | |
| def __init__(self, c_m, n): | |
| """ | |
| Args: | |
| c_m: | |
| channel dimension | |
| n: | |
| Factor multiplied to c_m to obtain the hidden channel dimension | |
| """ | |
| super(SingleRepTransition, self).__init__() | |
| self.c_m = c_m | |
| self.n = n | |
| self.layer_norm = LayerNorm(self.c_m) | |
| self.linear_1 = Linear(self.c_m, self.n * self.c_m, init="relu") | |
| self.relu = nn.ReLU() | |
| self.linear_2 = Linear(self.n * self.c_m, self.c_m, init="final") | |
| def _transition(self, m, mask): | |
| m = self.layer_norm(m) | |
| m = self.linear_1(m) | |
| m = self.relu(m) | |
| m = self.linear_2(m) * mask | |
| return m | |
| def forward( | |
| self, | |
| m: torch.Tensor, | |
| mask: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Args: | |
| m: | |
| [*, N_res, C_m] activation after attention | |
| mask: | |
| [*, N_res, C_m] mask | |
| Returns: | |
| m: | |
| [*, N_res, C_m] activation update | |
| """ | |
| # DISCREPANCY: DeepMind forgets to apply the mask here. | |
| if mask is None: | |
| mask = m.new_ones(m.shape[:-1]) | |
| mask = mask.unsqueeze(-1) | |
| m = self._transition(m, mask) | |
| return m | |
| class PairStack(nn.Module): | |
| def __init__( | |
| self, | |
| c_z: int, | |
| c_hidden_mul: int, | |
| c_hidden_pair_att: int, | |
| no_heads_pair: int, | |
| transition_n: int, | |
| pair_dropout: float, | |
| inf: float, | |
| eps: float | |
| ): | |
| super(PairStack, self).__init__() | |
| self.tri_mul_out = TriangleMultiplicationOutgoing( | |
| c_z, | |
| c_hidden_mul, | |
| ) | |
| self.tri_mul_in = TriangleMultiplicationIncoming( | |
| c_z, | |
| c_hidden_mul, | |
| ) | |
| self.tri_att_start = TriangleAttention( | |
| c_z, | |
| c_hidden_pair_att, | |
| no_heads_pair, | |
| inf=inf, | |
| ) | |
| self.tri_att_end = TriangleAttention( | |
| c_z, | |
| c_hidden_pair_att, | |
| no_heads_pair, | |
| inf=inf, | |
| ) | |
| self.pair_transition = PairTransition( | |
| c_z, | |
| transition_n, | |
| ) | |
| self.ps_dropout_row_layer = DropoutRowwise(pair_dropout) | |
| def forward(self, | |
| z: torch.Tensor, | |
| pair_mask: torch.Tensor, | |
| use_lma: bool = False, | |
| inplace_safe: bool = False, | |
| _mask_trans: bool = True, | |
| ) -> torch.Tensor: | |
| # DeepMind doesn't mask these transitions in the source, so _mask_trans | |
| # should be disabled to better approximate the exact activations of | |
| # the original. | |
| pair_trans_mask = pair_mask if _mask_trans else None | |
| tmu_update = self.tri_mul_out( | |
| z, | |
| mask=pair_mask, | |
| inplace_safe=inplace_safe, | |
| _add_with_inplace=True, | |
| ) | |
| if (not inplace_safe): | |
| z = z + self.ps_dropout_row_layer(tmu_update) | |
| else: | |
| z = tmu_update | |
| del tmu_update | |
| tmu_update = self.tri_mul_in( | |
| z, | |
| mask=pair_mask, | |
| inplace_safe=inplace_safe, | |
| _add_with_inplace=True, | |
| ) | |
| if (not inplace_safe): | |
| z = z + self.ps_dropout_row_layer(tmu_update) | |
| else: | |
| z = tmu_update | |
| del tmu_update | |
| z = add(z, | |
| self.ps_dropout_row_layer( | |
| self.tri_att_start( | |
| z, | |
| mask=pair_mask, | |
| use_memory_efficient_kernel=False, | |
| use_lma=use_lma, | |
| ) | |
| ), | |
| inplace=inplace_safe, | |
| ) | |
| z = z.transpose(-2, -3) | |
| if (inplace_safe): | |
| z = z.contiguous() | |
| z = add(z, | |
| self.ps_dropout_row_layer( | |
| self.tri_att_end( | |
| z, | |
| mask=pair_mask.transpose(-1, -2), | |
| use_memory_efficient_kernel=False, | |
| use_lma=use_lma, | |
| ) | |
| ), | |
| inplace=inplace_safe, | |
| ) | |
| z = z.transpose(-2, -3) | |
| if (inplace_safe): | |
| z = z.contiguous() | |
| z = add(z, | |
| self.pair_transition( | |
| z, mask=pair_trans_mask, | |
| ), | |
| inplace=inplace_safe, | |
| ) | |
| return z | |
| class EvoformerBlock(nn.Module, ABC): | |
| def __init__(self, | |
| c_m: int, | |
| c_z: int, | |
| c_hidden_single_att: int, | |
| c_hidden_mul: int, | |
| c_hidden_pair_att: int, | |
| no_heads_single: int, | |
| no_heads_pair: int, | |
| transition_n: int, | |
| single_dropout: float, | |
| pair_dropout: float, | |
| inf: float, | |
| eps: float, | |
| ): | |
| super(EvoformerBlock, self).__init__() | |
| self.single_att_row = SingleRowAttentionWithPairBias( | |
| c_m=c_m, | |
| c_z=c_z, | |
| c_hidden=c_hidden_single_att, | |
| no_heads=no_heads_single, | |
| inf=inf, | |
| ) | |
| self.single_dropout_layer = DropoutRowwise(single_dropout) | |
| self.single_transition = SingleRepTransition( | |
| c_m=c_m, | |
| n=transition_n, | |
| ) | |
| self.pair_stack = PairStack( | |
| c_z=c_z, | |
| c_hidden_mul=c_hidden_mul, | |
| c_hidden_pair_att=c_hidden_pair_att, | |
| no_heads_pair=no_heads_pair, | |
| transition_n=transition_n, | |
| pair_dropout=pair_dropout, | |
| inf=inf, | |
| eps=eps | |
| ) | |
| def forward(self, | |
| m: Optional[torch.Tensor], | |
| z: Optional[torch.Tensor], | |
| single_mask: torch.Tensor, | |
| pair_mask: torch.Tensor, | |
| use_lma: bool = False, | |
| inplace_safe: bool = False, | |
| _mask_trans: bool = True, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| single_trans_mask = single_mask if _mask_trans else None | |
| input_tensors = [m, z] | |
| m, z = input_tensors | |
| z = self.pair_stack( | |
| z=z, | |
| pair_mask=pair_mask, | |
| use_lma=use_lma, | |
| inplace_safe=inplace_safe, | |
| _mask_trans=_mask_trans, | |
| ) | |
| m = add(m, | |
| self.single_dropout_layer( | |
| self.single_att_row( | |
| m, | |
| z=z, | |
| mask=single_mask, | |
| use_memory_efficient_kernel=False, | |
| use_lma=use_lma, | |
| ) | |
| ), | |
| inplace=inplace_safe, | |
| ) | |
| m = add(m, self.single_transition(m, mask=single_mask), inplace=inplace_safe) | |
| return m, z | |
| class EvoformerStack(nn.Module): | |
| """ | |
| Main Evoformer trunk. | |
| Implements Algorithm 6. | |
| """ | |
| def __init__( | |
| self, | |
| c_m: int, | |
| c_z: int, | |
| c_hidden_single_att: int, | |
| c_hidden_mul: int, | |
| c_hidden_pair_att: int, | |
| c_s: int, | |
| no_heads_single: int, | |
| no_heads_pair: int, | |
| no_blocks: int, | |
| transition_n: int, | |
| single_dropout: float, | |
| pair_dropout: float, | |
| blocks_per_ckpt: int, | |
| inf: float, | |
| eps: float, | |
| clear_cache_between_blocks: bool = False, | |
| **kwargs, | |
| ): | |
| """ | |
| Args: | |
| c_m: | |
| single channel dimension | |
| c_z: | |
| Pair channel dimension | |
| c_hidden_single_att: | |
| Hidden dimension in single representation attention | |
| c_hidden_mul: | |
| Hidden dimension in multiplicative updates | |
| c_hidden_pair_att: | |
| Hidden dimension in triangular attention | |
| c_s: | |
| Channel dimension of the output "single" embedding | |
| no_heads_single: | |
| Number of heads used for single attention | |
| no_heads_pair: | |
| Number of heads used for pair attention | |
| no_blocks: | |
| Number of Evoformer blocks in the stack | |
| transition_n: | |
| Factor by which to multiply c_m to obtain the SingleTransition | |
| hidden dimension | |
| single_dropout: | |
| Dropout rate for single activations | |
| pair_dropout: | |
| Dropout used for pair activations | |
| blocks_per_ckpt: | |
| Number of Evoformer blocks in each activation checkpoint | |
| clear_cache_between_blocks: | |
| Whether to clear CUDA's GPU memory cache between blocks of the | |
| stack. Slows down each block but can reduce fragmentation | |
| """ | |
| super(EvoformerStack, self).__init__() | |
| self.blocks_per_ckpt = blocks_per_ckpt | |
| self.clear_cache_between_blocks = clear_cache_between_blocks | |
| self.blocks = nn.ModuleList() | |
| for _ in range(no_blocks): | |
| block = EvoformerBlock( | |
| c_m=c_m, | |
| c_z=c_z, | |
| c_hidden_single_att=c_hidden_single_att, | |
| c_hidden_mul=c_hidden_mul, | |
| c_hidden_pair_att=c_hidden_pair_att, | |
| no_heads_single=no_heads_single, | |
| no_heads_pair=no_heads_pair, | |
| transition_n=transition_n, | |
| single_dropout=single_dropout, | |
| pair_dropout=pair_dropout, | |
| inf=inf, | |
| eps=eps, | |
| ) | |
| self.blocks.append(block) | |
| self.linear = Linear(c_m, c_s) | |
| def _prep_blocks(self, | |
| use_lma: bool, | |
| single_mask: Optional[torch.Tensor], | |
| pair_mask: Optional[torch.Tensor], | |
| inplace_safe: bool, | |
| _mask_trans: bool, | |
| ): | |
| blocks = [ | |
| partial( | |
| b, | |
| single_mask=single_mask, | |
| pair_mask=pair_mask, | |
| use_lma=use_lma, | |
| inplace_safe=inplace_safe, | |
| _mask_trans=_mask_trans, | |
| ) | |
| for b in self.blocks | |
| ] | |
| if self.clear_cache_between_blocks: | |
| def block_with_cache_clear(block, *args, **kwargs): | |
| torch.cuda.empty_cache() | |
| return block(*args, **kwargs) | |
| blocks = [partial(block_with_cache_clear, b) for b in blocks] | |
| return blocks | |
| def forward(self, | |
| m: torch.Tensor, | |
| z: torch.Tensor, | |
| single_mask: torch.Tensor, | |
| pair_mask: torch.Tensor, | |
| use_lma: bool = False, | |
| inplace_safe: bool = False, | |
| _mask_trans: bool = True, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """ | |
| Args: | |
| m: | |
| [*, N_res, C_m] single embedding | |
| z: | |
| [*, N_res, N_res, C_z] pair embedding | |
| single_mask: | |
| [*, N_res] single mask | |
| pair_mask: | |
| [*, N_res, N_res] pair mask | |
| use_lma: | |
| Whether to use low-memory attention during inference. | |
| Returns: | |
| m: | |
| [*, N_res, C_m] single embedding | |
| z: | |
| [*, N_res, N_res, C_z] pair embedding | |
| s: | |
| [*, N_res, C_s] single embedding after linear layer | |
| """ | |
| blocks = self._prep_blocks( | |
| use_lma=use_lma, | |
| single_mask=single_mask, | |
| pair_mask=pair_mask, | |
| inplace_safe=inplace_safe, | |
| _mask_trans=_mask_trans, | |
| ) | |
| blocks_per_ckpt = self.blocks_per_ckpt | |
| if(not torch.is_grad_enabled()): | |
| blocks_per_ckpt = None | |
| m, z = checkpoint_blocks( | |
| blocks, | |
| args=(m, z), | |
| blocks_per_ckpt=blocks_per_ckpt, | |
| ) | |
| s = self.linear(m) | |
| return m, z, s | |