Spaces:
Running
on
T4
Running
on
T4
| # Copyright 2021 DeepMind Technologies Limited | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Modules and code used in the core part of AlphaFold. | |
| The structure generation code is in 'folding.py'. | |
| """ | |
| import functools | |
| from alphafold.common import residue_constants | |
| from alphafold.model import all_atom | |
| from alphafold.model import common_modules | |
| from alphafold.model import folding | |
| from alphafold.model import layer_stack | |
| from alphafold.model import lddt | |
| from alphafold.model import mapping | |
| from alphafold.model import prng | |
| from alphafold.model import quat_affine | |
| from alphafold.model import utils | |
| import haiku as hk | |
| import jax | |
| import jax.numpy as jnp | |
| from alphafold.model.r3 import Rigids, Rots, Vecs | |
| def softmax_cross_entropy(logits, labels): | |
| """Computes softmax cross entropy given logits and one-hot class labels.""" | |
| loss = -jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1) | |
| return jnp.asarray(loss) | |
| def sigmoid_cross_entropy(logits, labels): | |
| """Computes sigmoid cross entropy given logits and multiple class labels.""" | |
| log_p = jax.nn.log_sigmoid(logits) | |
| # log(1 - sigmoid(x)) = log_sigmoid(-x), the latter is more numerically stable | |
| log_not_p = jax.nn.log_sigmoid(-logits) | |
| loss = -labels * log_p - (1. - labels) * log_not_p | |
| return jnp.asarray(loss) | |
| def apply_dropout(*, tensor, safe_key, rate, is_training, broadcast_dim=None): | |
| """Applies dropout to a tensor.""" | |
| if is_training: # and rate != 0.0: | |
| shape = list(tensor.shape) | |
| if broadcast_dim is not None: | |
| shape[broadcast_dim] = 1 | |
| keep_rate = 1.0 - rate | |
| keep = jax.random.bernoulli(safe_key.get(), keep_rate, shape=shape) | |
| return keep * tensor / keep_rate | |
| else: | |
| return tensor | |
| def dropout_wrapper(module, | |
| input_act, | |
| mask, | |
| safe_key, | |
| global_config, | |
| output_act=None, | |
| is_training=True, | |
| scale_rate=1.0, | |
| **kwargs): | |
| """Applies module + dropout + residual update.""" | |
| if output_act is None: | |
| output_act = input_act | |
| gc = global_config | |
| residual = module(input_act, mask, is_training=is_training, **kwargs) | |
| dropout_rate = 0.0 if gc.deterministic else module.config.dropout_rate | |
| if module.config.shared_dropout: | |
| if module.config.orientation == 'per_row': | |
| broadcast_dim = 0 | |
| else: | |
| broadcast_dim = 1 | |
| else: | |
| broadcast_dim = None | |
| residual = apply_dropout(tensor=residual, | |
| safe_key=safe_key, | |
| rate=dropout_rate * scale_rate, | |
| is_training=is_training, | |
| broadcast_dim=broadcast_dim) | |
| new_act = output_act + residual | |
| return new_act | |
| def create_extra_msa_feature(batch): | |
| """Expand extra_msa into 1hot and concat with other extra msa features. | |
| We do this as late as possible as the one_hot extra msa can be very large. | |
| Arguments: | |
| batch: a dictionary with the following keys: | |
| * 'extra_msa': [N_extra_seq, N_res] MSA that wasn't selected as a cluster | |
| centre. Note, that this is not one-hot encoded. | |
| * 'extra_has_deletion': [N_extra_seq, N_res] Whether there is a deletion to | |
| the left of each position in the extra MSA. | |
| * 'extra_deletion_value': [N_extra_seq, N_res] The number of deletions to | |
| the left of each position in the extra MSA. | |
| Returns: | |
| Concatenated tensor of extra MSA features. | |
| """ | |
| # 23 = 20 amino acids + 'X' for unknown + gap + bert mask | |
| msa_1hot = jax.nn.one_hot(batch['extra_msa'], 23) | |
| msa_feat = [msa_1hot, | |
| jnp.expand_dims(batch['extra_has_deletion'], axis=-1), | |
| jnp.expand_dims(batch['extra_deletion_value'], axis=-1)] | |
| return jnp.concatenate(msa_feat, axis=-1) | |
| class AlphaFoldIteration(hk.Module): | |
| """A single recycling iteration of AlphaFold architecture. | |
| Computes ensembled (averaged) representations from the provided features. | |
| These representations are then passed to the various heads | |
| that have been requested by the configuration file. Each head also returns a | |
| loss which is combined as a weighted sum to produce the total loss. | |
| Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 3-22 | |
| """ | |
| def __init__(self, config, global_config, name='alphafold_iteration'): | |
| super().__init__(name=name) | |
| self.config = config | |
| self.global_config = global_config | |
| def __call__(self, | |
| ensembled_batch, | |
| non_ensembled_batch, | |
| is_training, | |
| compute_loss=False, | |
| ensemble_representations=False, | |
| return_representations=False): | |
| num_ensemble = jnp.asarray(ensembled_batch['seq_length'].shape[0]) | |
| if not ensemble_representations: | |
| assert ensembled_batch['seq_length'].shape[0] == 1 | |
| def slice_batch(i): | |
| b = {k: v[i] for k, v in ensembled_batch.items()} | |
| b.update(non_ensembled_batch) | |
| return b | |
| # Compute representations for each batch element and average. | |
| evoformer_module = EmbeddingsAndEvoformer( | |
| self.config.embeddings_and_evoformer, self.global_config) | |
| batch0 = slice_batch(0) | |
| representations = evoformer_module(batch0, is_training) | |
| # MSA representations are not ensembled so | |
| # we don't pass tensor into the loop. | |
| msa_representation = representations['msa'] | |
| del representations['msa'] | |
| # Average the representations (except MSA) over the batch dimension. | |
| if ensemble_representations: | |
| def body(x): | |
| """Add one element to the representations ensemble.""" | |
| i, current_representations = x | |
| feats = slice_batch(i) | |
| representations_update = evoformer_module( | |
| feats, is_training) | |
| new_representations = {} | |
| for k in current_representations: | |
| new_representations[k] = ( | |
| current_representations[k] + representations_update[k]) | |
| return i+1, new_representations | |
| if hk.running_init(): | |
| # When initializing the Haiku module, run one iteration of the | |
| # while_loop to initialize the Haiku modules used in `body`. | |
| _, representations = body((1, representations)) | |
| else: | |
| _, representations = hk.while_loop( | |
| lambda x: x[0] < num_ensemble, | |
| body, | |
| (1, representations)) | |
| for k in representations: | |
| if k != 'msa': | |
| representations[k] /= num_ensemble.astype(representations[k].dtype) | |
| representations['msa'] = msa_representation | |
| batch = batch0 # We are not ensembled from here on. | |
| if jnp.issubdtype(ensembled_batch['aatype'].dtype, jnp.integer): | |
| _, num_residues = ensembled_batch['aatype'].shape | |
| else: | |
| _, num_residues, _ = ensembled_batch['aatype'].shape | |
| if self.config.use_struct: | |
| struct_module = folding.StructureModule | |
| else: | |
| struct_module = folding.dummy | |
| heads = {} | |
| for head_name, head_config in sorted(self.config.heads.items()): | |
| if not head_config.weight: | |
| continue # Do not instantiate zero-weight heads. | |
| head_factory = { | |
| 'masked_msa': MaskedMsaHead, | |
| 'distogram': DistogramHead, | |
| 'structure_module': functools.partial(struct_module, compute_loss=compute_loss), | |
| 'predicted_lddt': PredictedLDDTHead, | |
| 'predicted_aligned_error': PredictedAlignedErrorHead, | |
| 'experimentally_resolved': ExperimentallyResolvedHead, | |
| }[head_name] | |
| heads[head_name] = (head_config, | |
| head_factory(head_config, self.global_config)) | |
| total_loss = 0. | |
| ret = {} | |
| ret['representations'] = representations | |
| def loss(module, head_config, ret, name, filter_ret=True): | |
| if filter_ret: | |
| value = ret[name] | |
| else: | |
| value = ret | |
| loss_output = module.loss(value, batch) | |
| ret[name].update(loss_output) | |
| loss = head_config.weight * ret[name]['loss'] | |
| return loss | |
| for name, (head_config, module) in heads.items(): | |
| # Skip PredictedLDDTHead and PredictedAlignedErrorHead until | |
| # StructureModule is executed. | |
| if name in ('predicted_lddt', 'predicted_aligned_error'): | |
| continue | |
| else: | |
| ret[name] = module(representations, batch, is_training) | |
| if 'representations' in ret[name]: | |
| # Extra representations from the head. Used by the structure module | |
| # to provide activations for the PredictedLDDTHead. | |
| representations.update(ret[name].pop('representations')) | |
| if compute_loss: | |
| total_loss += loss(module, head_config, ret, name) | |
| if self.config.use_struct: | |
| if self.config.heads.get('predicted_lddt.weight', 0.0): | |
| # Add PredictedLDDTHead after StructureModule executes. | |
| name = 'predicted_lddt' | |
| # Feed all previous results to give access to structure_module result. | |
| head_config, module = heads[name] | |
| ret[name] = module(representations, batch, is_training) | |
| if compute_loss: | |
| total_loss += loss(module, head_config, ret, name, filter_ret=False) | |
| if ('predicted_aligned_error' in self.config.heads | |
| and self.config.heads.get('predicted_aligned_error.weight', 0.0)): | |
| # Add PredictedAlignedErrorHead after StructureModule executes. | |
| name = 'predicted_aligned_error' | |
| # Feed all previous results to give access to structure_module result. | |
| head_config, module = heads[name] | |
| ret[name] = module(representations, batch, is_training) | |
| if compute_loss: | |
| total_loss += loss(module, head_config, ret, name, filter_ret=False) | |
| if compute_loss: | |
| return ret, total_loss | |
| else: | |
| return ret | |
| class AlphaFold(hk.Module): | |
| """AlphaFold model with recycling. | |
| Jumper et al. (2021) Suppl. Alg. 2 "Inference" | |
| """ | |
| def __init__(self, config, name='alphafold'): | |
| super().__init__(name=name) | |
| self.config = config | |
| self.global_config = config.global_config | |
| def __call__( | |
| self, | |
| batch, | |
| is_training, | |
| compute_loss=False, | |
| ensemble_representations=False, | |
| return_representations=False): | |
| """Run the AlphaFold model. | |
| Arguments: | |
| batch: Dictionary with inputs to the AlphaFold model. | |
| is_training: Whether the system is in training or inference mode. | |
| compute_loss: Whether to compute losses (requires extra features | |
| to be present in the batch and knowing the true structure). | |
| ensemble_representations: Whether to use ensembling of representations. | |
| return_representations: Whether to also return the intermediate | |
| representations. | |
| Returns: | |
| When compute_loss is True: | |
| a tuple of loss and output of AlphaFoldIteration. | |
| When compute_loss is False: | |
| just output of AlphaFoldIteration. | |
| The output of AlphaFoldIteration is a nested dictionary containing | |
| predictions from the various heads. | |
| """ | |
| if "scale_rate" not in batch: | |
| batch["scale_rate"] = jnp.ones((1,)) | |
| impl = AlphaFoldIteration(self.config, self.global_config) | |
| if jnp.issubdtype(batch['aatype'].dtype, jnp.integer): | |
| batch_size, num_residues = batch['aatype'].shape | |
| else: | |
| batch_size, num_residues, _ = batch['aatype'].shape | |
| def get_prev(ret): | |
| new_prev = { | |
| 'prev_msa_first_row': ret['representations']['msa_first_row'], | |
| 'prev_pair': ret['representations']['pair'], | |
| 'prev_dgram': ret["distogram"]["logits"], | |
| } | |
| if self.config.use_struct: | |
| new_prev.update({'prev_pos': ret['structure_module']['final_atom_positions'], | |
| 'prev_plddt': ret["predicted_lddt"]["logits"]}) | |
| if "predicted_aligned_error" in ret: | |
| new_prev["prev_pae"] = ret["predicted_aligned_error"]["logits"] | |
| if not self.config.backprop_recycle: | |
| for k in ["prev_pos","prev_msa_first_row","prev_pair"]: | |
| if k in new_prev: | |
| new_prev[k] = jax.lax.stop_gradient(new_prev[k]) | |
| return new_prev | |
| def do_call(prev, | |
| recycle_idx, | |
| compute_loss=compute_loss): | |
| if self.config.resample_msa_in_recycling: | |
| num_ensemble = batch_size // (self.config.num_recycle + 1) | |
| def slice_recycle_idx(x): | |
| start = recycle_idx * num_ensemble | |
| size = num_ensemble | |
| return jax.lax.dynamic_slice_in_dim(x, start, size, axis=0) | |
| ensembled_batch = jax.tree_map(slice_recycle_idx, batch) | |
| else: | |
| num_ensemble = batch_size | |
| ensembled_batch = batch | |
| non_ensembled_batch = jax.tree_map(lambda x: x, prev) | |
| return impl(ensembled_batch=ensembled_batch, | |
| non_ensembled_batch=non_ensembled_batch, | |
| is_training=is_training, | |
| compute_loss=compute_loss, | |
| ensemble_representations=ensemble_representations) | |
| emb_config = self.config.embeddings_and_evoformer | |
| prev = { | |
| 'prev_msa_first_row': jnp.zeros([num_residues, emb_config.msa_channel]), | |
| 'prev_pair': jnp.zeros([num_residues, num_residues, emb_config.pair_channel]), | |
| 'prev_dgram': jnp.zeros([num_residues, num_residues, 64]), | |
| } | |
| if self.config.use_struct: | |
| prev.update({'prev_pos': jnp.zeros([num_residues, residue_constants.atom_type_num, 3]), | |
| 'prev_plddt': jnp.zeros([num_residues, 50]), | |
| 'prev_pae': jnp.zeros([num_residues, num_residues, 64])}) | |
| for k in ["pos","msa_first_row","pair","dgram"]: | |
| if f"init_{k}" in batch: prev[f"prev_{k}"] = batch[f"init_{k}"][0] | |
| if self.config.num_recycle: | |
| if 'num_iter_recycling' in batch: | |
| # Training time: num_iter_recycling is in batch. | |
| # The value for each ensemble batch is the same, so arbitrarily taking | |
| # 0-th. | |
| num_iter = batch['num_iter_recycling'][0] | |
| # Add insurance that we will not run more | |
| # recyclings than the model is configured to run. | |
| num_iter = jnp.minimum(num_iter, self.config.num_recycle) | |
| else: | |
| # Eval mode or tests: use the maximum number of iterations. | |
| num_iter = self.config.num_recycle | |
| def add_prev(p,p_): | |
| p_["prev_dgram"] += p["prev_dgram"] | |
| if self.config.use_struct: | |
| p_["prev_plddt"] += p["prev_plddt"] | |
| p_["prev_pae"] += p["prev_pae"] | |
| return p_ | |
| ############################################################## | |
| def body(p, i): | |
| p_ = get_prev(do_call(p, recycle_idx=i, compute_loss=False)) | |
| if self.config.add_prev: | |
| p_ = add_prev(p, p_) | |
| return p_, None | |
| if hk.running_init(): | |
| prev,_ = body(prev, 0) | |
| else: | |
| prev,_ = hk.scan(body, prev, jnp.arange(num_iter)) | |
| ############################################################## | |
| else: | |
| num_iter = 0 | |
| ret = do_call(prev=prev, recycle_idx=num_iter) | |
| if self.config.add_prev: | |
| prev_ = get_prev(ret) | |
| if compute_loss: | |
| ret = ret[0], [ret[1]] | |
| if not return_representations: | |
| del (ret[0] if compute_loss else ret)['representations'] # pytype: disable=unsupported-operands | |
| if self.config.add_prev and num_iter > 0: | |
| prev_ = add_prev(prev, prev_) | |
| ret["distogram"]["logits"] = prev_["prev_dgram"]/(num_iter+1) | |
| if self.config.use_struct: | |
| ret["predicted_lddt"]["logits"] = prev_["prev_plddt"]/(num_iter+1) | |
| if "predicted_aligned_error" in ret: | |
| ret["predicted_aligned_error"]["logits"] = prev_["prev_pae"]/(num_iter+1) | |
| return ret | |
| class TemplatePairStack(hk.Module): | |
| """Pair stack for the templates. | |
| Jumper et al. (2021) Suppl. Alg. 16 "TemplatePairStack" | |
| """ | |
| def __init__(self, config, global_config, name='template_pair_stack'): | |
| super().__init__(name=name) | |
| self.config = config | |
| self.global_config = global_config | |
| def __call__(self, pair_act, pair_mask, is_training, safe_key=None, scale_rate=1.0): | |
| """Builds TemplatePairStack module. | |
| Arguments: | |
| pair_act: Pair activations for single template, shape [N_res, N_res, c_t]. | |
| pair_mask: Pair mask, shape [N_res, N_res]. | |
| is_training: Whether the module is in training mode. | |
| safe_key: Safe key object encapsulating the random number generation key. | |
| Returns: | |
| Updated pair_act, shape [N_res, N_res, c_t]. | |
| """ | |
| if safe_key is None: | |
| safe_key = prng.SafeKey(hk.next_rng_key()) | |
| gc = self.global_config | |
| c = self.config | |
| if not c.num_block: | |
| return pair_act | |
| def block(x): | |
| """One block of the template pair stack.""" | |
| pair_act, safe_key = x | |
| dropout_wrapper_fn = functools.partial( | |
| dropout_wrapper, is_training=is_training, global_config=gc, scale_rate=scale_rate) | |
| safe_key, *sub_keys = safe_key.split(6) | |
| sub_keys = iter(sub_keys) | |
| pair_act = dropout_wrapper_fn( | |
| TriangleAttention(c.triangle_attention_starting_node, gc, | |
| name='triangle_attention_starting_node'), | |
| pair_act, | |
| pair_mask, | |
| next(sub_keys)) | |
| pair_act = dropout_wrapper_fn( | |
| TriangleAttention(c.triangle_attention_ending_node, gc, | |
| name='triangle_attention_ending_node'), | |
| pair_act, | |
| pair_mask, | |
| next(sub_keys)) | |
| pair_act = dropout_wrapper_fn( | |
| TriangleMultiplication(c.triangle_multiplication_outgoing, gc, | |
| name='triangle_multiplication_outgoing'), | |
| pair_act, | |
| pair_mask, | |
| next(sub_keys)) | |
| pair_act = dropout_wrapper_fn( | |
| TriangleMultiplication(c.triangle_multiplication_incoming, gc, | |
| name='triangle_multiplication_incoming'), | |
| pair_act, | |
| pair_mask, | |
| next(sub_keys)) | |
| pair_act = dropout_wrapper_fn( | |
| Transition(c.pair_transition, gc, name='pair_transition'), | |
| pair_act, | |
| pair_mask, | |
| next(sub_keys)) | |
| return pair_act, safe_key | |
| if gc.use_remat: | |
| block = hk.remat(block) | |
| res_stack = layer_stack.layer_stack(c.num_block)(block) | |
| pair_act, safe_key = res_stack((pair_act, safe_key)) | |
| return pair_act | |
| class Transition(hk.Module): | |
| """Transition layer. | |
| Jumper et al. (2021) Suppl. Alg. 9 "MSATransition" | |
| Jumper et al. (2021) Suppl. Alg. 15 "PairTransition" | |
| """ | |
| def __init__(self, config, global_config, name='transition_block'): | |
| super().__init__(name=name) | |
| self.config = config | |
| self.global_config = global_config | |
| def __call__(self, act, mask, is_training=True): | |
| """Builds Transition module. | |
| Arguments: | |
| act: A tensor of queries of size [batch_size, N_res, N_channel]. | |
| mask: A tensor denoting the mask of size [batch_size, N_res]. | |
| is_training: Whether the module is in training mode. | |
| Returns: | |
| A float32 tensor of size [batch_size, N_res, N_channel]. | |
| """ | |
| _, _, nc = act.shape | |
| num_intermediate = int(nc * self.config.num_intermediate_factor) | |
| mask = jnp.expand_dims(mask, axis=-1) | |
| act = hk.LayerNorm( | |
| axis=[-1], | |
| create_scale=True, | |
| create_offset=True, | |
| name='input_layer_norm')( | |
| act) | |
| transition_module = hk.Sequential([ | |
| common_modules.Linear( | |
| num_intermediate, | |
| initializer='relu', | |
| name='transition1'), jax.nn.relu, | |
| common_modules.Linear( | |
| nc, | |
| initializer=utils.final_init(self.global_config), | |
| name='transition2') | |
| ]) | |
| act = mapping.inference_subbatch( | |
| transition_module, | |
| self.global_config.subbatch_size, | |
| batched_args=[act], | |
| nonbatched_args=[], | |
| low_memory=not is_training) | |
| return act | |
| def glorot_uniform(): | |
| return hk.initializers.VarianceScaling(scale=1.0, | |
| mode='fan_avg', | |
| distribution='uniform') | |
| class Attention(hk.Module): | |
| """Multihead attention.""" | |
| def __init__(self, config, global_config, output_dim, name='attention'): | |
| super().__init__(name=name) | |
| self.config = config | |
| self.global_config = global_config | |
| self.output_dim = output_dim | |
| def __call__(self, q_data, m_data, bias, nonbatched_bias=None): | |
| """Builds Attention module. | |
| Arguments: | |
| q_data: A tensor of queries, shape [batch_size, N_queries, q_channels]. | |
| m_data: A tensor of memories from which the keys and values are | |
| projected, shape [batch_size, N_keys, m_channels]. | |
| bias: A bias for the attention, shape [batch_size, N_queries, N_keys]. | |
| nonbatched_bias: Shared bias, shape [N_queries, N_keys]. | |
| Returns: | |
| A float32 tensor of shape [batch_size, N_queries, output_dim]. | |
| """ | |
| # Sensible default for when the config keys are missing | |
| key_dim = self.config.get('key_dim', int(q_data.shape[-1])) | |
| value_dim = self.config.get('value_dim', int(m_data.shape[-1])) | |
| num_head = self.config.num_head | |
| assert key_dim % num_head == 0 | |
| assert value_dim % num_head == 0 | |
| key_dim = key_dim // num_head | |
| value_dim = value_dim // num_head | |
| q_weights = hk.get_parameter( | |
| 'query_w', shape=(q_data.shape[-1], num_head, key_dim), | |
| init=glorot_uniform()) | |
| k_weights = hk.get_parameter( | |
| 'key_w', shape=(m_data.shape[-1], num_head, key_dim), | |
| init=glorot_uniform()) | |
| v_weights = hk.get_parameter( | |
| 'value_w', shape=(m_data.shape[-1], num_head, value_dim), | |
| init=glorot_uniform()) | |
| q = jnp.einsum('bqa,ahc->bqhc', q_data, q_weights) * key_dim**(-0.5) | |
| k = jnp.einsum('bka,ahc->bkhc', m_data, k_weights) | |
| v = jnp.einsum('bka,ahc->bkhc', m_data, v_weights) | |
| logits = jnp.einsum('bqhc,bkhc->bhqk', q, k) + bias | |
| if nonbatched_bias is not None: | |
| logits += jnp.expand_dims(nonbatched_bias, axis=0) | |
| weights = jax.nn.softmax(logits) | |
| weighted_avg = jnp.einsum('bhqk,bkhc->bqhc', weights, v) | |
| if self.global_config.zero_init: | |
| init = hk.initializers.Constant(0.0) | |
| else: | |
| init = glorot_uniform() | |
| if self.config.gating: | |
| gating_weights = hk.get_parameter( | |
| 'gating_w', | |
| shape=(q_data.shape[-1], num_head, value_dim), | |
| init=hk.initializers.Constant(0.0)) | |
| gating_bias = hk.get_parameter( | |
| 'gating_b', | |
| shape=(num_head, value_dim), | |
| init=hk.initializers.Constant(1.0)) | |
| gate_values = jnp.einsum('bqc, chv->bqhv', q_data, | |
| gating_weights) + gating_bias | |
| gate_values = jax.nn.sigmoid(gate_values) | |
| weighted_avg *= gate_values | |
| o_weights = hk.get_parameter( | |
| 'output_w', shape=(num_head, value_dim, self.output_dim), | |
| init=init) | |
| o_bias = hk.get_parameter('output_b', shape=(self.output_dim,), | |
| init=hk.initializers.Constant(0.0)) | |
| output = jnp.einsum('bqhc,hco->bqo', weighted_avg, o_weights) + o_bias | |
| return output | |
| class GlobalAttention(hk.Module): | |
| """Global attention. | |
| Jumper et al. (2021) Suppl. Alg. 19 "MSAColumnGlobalAttention" lines 2-7 | |
| """ | |
| def __init__(self, config, global_config, output_dim, name='attention'): | |
| super().__init__(name=name) | |
| self.config = config | |
| self.global_config = global_config | |
| self.output_dim = output_dim | |
| def __call__(self, q_data, m_data, q_mask, bias): | |
| """Builds GlobalAttention module. | |
| Arguments: | |
| q_data: A tensor of queries with size [batch_size, N_queries, | |
| q_channels] | |
| m_data: A tensor of memories from which the keys and values | |
| projected. Size [batch_size, N_keys, m_channels] | |
| q_mask: A binary mask for q_data with zeros in the padded sequence | |
| elements and ones otherwise. Size [batch_size, N_queries, q_channels] | |
| (or broadcastable to this shape). | |
| bias: A bias for the attention. | |
| Returns: | |
| A float32 tensor of size [batch_size, N_queries, output_dim]. | |
| """ | |
| # Sensible default for when the config keys are missing | |
| key_dim = self.config.get('key_dim', int(q_data.shape[-1])) | |
| value_dim = self.config.get('value_dim', int(m_data.shape[-1])) | |
| num_head = self.config.num_head | |
| assert key_dim % num_head == 0 | |
| assert value_dim % num_head == 0 | |
| key_dim = key_dim // num_head | |
| value_dim = value_dim // num_head | |
| q_weights = hk.get_parameter( | |
| 'query_w', shape=(q_data.shape[-1], num_head, key_dim), | |
| init=glorot_uniform()) | |
| k_weights = hk.get_parameter( | |
| 'key_w', shape=(m_data.shape[-1], key_dim), | |
| init=glorot_uniform()) | |
| v_weights = hk.get_parameter( | |
| 'value_w', shape=(m_data.shape[-1], value_dim), | |
| init=glorot_uniform()) | |
| v = jnp.einsum('bka,ac->bkc', m_data, v_weights) | |
| q_avg = utils.mask_mean(q_mask, q_data, axis=1) | |
| q = jnp.einsum('ba,ahc->bhc', q_avg, q_weights) * key_dim**(-0.5) | |
| k = jnp.einsum('bka,ac->bkc', m_data, k_weights) | |
| bias = (1e9 * (q_mask[:, None, :, 0] - 1.)) | |
| logits = jnp.einsum('bhc,bkc->bhk', q, k) + bias | |
| weights = jax.nn.softmax(logits) | |
| weighted_avg = jnp.einsum('bhk,bkc->bhc', weights, v) | |
| if self.global_config.zero_init: | |
| init = hk.initializers.Constant(0.0) | |
| else: | |
| init = glorot_uniform() | |
| o_weights = hk.get_parameter( | |
| 'output_w', shape=(num_head, value_dim, self.output_dim), | |
| init=init) | |
| o_bias = hk.get_parameter('output_b', shape=(self.output_dim,), | |
| init=hk.initializers.Constant(0.0)) | |
| if self.config.gating: | |
| gating_weights = hk.get_parameter( | |
| 'gating_w', | |
| shape=(q_data.shape[-1], num_head, value_dim), | |
| init=hk.initializers.Constant(0.0)) | |
| gating_bias = hk.get_parameter( | |
| 'gating_b', | |
| shape=(num_head, value_dim), | |
| init=hk.initializers.Constant(1.0)) | |
| gate_values = jnp.einsum('bqc, chv->bqhv', q_data, gating_weights) | |
| gate_values = jax.nn.sigmoid(gate_values + gating_bias) | |
| weighted_avg = weighted_avg[:, None] * gate_values | |
| output = jnp.einsum('bqhc,hco->bqo', weighted_avg, o_weights) + o_bias | |
| else: | |
| output = jnp.einsum('bhc,hco->bo', weighted_avg, o_weights) + o_bias | |
| output = output[:, None] | |
| return output | |
| class MSARowAttentionWithPairBias(hk.Module): | |
| """MSA per-row attention biased by the pair representation. | |
| Jumper et al. (2021) Suppl. Alg. 7 "MSARowAttentionWithPairBias" | |
| """ | |
| def __init__(self, config, global_config, | |
| name='msa_row_attention_with_pair_bias'): | |
| super().__init__(name=name) | |
| self.config = config | |
| self.global_config = global_config | |
| def __call__(self, | |
| msa_act, | |
| msa_mask, | |
| pair_act, | |
| is_training=False): | |
| """Builds MSARowAttentionWithPairBias module. | |
| Arguments: | |
| msa_act: [N_seq, N_res, c_m] MSA representation. | |
| msa_mask: [N_seq, N_res] mask of non-padded regions. | |
| pair_act: [N_res, N_res, c_z] pair representation. | |
| is_training: Whether the module is in training mode. | |
| Returns: | |
| Update to msa_act, shape [N_seq, N_res, c_m]. | |
| """ | |
| c = self.config | |
| assert len(msa_act.shape) == 3 | |
| assert len(msa_mask.shape) == 2 | |
| assert c.orientation == 'per_row' | |
| bias = (1e9 * (msa_mask - 1.))[:, None, None, :] | |
| assert len(bias.shape) == 4 | |
| msa_act = hk.LayerNorm( | |
| axis=[-1], create_scale=True, create_offset=True, name='query_norm')( | |
| msa_act) | |
| pair_act = hk.LayerNorm( | |
| axis=[-1], | |
| create_scale=True, | |
| create_offset=True, | |
| name='feat_2d_norm')( | |
| pair_act) | |
| init_factor = 1. / jnp.sqrt(int(pair_act.shape[-1])) | |
| weights = hk.get_parameter( | |
| 'feat_2d_weights', | |
| shape=(pair_act.shape[-1], c.num_head), | |
| init=hk.initializers.RandomNormal(stddev=init_factor)) | |
| nonbatched_bias = jnp.einsum('qkc,ch->hqk', pair_act, weights) | |
| attn_mod = Attention( | |
| c, self.global_config, msa_act.shape[-1]) | |
| msa_act = mapping.inference_subbatch( | |
| attn_mod, | |
| self.global_config.subbatch_size, | |
| batched_args=[msa_act, msa_act, bias], | |
| nonbatched_args=[nonbatched_bias], | |
| low_memory=not is_training) | |
| return msa_act | |
| class MSAColumnAttention(hk.Module): | |
| """MSA per-column attention. | |
| Jumper et al. (2021) Suppl. Alg. 8 "MSAColumnAttention" | |
| """ | |
| def __init__(self, config, global_config, name='msa_column_attention'): | |
| super().__init__(name=name) | |
| self.config = config | |
| self.global_config = global_config | |
| def __call__(self, | |
| msa_act, | |
| msa_mask, | |
| is_training=False): | |
| """Builds MSAColumnAttention module. | |
| Arguments: | |
| msa_act: [N_seq, N_res, c_m] MSA representation. | |
| msa_mask: [N_seq, N_res] mask of non-padded regions. | |
| is_training: Whether the module is in training mode. | |
| Returns: | |
| Update to msa_act, shape [N_seq, N_res, c_m] | |
| """ | |
| c = self.config | |
| assert len(msa_act.shape) == 3 | |
| assert len(msa_mask.shape) == 2 | |
| assert c.orientation == 'per_column' | |
| msa_act = jnp.swapaxes(msa_act, -2, -3) | |
| msa_mask = jnp.swapaxes(msa_mask, -1, -2) | |
| bias = (1e9 * (msa_mask - 1.))[:, None, None, :] | |
| assert len(bias.shape) == 4 | |
| msa_act = hk.LayerNorm( | |
| axis=[-1], create_scale=True, create_offset=True, name='query_norm')( | |
| msa_act) | |
| attn_mod = Attention( | |
| c, self.global_config, msa_act.shape[-1]) | |
| msa_act = mapping.inference_subbatch( | |
| attn_mod, | |
| self.global_config.subbatch_size, | |
| batched_args=[msa_act, msa_act, bias], | |
| nonbatched_args=[], | |
| low_memory=not is_training) | |
| msa_act = jnp.swapaxes(msa_act, -2, -3) | |
| return msa_act | |
| class MSAColumnGlobalAttention(hk.Module): | |
| """MSA per-column global attention. | |
| Jumper et al. (2021) Suppl. Alg. 19 "MSAColumnGlobalAttention" | |
| """ | |
| def __init__(self, config, global_config, name='msa_column_global_attention'): | |
| super().__init__(name=name) | |
| self.config = config | |
| self.global_config = global_config | |
| def __call__(self, | |
| msa_act, | |
| msa_mask, | |
| is_training=False): | |
| """Builds MSAColumnGlobalAttention module. | |
| Arguments: | |
| msa_act: [N_seq, N_res, c_m] MSA representation. | |
| msa_mask: [N_seq, N_res] mask of non-padded regions. | |
| is_training: Whether the module is in training mode. | |
| Returns: | |
| Update to msa_act, shape [N_seq, N_res, c_m]. | |
| """ | |
| c = self.config | |
| assert len(msa_act.shape) == 3 | |
| assert len(msa_mask.shape) == 2 | |
| assert c.orientation == 'per_column' | |
| msa_act = jnp.swapaxes(msa_act, -2, -3) | |
| msa_mask = jnp.swapaxes(msa_mask, -1, -2) | |
| bias = (1e9 * (msa_mask - 1.))[:, None, None, :] | |
| assert len(bias.shape) == 4 | |
| msa_act = hk.LayerNorm( | |
| axis=[-1], create_scale=True, create_offset=True, name='query_norm')( | |
| msa_act) | |
| attn_mod = GlobalAttention( | |
| c, self.global_config, msa_act.shape[-1], | |
| name='attention') | |
| # [N_seq, N_res, 1] | |
| msa_mask = jnp.expand_dims(msa_mask, axis=-1) | |
| msa_act = mapping.inference_subbatch( | |
| attn_mod, | |
| self.global_config.subbatch_size, | |
| batched_args=[msa_act, msa_act, msa_mask, bias], | |
| nonbatched_args=[], | |
| low_memory=not is_training) | |
| msa_act = jnp.swapaxes(msa_act, -2, -3) | |
| return msa_act | |
| class TriangleAttention(hk.Module): | |
| """Triangle Attention. | |
| Jumper et al. (2021) Suppl. Alg. 13 "TriangleAttentionStartingNode" | |
| Jumper et al. (2021) Suppl. Alg. 14 "TriangleAttentionEndingNode" | |
| """ | |
| def __init__(self, config, global_config, name='triangle_attention'): | |
| super().__init__(name=name) | |
| self.config = config | |
| self.global_config = global_config | |
| def __call__(self, pair_act, pair_mask, is_training=False): | |
| """Builds TriangleAttention module. | |
| Arguments: | |
| pair_act: [N_res, N_res, c_z] pair activations tensor | |
| pair_mask: [N_res, N_res] mask of non-padded regions in the tensor. | |
| is_training: Whether the module is in training mode. | |
| Returns: | |
| Update to pair_act, shape [N_res, N_res, c_z]. | |
| """ | |
| c = self.config | |
| assert len(pair_act.shape) == 3 | |
| assert len(pair_mask.shape) == 2 | |
| assert c.orientation in ['per_row', 'per_column'] | |
| if c.orientation == 'per_column': | |
| pair_act = jnp.swapaxes(pair_act, -2, -3) | |
| pair_mask = jnp.swapaxes(pair_mask, -1, -2) | |
| bias = (1e9 * (pair_mask - 1.))[:, None, None, :] | |
| assert len(bias.shape) == 4 | |
| pair_act = hk.LayerNorm( | |
| axis=[-1], create_scale=True, create_offset=True, name='query_norm')( | |
| pair_act) | |
| init_factor = 1. / jnp.sqrt(int(pair_act.shape[-1])) | |
| weights = hk.get_parameter( | |
| 'feat_2d_weights', | |
| shape=(pair_act.shape[-1], c.num_head), | |
| init=hk.initializers.RandomNormal(stddev=init_factor)) | |
| nonbatched_bias = jnp.einsum('qkc,ch->hqk', pair_act, weights) | |
| attn_mod = Attention( | |
| c, self.global_config, pair_act.shape[-1]) | |
| pair_act = mapping.inference_subbatch( | |
| attn_mod, | |
| self.global_config.subbatch_size, | |
| batched_args=[pair_act, pair_act, bias], | |
| nonbatched_args=[nonbatched_bias], | |
| low_memory=not is_training) | |
| if c.orientation == 'per_column': | |
| pair_act = jnp.swapaxes(pair_act, -2, -3) | |
| return pair_act | |
| class MaskedMsaHead(hk.Module): | |
| """Head to predict MSA at the masked locations. | |
| The MaskedMsaHead employs a BERT-style objective to reconstruct a masked | |
| version of the full MSA, based on a linear projection of | |
| the MSA representation. | |
| Jumper et al. (2021) Suppl. Sec. 1.9.9 "Masked MSA prediction" | |
| """ | |
| def __init__(self, config, global_config, name='masked_msa_head'): | |
| super().__init__(name=name) | |
| self.config = config | |
| self.global_config = global_config | |
| def __call__(self, representations, batch, is_training): | |
| """Builds MaskedMsaHead module. | |
| Arguments: | |
| representations: Dictionary of representations, must contain: | |
| * 'msa': MSA representation, shape [N_seq, N_res, c_m]. | |
| batch: Batch, unused. | |
| is_training: Whether the module is in training mode. | |
| Returns: | |
| Dictionary containing: | |
| * 'logits': logits of shape [N_seq, N_res, N_aatype] with | |
| (unnormalized) log probabilies of predicted aatype at position. | |
| """ | |
| del batch | |
| logits = common_modules.Linear( | |
| self.config.num_output, | |
| initializer=utils.final_init(self.global_config), | |
| name='logits')( | |
| representations['msa']) | |
| return dict(logits=logits) | |
| def loss(self, value, batch): | |
| errors = softmax_cross_entropy( | |
| labels=jax.nn.one_hot(batch['true_msa'], num_classes=23), | |
| logits=value['logits']) | |
| loss = (jnp.sum(errors * batch['bert_mask'], axis=(-2, -1)) / | |
| (1e-8 + jnp.sum(batch['bert_mask'], axis=(-2, -1)))) | |
| return {'loss': loss} | |
| class PredictedLDDTHead(hk.Module): | |
| """Head to predict the per-residue LDDT to be used as a confidence measure. | |
| Jumper et al. (2021) Suppl. Sec. 1.9.6 "Model confidence prediction (pLDDT)" | |
| Jumper et al. (2021) Suppl. Alg. 29 "predictPerResidueLDDT_Ca" | |
| """ | |
| def __init__(self, config, global_config, name='predicted_lddt_head'): | |
| super().__init__(name=name) | |
| self.config = config | |
| self.global_config = global_config | |
| def __call__(self, representations, batch, is_training): | |
| """Builds ExperimentallyResolvedHead module. | |
| Arguments: | |
| representations: Dictionary of representations, must contain: | |
| * 'structure_module': Single representation from the structure module, | |
| shape [N_res, c_s]. | |
| batch: Batch, unused. | |
| is_training: Whether the module is in training mode. | |
| Returns: | |
| Dictionary containing : | |
| * 'logits': logits of shape [N_res, N_bins] with | |
| (unnormalized) log probabilies of binned predicted lDDT. | |
| """ | |
| act = representations['structure_module'] | |
| act = hk.LayerNorm( | |
| axis=[-1], | |
| create_scale=True, | |
| create_offset=True, | |
| name='input_layer_norm')( | |
| act) | |
| act = common_modules.Linear( | |
| self.config.num_channels, | |
| initializer='relu', | |
| name='act_0')( | |
| act) | |
| act = jax.nn.relu(act) | |
| act = common_modules.Linear( | |
| self.config.num_channels, | |
| initializer='relu', | |
| name='act_1')( | |
| act) | |
| act = jax.nn.relu(act) | |
| logits = common_modules.Linear( | |
| self.config.num_bins, | |
| initializer=utils.final_init(self.global_config), | |
| name='logits')( | |
| act) | |
| # Shape (batch_size, num_res, num_bins) | |
| return dict(logits=logits) | |
| def loss(self, value, batch): | |
| # Shape (num_res, 37, 3) | |
| pred_all_atom_pos = value['structure_module']['final_atom_positions'] | |
| # Shape (num_res, 37, 3) | |
| true_all_atom_pos = batch['all_atom_positions'] | |
| # Shape (num_res, 37) | |
| all_atom_mask = batch['all_atom_mask'] | |
| # Shape (num_res,) | |
| lddt_ca = lddt.lddt( | |
| # Shape (batch_size, num_res, 3) | |
| predicted_points=pred_all_atom_pos[None, :, 1, :], | |
| # Shape (batch_size, num_res, 3) | |
| true_points=true_all_atom_pos[None, :, 1, :], | |
| # Shape (batch_size, num_res, 1) | |
| true_points_mask=all_atom_mask[None, :, 1:2].astype(jnp.float32), | |
| cutoff=15., | |
| per_residue=True)[0] | |
| lddt_ca = jax.lax.stop_gradient(lddt_ca) | |
| num_bins = self.config.num_bins | |
| bin_index = jnp.floor(lddt_ca * num_bins).astype(jnp.int32) | |
| # protect against out of range for lddt_ca == 1 | |
| bin_index = jnp.minimum(bin_index, num_bins - 1) | |
| lddt_ca_one_hot = jax.nn.one_hot(bin_index, num_classes=num_bins) | |
| # Shape (num_res, num_channel) | |
| logits = value['predicted_lddt']['logits'] | |
| errors = softmax_cross_entropy(labels=lddt_ca_one_hot, logits=logits) | |
| # Shape (num_res,) | |
| mask_ca = all_atom_mask[:, residue_constants.atom_order['CA']] | |
| mask_ca = mask_ca.astype(jnp.float32) | |
| loss = jnp.sum(errors * mask_ca) / (jnp.sum(mask_ca) + 1e-8) | |
| if self.config.filter_by_resolution: | |
| # NMR & distillation have resolution = 0 | |
| loss *= ((batch['resolution'] >= self.config.min_resolution) | |
| & (batch['resolution'] <= self.config.max_resolution)).astype( | |
| jnp.float32) | |
| output = {'loss': loss} | |
| return output | |
| class PredictedAlignedErrorHead(hk.Module): | |
| """Head to predict the distance errors in the backbone alignment frames. | |
| Can be used to compute predicted TM-Score. | |
| Jumper et al. (2021) Suppl. Sec. 1.9.7 "TM-score prediction" | |
| """ | |
| def __init__(self, config, global_config, | |
| name='predicted_aligned_error_head'): | |
| super().__init__(name=name) | |
| self.config = config | |
| self.global_config = global_config | |
| def __call__(self, representations, batch, is_training): | |
| """Builds PredictedAlignedErrorHead module. | |
| Arguments: | |
| representations: Dictionary of representations, must contain: | |
| * 'pair': pair representation, shape [N_res, N_res, c_z]. | |
| batch: Batch, unused. | |
| is_training: Whether the module is in training mode. | |
| Returns: | |
| Dictionary containing: | |
| * logits: logits for aligned error, shape [N_res, N_res, N_bins]. | |
| * bin_breaks: array containing bin breaks, shape [N_bins - 1]. | |
| """ | |
| act = representations['pair'] | |
| # Shape (num_res, num_res, num_bins) | |
| logits = common_modules.Linear( | |
| self.config.num_bins, | |
| initializer=utils.final_init(self.global_config), | |
| name='logits')(act) | |
| # Shape (num_bins,) | |
| breaks = jnp.linspace( | |
| 0., self.config.max_error_bin, self.config.num_bins - 1) | |
| return dict(logits=logits, breaks=breaks) | |
| def loss(self, value, batch): | |
| # Shape (num_res, 7) | |
| predicted_affine = quat_affine.QuatAffine.from_tensor( | |
| value['structure_module']['final_affines']) | |
| # Shape (num_res, 7) | |
| true_affine = quat_affine.QuatAffine.from_tensor( | |
| batch['backbone_affine_tensor']) | |
| # Shape (num_res) | |
| mask = batch['backbone_affine_mask'] | |
| # Shape (num_res, num_res) | |
| square_mask = mask[:, None] * mask[None, :] | |
| num_bins = self.config.num_bins | |
| # (1, num_bins - 1) | |
| breaks = value['predicted_aligned_error']['breaks'] | |
| # (1, num_bins) | |
| logits = value['predicted_aligned_error']['logits'] | |
| # Compute the squared error for each alignment. | |
| def _local_frame_points(affine): | |
| points = [jnp.expand_dims(x, axis=-2) for x in affine.translation] | |
| return affine.invert_point(points, extra_dims=1) | |
| error_dist2_xyz = [ | |
| jnp.square(a - b) | |
| for a, b in zip(_local_frame_points(predicted_affine), | |
| _local_frame_points(true_affine))] | |
| error_dist2 = sum(error_dist2_xyz) | |
| # Shape (num_res, num_res) | |
| # First num_res are alignment frames, second num_res are the residues. | |
| error_dist2 = jax.lax.stop_gradient(error_dist2) | |
| sq_breaks = jnp.square(breaks) | |
| true_bins = jnp.sum(( | |
| error_dist2[..., None] > sq_breaks).astype(jnp.int32), axis=-1) | |
| errors = softmax_cross_entropy( | |
| labels=jax.nn.one_hot(true_bins, num_bins, axis=-1), logits=logits) | |
| loss = (jnp.sum(errors * square_mask, axis=(-2, -1)) / | |
| (1e-8 + jnp.sum(square_mask, axis=(-2, -1)))) | |
| if self.config.filter_by_resolution: | |
| # NMR & distillation have resolution = 0 | |
| loss *= ((batch['resolution'] >= self.config.min_resolution) | |
| & (batch['resolution'] <= self.config.max_resolution)).astype( | |
| jnp.float32) | |
| output = {'loss': loss} | |
| return output | |
| class ExperimentallyResolvedHead(hk.Module): | |
| """Predicts if an atom is experimentally resolved in a high-res structure. | |
| Only trained on high-resolution X-ray crystals & cryo-EM. | |
| Jumper et al. (2021) Suppl. Sec. 1.9.10 '"Experimentally resolved" prediction' | |
| """ | |
| def __init__(self, config, global_config, | |
| name='experimentally_resolved_head'): | |
| super().__init__(name=name) | |
| self.config = config | |
| self.global_config = global_config | |
| def __call__(self, representations, batch, is_training): | |
| """Builds ExperimentallyResolvedHead module. | |
| Arguments: | |
| representations: Dictionary of representations, must contain: | |
| * 'single': Single representation, shape [N_res, c_s]. | |
| batch: Batch, unused. | |
| is_training: Whether the module is in training mode. | |
| Returns: | |
| Dictionary containing: | |
| * 'logits': logits of shape [N_res, 37], | |
| log probability that an atom is resolved in atom37 representation, | |
| can be converted to probability by applying sigmoid. | |
| """ | |
| logits = common_modules.Linear( | |
| 37, # atom_exists.shape[-1] | |
| initializer=utils.final_init(self.global_config), | |
| name='logits')(representations['single']) | |
| return dict(logits=logits) | |
| def loss(self, value, batch): | |
| logits = value['logits'] | |
| assert len(logits.shape) == 2 | |
| # Does the atom appear in the amino acid? | |
| atom_exists = batch['atom37_atom_exists'] | |
| # Is the atom resolved in the experiment? Subset of atom_exists, | |
| # *except for OXT* | |
| all_atom_mask = batch['all_atom_mask'].astype(jnp.float32) | |
| xent = sigmoid_cross_entropy(labels=all_atom_mask, logits=logits) | |
| loss = jnp.sum(xent * atom_exists) / (1e-8 + jnp.sum(atom_exists)) | |
| if self.config.filter_by_resolution: | |
| # NMR & distillation examples have resolution = 0. | |
| loss *= ((batch['resolution'] >= self.config.min_resolution) | |
| & (batch['resolution'] <= self.config.max_resolution)).astype( | |
| jnp.float32) | |
| output = {'loss': loss} | |
| return output | |
| class TriangleMultiplication(hk.Module): | |
| """Triangle multiplication layer ("outgoing" or "incoming"). | |
| Jumper et al. (2021) Suppl. Alg. 11 "TriangleMultiplicationOutgoing" | |
| Jumper et al. (2021) Suppl. Alg. 12 "TriangleMultiplicationIncoming" | |
| """ | |
| def __init__(self, config, global_config, name='triangle_multiplication'): | |
| super().__init__(name=name) | |
| self.config = config | |
| self.global_config = global_config | |
| def __call__(self, act, mask, is_training=True): | |
| """Builds TriangleMultiplication module. | |
| Arguments: | |
| act: Pair activations, shape [N_res, N_res, c_z] | |
| mask: Pair mask, shape [N_res, N_res]. | |
| is_training: Whether the module is in training mode. | |
| Returns: | |
| Outputs, same shape/type as act. | |
| """ | |
| del is_training | |
| c = self.config | |
| gc = self.global_config | |
| mask = mask[..., None] | |
| act = hk.LayerNorm(axis=[-1], create_scale=True, create_offset=True, | |
| name='layer_norm_input')(act) | |
| input_act = act | |
| left_projection = common_modules.Linear( | |
| c.num_intermediate_channel, | |
| name='left_projection') | |
| left_proj_act = mask * left_projection(act) | |
| right_projection = common_modules.Linear( | |
| c.num_intermediate_channel, | |
| name='right_projection') | |
| right_proj_act = mask * right_projection(act) | |
| left_gate_values = jax.nn.sigmoid(common_modules.Linear( | |
| c.num_intermediate_channel, | |
| bias_init=1., | |
| initializer=utils.final_init(gc), | |
| name='left_gate')(act)) | |
| right_gate_values = jax.nn.sigmoid(common_modules.Linear( | |
| c.num_intermediate_channel, | |
| bias_init=1., | |
| initializer=utils.final_init(gc), | |
| name='right_gate')(act)) | |
| left_proj_act *= left_gate_values | |
| right_proj_act *= right_gate_values | |
| # "Outgoing" edges equation: 'ikc,jkc->ijc' | |
| # "Incoming" edges equation: 'kjc,kic->ijc' | |
| # Note on the Suppl. Alg. 11 & 12 notation: | |
| # For the "outgoing" edges, a = left_proj_act and b = right_proj_act | |
| # For the "incoming" edges, it's swapped: | |
| # b = left_proj_act and a = right_proj_act | |
| act = jnp.einsum(c.equation, left_proj_act, right_proj_act) | |
| act = hk.LayerNorm( | |
| axis=[-1], | |
| create_scale=True, | |
| create_offset=True, | |
| name='center_layer_norm')( | |
| act) | |
| output_channel = int(input_act.shape[-1]) | |
| act = common_modules.Linear( | |
| output_channel, | |
| initializer=utils.final_init(gc), | |
| name='output_projection')(act) | |
| gate_values = jax.nn.sigmoid(common_modules.Linear( | |
| output_channel, | |
| bias_init=1., | |
| initializer=utils.final_init(gc), | |
| name='gating_linear')(input_act)) | |
| act *= gate_values | |
| return act | |
| class DistogramHead(hk.Module): | |
| """Head to predict a distogram. | |
| Jumper et al. (2021) Suppl. Sec. 1.9.8 "Distogram prediction" | |
| """ | |
| def __init__(self, config, global_config, name='distogram_head'): | |
| super().__init__(name=name) | |
| self.config = config | |
| self.global_config = global_config | |
| def __call__(self, representations, batch, is_training): | |
| """Builds DistogramHead module. | |
| Arguments: | |
| representations: Dictionary of representations, must contain: | |
| * 'pair': pair representation, shape [N_res, N_res, c_z]. | |
| batch: Batch, unused. | |
| is_training: Whether the module is in training mode. | |
| Returns: | |
| Dictionary containing: | |
| * logits: logits for distogram, shape [N_res, N_res, N_bins]. | |
| * bin_breaks: array containing bin breaks, shape [N_bins - 1,]. | |
| """ | |
| half_logits = common_modules.Linear( | |
| self.config.num_bins, | |
| initializer=utils.final_init(self.global_config), | |
| name='half_logits')( | |
| representations['pair']) | |
| logits = half_logits + jnp.swapaxes(half_logits, -2, -3) | |
| breaks = jnp.linspace(self.config.first_break, self.config.last_break, | |
| self.config.num_bins - 1) | |
| return dict(logits=logits, bin_edges=breaks) | |
| def loss(self, value, batch): | |
| return _distogram_log_loss(value['logits'], value['bin_edges'], | |
| batch, self.config.num_bins) | |
| def _distogram_log_loss(logits, bin_edges, batch, num_bins): | |
| """Log loss of a distogram.""" | |
| assert len(logits.shape) == 3 | |
| positions = batch['pseudo_beta'] | |
| mask = batch['pseudo_beta_mask'] | |
| assert positions.shape[-1] == 3 | |
| sq_breaks = jnp.square(bin_edges) | |
| dist2 = jnp.sum( | |
| jnp.square( | |
| jnp.expand_dims(positions, axis=-2) - | |
| jnp.expand_dims(positions, axis=-3)), | |
| axis=-1, | |
| keepdims=True) | |
| true_bins = jnp.sum(dist2 > sq_breaks, axis=-1) | |
| errors = softmax_cross_entropy( | |
| labels=jax.nn.one_hot(true_bins, num_bins), logits=logits) | |
| square_mask = jnp.expand_dims(mask, axis=-2) * jnp.expand_dims(mask, axis=-1) | |
| avg_error = ( | |
| jnp.sum(errors * square_mask, axis=(-2, -1)) / | |
| (1e-6 + jnp.sum(square_mask, axis=(-2, -1)))) | |
| dist2 = dist2[..., 0] | |
| return dict(loss=avg_error, true_dist=jnp.sqrt(1e-6 + dist2)) | |
| class OuterProductMean(hk.Module): | |
| """Computes mean outer product. | |
| Jumper et al. (2021) Suppl. Alg. 10 "OuterProductMean" | |
| """ | |
| def __init__(self, | |
| config, | |
| global_config, | |
| num_output_channel, | |
| name='outer_product_mean'): | |
| super().__init__(name=name) | |
| self.global_config = global_config | |
| self.config = config | |
| self.num_output_channel = num_output_channel | |
| def __call__(self, act, mask, is_training=True): | |
| """Builds OuterProductMean module. | |
| Arguments: | |
| act: MSA representation, shape [N_seq, N_res, c_m]. | |
| mask: MSA mask, shape [N_seq, N_res]. | |
| is_training: Whether the module is in training mode. | |
| Returns: | |
| Update to pair representation, shape [N_res, N_res, c_z]. | |
| """ | |
| gc = self.global_config | |
| c = self.config | |
| mask = mask[..., None] | |
| act = hk.LayerNorm([-1], True, True, name='layer_norm_input')(act) | |
| left_act = mask * common_modules.Linear( | |
| c.num_outer_channel, | |
| initializer='linear', | |
| name='left_projection')( | |
| act) | |
| right_act = mask * common_modules.Linear( | |
| c.num_outer_channel, | |
| initializer='linear', | |
| name='right_projection')( | |
| act) | |
| if gc.zero_init: | |
| init_w = hk.initializers.Constant(0.0) | |
| else: | |
| init_w = hk.initializers.VarianceScaling(scale=2., mode='fan_in') | |
| output_w = hk.get_parameter( | |
| 'output_w', | |
| shape=(c.num_outer_channel, c.num_outer_channel, | |
| self.num_output_channel), | |
| init=init_w) | |
| output_b = hk.get_parameter( | |
| 'output_b', shape=(self.num_output_channel,), | |
| init=hk.initializers.Constant(0.0)) | |
| def compute_chunk(left_act): | |
| # This is equivalent to | |
| # | |
| # act = jnp.einsum('abc,ade->dceb', left_act, right_act) | |
| # act = jnp.einsum('dceb,cef->bdf', act, output_w) + output_b | |
| # | |
| # but faster. | |
| left_act = jnp.transpose(left_act, [0, 2, 1]) | |
| act = jnp.einsum('acb,ade->dceb', left_act, right_act) | |
| act = jnp.einsum('dceb,cef->dbf', act, output_w) + output_b | |
| return jnp.transpose(act, [1, 0, 2]) | |
| act = mapping.inference_subbatch( | |
| compute_chunk, | |
| c.chunk_size, | |
| batched_args=[left_act], | |
| nonbatched_args=[], | |
| low_memory=True, | |
| input_subbatch_dim=1, | |
| output_subbatch_dim=0) | |
| epsilon = 1e-3 | |
| norm = jnp.einsum('abc,adc->bdc', mask, mask) | |
| act /= epsilon + norm | |
| return act | |
| def dgram_from_positions(positions, num_bins, min_bin, max_bin): | |
| """Compute distogram from amino acid positions. | |
| Arguments: | |
| positions: [N_res, 3] Position coordinates. | |
| num_bins: The number of bins in the distogram. | |
| min_bin: The left edge of the first bin. | |
| max_bin: The left edge of the final bin. The final bin catches | |
| everything larger than `max_bin`. | |
| Returns: | |
| Distogram with the specified number of bins. | |
| """ | |
| def squared_difference(x, y): | |
| return jnp.square(x - y) | |
| lower_breaks = jnp.linspace(min_bin, max_bin, num_bins) | |
| lower_breaks = jnp.square(lower_breaks) | |
| upper_breaks = jnp.concatenate([lower_breaks[1:],jnp.array([1e8], dtype=jnp.float32)], axis=-1) | |
| dist2 = jnp.sum( | |
| squared_difference( | |
| jnp.expand_dims(positions, axis=-2), | |
| jnp.expand_dims(positions, axis=-3)), | |
| axis=-1, keepdims=True) | |
| return ((dist2 > lower_breaks).astype(jnp.float32) * (dist2 < upper_breaks).astype(jnp.float32)) | |
| def dgram_from_positions_soft(positions, num_bins, min_bin, max_bin, temp=2.0): | |
| '''soft positions to dgram converter''' | |
| lower_breaks = jnp.append(-1e8,jnp.linspace(min_bin, max_bin, num_bins)) | |
| upper_breaks = jnp.append(lower_breaks[1:],1e8) | |
| dist = jnp.sqrt(jnp.square(positions[...,:,None,:] - positions[...,None,:,:]).sum(-1,keepdims=True) + 1e-8) | |
| o = jax.nn.sigmoid((dist - lower_breaks)/temp) * jax.nn.sigmoid((upper_breaks - dist)/temp) | |
| o = o/(o.sum(-1,keepdims=True) + 1e-8) | |
| return o[...,1:] | |
| def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks): | |
| """Create pseudo beta features.""" | |
| ca_idx = residue_constants.atom_order['CA'] | |
| cb_idx = residue_constants.atom_order['CB'] | |
| if jnp.issubdtype(aatype.dtype, jnp.integer): | |
| is_gly = jnp.equal(aatype, residue_constants.restype_order['G']) | |
| is_gly_tile = jnp.tile(is_gly[..., None], [1] * len(is_gly.shape) + [3]) | |
| pseudo_beta = jnp.where(is_gly_tile, all_atom_positions[..., ca_idx, :], all_atom_positions[..., cb_idx, :]) | |
| if all_atom_masks is not None: | |
| pseudo_beta_mask = jnp.where(is_gly, all_atom_masks[..., ca_idx], all_atom_masks[..., cb_idx]) | |
| pseudo_beta_mask = pseudo_beta_mask.astype(jnp.float32) | |
| return pseudo_beta, pseudo_beta_mask | |
| else: | |
| return pseudo_beta | |
| else: | |
| is_gly = aatype[...,residue_constants.restype_order['G']] | |
| ca_pos = all_atom_positions[...,ca_idx,:] | |
| cb_pos = all_atom_positions[...,cb_idx,:] | |
| pseudo_beta = is_gly[...,None] * ca_pos + (1-is_gly[...,None]) * cb_pos | |
| if all_atom_masks is not None: | |
| ca_mask = all_atom_masks[...,ca_idx] | |
| cb_mask = all_atom_masks[...,cb_idx] | |
| pseudo_beta_mask = is_gly * ca_mask + (1-is_gly) * cb_mask | |
| return pseudo_beta, pseudo_beta_mask | |
| else: | |
| return pseudo_beta | |
| class EvoformerIteration(hk.Module): | |
| """Single iteration (block) of Evoformer stack. | |
| Jumper et al. (2021) Suppl. Alg. 6 "EvoformerStack" lines 2-10 | |
| """ | |
| def __init__(self, config, global_config, is_extra_msa, | |
| name='evoformer_iteration'): | |
| super().__init__(name=name) | |
| self.config = config | |
| self.global_config = global_config | |
| self.is_extra_msa = is_extra_msa | |
| def __call__(self, activations, masks, is_training=True, safe_key=None, scale_rate=1.0): | |
| """Builds EvoformerIteration module. | |
| Arguments: | |
| activations: Dictionary containing activations: | |
| * 'msa': MSA activations, shape [N_seq, N_res, c_m]. | |
| * 'pair': pair activations, shape [N_res, N_res, c_z]. | |
| masks: Dictionary of masks: | |
| * 'msa': MSA mask, shape [N_seq, N_res]. | |
| * 'pair': pair mask, shape [N_res, N_res]. | |
| is_training: Whether the module is in training mode. | |
| safe_key: prng.SafeKey encapsulating rng key. | |
| Returns: | |
| Outputs, same shape/type as act. | |
| """ | |
| c = self.config | |
| gc = self.global_config | |
| msa_act, pair_act = activations['msa'], activations['pair'] | |
| if safe_key is None: | |
| safe_key = prng.SafeKey(hk.next_rng_key()) | |
| msa_mask, pair_mask = masks['msa'], masks['pair'] | |
| dropout_wrapper_fn = functools.partial( | |
| dropout_wrapper, | |
| is_training=is_training, | |
| global_config=gc, | |
| scale_rate=scale_rate) | |
| safe_key, *sub_keys = safe_key.split(10) | |
| sub_keys = iter(sub_keys) | |
| msa_act = dropout_wrapper_fn( | |
| MSARowAttentionWithPairBias( | |
| c.msa_row_attention_with_pair_bias, gc, | |
| name='msa_row_attention_with_pair_bias'), | |
| msa_act, | |
| msa_mask, | |
| safe_key=next(sub_keys), | |
| pair_act=pair_act) | |
| if not self.is_extra_msa: | |
| attn_mod = MSAColumnAttention( | |
| c.msa_column_attention, gc, name='msa_column_attention') | |
| else: | |
| attn_mod = MSAColumnGlobalAttention( | |
| c.msa_column_attention, gc, name='msa_column_global_attention') | |
| msa_act = dropout_wrapper_fn( | |
| attn_mod, | |
| msa_act, | |
| msa_mask, | |
| safe_key=next(sub_keys)) | |
| msa_act = dropout_wrapper_fn( | |
| Transition(c.msa_transition, gc, name='msa_transition'), | |
| msa_act, | |
| msa_mask, | |
| safe_key=next(sub_keys)) | |
| pair_act = dropout_wrapper_fn( | |
| OuterProductMean( | |
| config=c.outer_product_mean, | |
| global_config=self.global_config, | |
| num_output_channel=int(pair_act.shape[-1]), | |
| name='outer_product_mean'), | |
| msa_act, | |
| msa_mask, | |
| safe_key=next(sub_keys), | |
| output_act=pair_act) | |
| pair_act = dropout_wrapper_fn( | |
| TriangleMultiplication(c.triangle_multiplication_outgoing, gc, | |
| name='triangle_multiplication_outgoing'), | |
| pair_act, | |
| pair_mask, | |
| safe_key=next(sub_keys)) | |
| pair_act = dropout_wrapper_fn( | |
| TriangleMultiplication(c.triangle_multiplication_incoming, gc, | |
| name='triangle_multiplication_incoming'), | |
| pair_act, | |
| pair_mask, | |
| safe_key=next(sub_keys)) | |
| pair_act = dropout_wrapper_fn( | |
| TriangleAttention(c.triangle_attention_starting_node, gc, | |
| name='triangle_attention_starting_node'), | |
| pair_act, | |
| pair_mask, | |
| safe_key=next(sub_keys)) | |
| pair_act = dropout_wrapper_fn( | |
| TriangleAttention(c.triangle_attention_ending_node, gc, | |
| name='triangle_attention_ending_node'), | |
| pair_act, | |
| pair_mask, | |
| safe_key=next(sub_keys)) | |
| pair_act = dropout_wrapper_fn( | |
| Transition(c.pair_transition, gc, name='pair_transition'), | |
| pair_act, | |
| pair_mask, | |
| safe_key=next(sub_keys)) | |
| return {'msa': msa_act, 'pair': pair_act} | |
| class EmbeddingsAndEvoformer(hk.Module): | |
| """Embeds the input data and runs Evoformer. | |
| Produces the MSA, single and pair representations. | |
| Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 5-18 | |
| """ | |
| def __init__(self, config, global_config, name='evoformer'): | |
| super().__init__(name=name) | |
| self.config = config | |
| self.global_config = global_config | |
| def __call__(self, batch, is_training, safe_key=None): | |
| c = self.config | |
| gc = self.global_config | |
| if safe_key is None: | |
| safe_key = prng.SafeKey(hk.next_rng_key()) | |
| # Embed clustered MSA. | |
| # Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 5 | |
| # Jumper et al. (2021) Suppl. Alg. 3 "InputEmbedder" | |
| preprocess_1d = common_modules.Linear( | |
| c.msa_channel, name='preprocess_1d')( | |
| batch['target_feat']) | |
| preprocess_msa = common_modules.Linear( | |
| c.msa_channel, name='preprocess_msa')( | |
| batch['msa_feat']) | |
| msa_activations = jnp.expand_dims(preprocess_1d, axis=0) + preprocess_msa | |
| left_single = common_modules.Linear( | |
| c.pair_channel, name='left_single')( | |
| batch['target_feat']) | |
| right_single = common_modules.Linear( | |
| c.pair_channel, name='right_single')( | |
| batch['target_feat']) | |
| pair_activations = left_single[:, None] + right_single[None] | |
| mask_2d = batch['seq_mask'][:, None] * batch['seq_mask'][None, :] | |
| # Inject previous outputs for recycling. | |
| # Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 6 | |
| # Jumper et al. (2021) Suppl. Alg. 32 "RecyclingEmbedder" | |
| if "prev_pos" in batch: | |
| # use predicted position input | |
| prev_pseudo_beta = pseudo_beta_fn(batch['aatype'], batch['prev_pos'], None) | |
| if c.backprop_dgram: | |
| dgram = dgram_from_positions_soft(prev_pseudo_beta, temp=c.backprop_dgram_temp, **c.prev_pos) | |
| else: | |
| dgram = dgram_from_positions(prev_pseudo_beta, **c.prev_pos) | |
| elif 'prev_dgram' in batch: | |
| # use predicted distogram input (from Sergey) | |
| dgram = jax.nn.softmax(batch["prev_dgram"]) | |
| dgram_map = jax.nn.one_hot(jnp.repeat(jnp.append(0,jnp.arange(15)),4),15).at[:,0].set(0) | |
| dgram = dgram @ dgram_map | |
| pair_activations += common_modules.Linear(c.pair_channel, name='prev_pos_linear')(dgram) | |
| if c.recycle_features: | |
| if 'prev_msa_first_row' in batch: | |
| prev_msa_first_row = hk.LayerNorm([-1], | |
| True, | |
| True, | |
| name='prev_msa_first_row_norm')( | |
| batch['prev_msa_first_row']) | |
| msa_activations = msa_activations.at[0].add(prev_msa_first_row) | |
| if 'prev_pair' in batch: | |
| pair_activations += hk.LayerNorm([-1], | |
| True, | |
| True, | |
| name='prev_pair_norm')( | |
| batch['prev_pair']) | |
| # Relative position encoding. | |
| # Jumper et al. (2021) Suppl. Alg. 4 "relpos" | |
| # Jumper et al. (2021) Suppl. Alg. 5 "one_hot" | |
| if c.max_relative_feature: | |
| # Add one-hot-encoded clipped residue distances to the pair activations. | |
| if "rel_pos" in batch: | |
| rel_pos = batch['rel_pos'] | |
| else: | |
| if "offset" in batch: | |
| offset = batch['offset'] | |
| else: | |
| pos = batch['residue_index'] | |
| offset = pos[:, None] - pos[None, :] | |
| rel_pos = jax.nn.one_hot( | |
| jnp.clip( | |
| offset + c.max_relative_feature, | |
| a_min=0, | |
| a_max=2 * c.max_relative_feature), | |
| 2 * c.max_relative_feature + 1) | |
| pair_activations += common_modules.Linear(c.pair_channel, name='pair_activiations')(rel_pos) | |
| # Embed templates into the pair activations. | |
| # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 9-13 | |
| if c.template.enabled: | |
| template_batch = {k: batch[k] for k in batch if k.startswith('template_')} | |
| template_pair_representation = TemplateEmbedding(c.template, gc)( | |
| pair_activations, | |
| template_batch, | |
| mask_2d, | |
| is_training=is_training, | |
| scale_rate=batch["scale_rate"]) | |
| pair_activations += template_pair_representation | |
| # Embed extra MSA features. | |
| # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 14-16 | |
| extra_msa_feat = create_extra_msa_feature(batch) | |
| extra_msa_activations = common_modules.Linear( | |
| c.extra_msa_channel, | |
| name='extra_msa_activations')( | |
| extra_msa_feat) | |
| # Extra MSA Stack. | |
| # Jumper et al. (2021) Suppl. Alg. 18 "ExtraMsaStack" | |
| extra_msa_stack_input = { | |
| 'msa': extra_msa_activations, | |
| 'pair': pair_activations, | |
| } | |
| extra_msa_stack_iteration = EvoformerIteration( | |
| c.evoformer, gc, is_extra_msa=True, name='extra_msa_stack') | |
| def extra_msa_stack_fn(x): | |
| act, safe_key = x | |
| safe_key, safe_subkey = safe_key.split() | |
| extra_evoformer_output = extra_msa_stack_iteration( | |
| activations=act, | |
| masks={ | |
| 'msa': batch['extra_msa_mask'], | |
| 'pair': mask_2d | |
| }, | |
| is_training=is_training, | |
| safe_key=safe_subkey, scale_rate=batch["scale_rate"]) | |
| return (extra_evoformer_output, safe_key) | |
| if gc.use_remat: | |
| extra_msa_stack_fn = hk.remat(extra_msa_stack_fn) | |
| extra_msa_stack = layer_stack.layer_stack( | |
| c.extra_msa_stack_num_block)( | |
| extra_msa_stack_fn) | |
| extra_msa_output, safe_key = extra_msa_stack( | |
| (extra_msa_stack_input, safe_key)) | |
| pair_activations = extra_msa_output['pair'] | |
| evoformer_input = { | |
| 'msa': msa_activations, | |
| 'pair': pair_activations, | |
| } | |
| evoformer_masks = {'msa': batch['msa_mask'], 'pair': mask_2d} | |
| #################################################################### | |
| #################################################################### | |
| # Append num_templ rows to msa_activations with template embeddings. | |
| # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 7-8 | |
| if c.template.enabled and c.template.embed_torsion_angles: | |
| if jnp.issubdtype(batch['template_aatype'].dtype, jnp.integer): | |
| num_templ, num_res = batch['template_aatype'].shape | |
| # Embed the templates aatypes. | |
| aatype = batch['template_aatype'] | |
| aatype_one_hot = jax.nn.one_hot(batch['template_aatype'], 22, axis=-1) | |
| else: | |
| num_templ, num_res, _ = batch['template_aatype'].shape | |
| aatype = batch['template_aatype'].argmax(-1) | |
| aatype_one_hot = batch['template_aatype'] | |
| # Embed the templates aatype, torsion angles and masks. | |
| # Shape (templates, residues, msa_channels) | |
| ret = all_atom.atom37_to_torsion_angles( | |
| aatype=aatype, | |
| all_atom_pos=batch['template_all_atom_positions'], | |
| all_atom_mask=batch['template_all_atom_masks'], | |
| # Ensure consistent behaviour during testing: | |
| placeholder_for_undefined=not gc.zero_init) | |
| template_features = jnp.concatenate([ | |
| aatype_one_hot, | |
| jnp.reshape(ret['torsion_angles_sin_cos'], [num_templ, num_res, 14]), | |
| jnp.reshape(ret['alt_torsion_angles_sin_cos'], [num_templ, num_res, 14]), | |
| ret['torsion_angles_mask']], axis=-1) | |
| template_activations = common_modules.Linear( | |
| c.msa_channel, | |
| initializer='relu', | |
| name='template_single_embedding')(template_features) | |
| template_activations = jax.nn.relu(template_activations) | |
| template_activations = common_modules.Linear( | |
| c.msa_channel, | |
| initializer='relu', | |
| name='template_projection')(template_activations) | |
| # Concatenate the templates to the msa. | |
| evoformer_input['msa'] = jnp.concatenate([evoformer_input['msa'], template_activations], axis=0) | |
| # Concatenate templates masks to the msa masks. | |
| # Use mask from the psi angle, as it only depends on the backbone atoms | |
| # from a single residue. | |
| torsion_angle_mask = ret['torsion_angles_mask'][:, :, 2] | |
| torsion_angle_mask = torsion_angle_mask.astype(evoformer_masks['msa'].dtype) | |
| evoformer_masks['msa'] = jnp.concatenate([evoformer_masks['msa'], torsion_angle_mask], axis=0) | |
| #################################################################### | |
| #################################################################### | |
| # Main trunk of the network | |
| # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 17-18 | |
| evoformer_iteration = EvoformerIteration( | |
| c.evoformer, gc, is_extra_msa=False, name='evoformer_iteration') | |
| def evoformer_fn(x): | |
| act, safe_key = x | |
| safe_key, safe_subkey = safe_key.split() | |
| evoformer_output = evoformer_iteration( | |
| activations=act, | |
| masks=evoformer_masks, | |
| is_training=is_training, | |
| safe_key=safe_subkey, scale_rate=batch["scale_rate"]) | |
| return (evoformer_output, safe_key) | |
| if gc.use_remat: | |
| evoformer_fn = hk.remat(evoformer_fn) | |
| evoformer_stack = layer_stack.layer_stack(c.evoformer_num_block)(evoformer_fn) | |
| evoformer_output, safe_key = evoformer_stack((evoformer_input, safe_key)) | |
| msa_activations = evoformer_output['msa'] | |
| pair_activations = evoformer_output['pair'] | |
| single_activations = common_modules.Linear( | |
| c.seq_channel, name='single_activations')(msa_activations[0]) | |
| num_sequences = batch['msa_feat'].shape[0] | |
| output = { | |
| 'single': single_activations, | |
| 'pair': pair_activations, | |
| # Crop away template rows such that they are not used in MaskedMsaHead. | |
| 'msa': msa_activations[:num_sequences, :, :], | |
| 'msa_first_row': msa_activations[0], | |
| } | |
| return output | |
| #################################################################### | |
| #################################################################### | |
| class SingleTemplateEmbedding(hk.Module): | |
| """Embeds a single template. | |
| Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 9+11 | |
| """ | |
| def __init__(self, config, global_config, name='single_template_embedding'): | |
| super().__init__(name=name) | |
| self.config = config | |
| self.global_config = global_config | |
| def __call__(self, query_embedding, batch, mask_2d, is_training, scale_rate=1.0): | |
| """Build the single template embedding. | |
| Arguments: | |
| query_embedding: Query pair representation, shape [N_res, N_res, c_z]. | |
| batch: A batch of template features (note the template dimension has been | |
| stripped out as this module only runs over a single template). | |
| mask_2d: Padding mask (Note: this doesn't care if a template exists, | |
| unlike the template_pseudo_beta_mask). | |
| is_training: Whether the module is in training mode. | |
| Returns: | |
| A template embedding [N_res, N_res, c_z]. | |
| """ | |
| assert mask_2d.dtype == query_embedding.dtype | |
| dtype = query_embedding.dtype | |
| num_res = batch['template_aatype'].shape[0] | |
| num_channels = (self.config.template_pair_stack | |
| .triangle_attention_ending_node.value_dim) | |
| template_mask = batch['template_pseudo_beta_mask'] | |
| template_mask_2d = template_mask[:, None] * template_mask[None, :] | |
| template_mask_2d = template_mask_2d.astype(dtype) | |
| if "template_dgram" in batch: | |
| template_dgram = batch["template_dgram"] | |
| else: | |
| if self.config.backprop_dgram: | |
| template_dgram = dgram_from_positions_soft(batch['template_pseudo_beta'], | |
| temp=self.config.backprop_dgram_temp, | |
| **self.config.dgram_features) | |
| else: | |
| template_dgram = dgram_from_positions(batch['template_pseudo_beta'], | |
| **self.config.dgram_features) | |
| template_dgram = template_dgram.astype(dtype) | |
| to_concat = [template_dgram, template_mask_2d[:, :, None]] | |
| if jnp.issubdtype(batch['template_aatype'].dtype, jnp.integer): | |
| aatype = jax.nn.one_hot(batch['template_aatype'], 22, axis=-1, dtype=dtype) | |
| else: | |
| aatype = batch['template_aatype'] | |
| to_concat.append(jnp.tile(aatype[None, :, :], [num_res, 1, 1])) | |
| to_concat.append(jnp.tile(aatype[:, None, :], [1, num_res, 1])) | |
| # Backbone affine mask: whether the residue has C, CA, N | |
| # (the template mask defined above only considers pseudo CB). | |
| n, ca, c = [residue_constants.atom_order[a] for a in ('N', 'CA', 'C')] | |
| template_mask = ( | |
| batch['template_all_atom_masks'][..., n] * | |
| batch['template_all_atom_masks'][..., ca] * | |
| batch['template_all_atom_masks'][..., c]) | |
| template_mask_2d = template_mask[:, None] * template_mask[None, :] | |
| # compute unit_vector (not used by default) | |
| if self.config.use_template_unit_vector: | |
| rot, trans = quat_affine.make_transform_from_reference( | |
| n_xyz=batch['template_all_atom_positions'][:, n], | |
| ca_xyz=batch['template_all_atom_positions'][:, ca], | |
| c_xyz=batch['template_all_atom_positions'][:, c]) | |
| affines = quat_affine.QuatAffine( | |
| quaternion=quat_affine.rot_to_quat(rot, unstack_inputs=True), | |
| translation=trans, | |
| rotation=rot, | |
| unstack_inputs=True) | |
| points = [jnp.expand_dims(x, axis=-2) for x in affines.translation] | |
| affine_vec = affines.invert_point(points, extra_dims=1) | |
| inv_distance_scalar = jax.lax.rsqrt(1e-6 + sum([jnp.square(x) for x in affine_vec])) | |
| inv_distance_scalar *= template_mask_2d.astype(inv_distance_scalar.dtype) | |
| unit_vector = [(x * inv_distance_scalar)[..., None] for x in affine_vec] | |
| else: | |
| unit_vector = [jnp.zeros((num_res,num_res,1))] * 3 | |
| unit_vector = [x.astype(dtype) for x in unit_vector] | |
| to_concat.extend(unit_vector) | |
| template_mask_2d = template_mask_2d.astype(dtype) | |
| to_concat.append(template_mask_2d[..., None]) | |
| act = jnp.concatenate(to_concat, axis=-1) | |
| # Mask out non-template regions so we don't get arbitrary values in the | |
| # distogram for these regions. | |
| act *= template_mask_2d[..., None] | |
| # Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 9 | |
| act = common_modules.Linear( | |
| num_channels, | |
| initializer='relu', | |
| name='embedding2d')(act) | |
| # Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 11 | |
| act = TemplatePairStack( | |
| self.config.template_pair_stack, self.global_config)(act, mask_2d, is_training, scale_rate=scale_rate) | |
| act = hk.LayerNorm([-1], True, True, name='output_layer_norm')(act) | |
| return act | |
| class TemplateEmbedding(hk.Module): | |
| """Embeds a set of templates. | |
| Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 9-12 | |
| Jumper et al. (2021) Suppl. Alg. 17 "TemplatePointwiseAttention" | |
| """ | |
| def __init__(self, config, global_config, name='template_embedding'): | |
| super().__init__(name=name) | |
| self.config = config | |
| self.global_config = global_config | |
| def __call__(self, query_embedding, template_batch, mask_2d, is_training, scale_rate=1.0): | |
| """Build TemplateEmbedding module. | |
| Arguments: | |
| query_embedding: Query pair representation, shape [N_res, N_res, c_z]. | |
| template_batch: A batch of template features. | |
| mask_2d: Padding mask (Note: this doesn't care if a template exists, | |
| unlike the template_pseudo_beta_mask). | |
| is_training: Whether the module is in training mode. | |
| Returns: | |
| A template embedding [N_res, N_res, c_z]. | |
| """ | |
| num_templates = template_batch['template_mask'].shape[0] | |
| num_channels = (self.config.template_pair_stack | |
| .triangle_attention_ending_node.value_dim) | |
| num_res = query_embedding.shape[0] | |
| dtype = query_embedding.dtype | |
| template_mask = template_batch['template_mask'] | |
| template_mask = template_mask.astype(dtype) | |
| query_num_channels = query_embedding.shape[-1] | |
| # Make sure the weights are shared across templates by constructing the | |
| # embedder here. | |
| # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 9-12 | |
| template_embedder = SingleTemplateEmbedding(self.config, self.global_config) | |
| def map_fn(batch): | |
| return template_embedder(query_embedding, batch, mask_2d, is_training, scale_rate=scale_rate) | |
| template_pair_representation = mapping.sharded_map(map_fn, in_axes=0)(template_batch) | |
| # Cross attend from the query to the templates along the residue | |
| # dimension by flattening everything else into the batch dimension. | |
| # Jumper et al. (2021) Suppl. Alg. 17 "TemplatePointwiseAttention" | |
| flat_query = jnp.reshape(query_embedding,[num_res * num_res, 1, query_num_channels]) | |
| flat_templates = jnp.reshape( | |
| jnp.transpose(template_pair_representation, [1, 2, 0, 3]), | |
| [num_res * num_res, num_templates, num_channels]) | |
| bias = (1e9 * (template_mask[None, None, None, :] - 1.)) | |
| template_pointwise_attention_module = Attention( | |
| self.config.attention, self.global_config, query_num_channels) | |
| nonbatched_args = [bias] | |
| batched_args = [flat_query, flat_templates] | |
| embedding = mapping.inference_subbatch( | |
| template_pointwise_attention_module, | |
| self.config.subbatch_size, | |
| batched_args=batched_args, | |
| nonbatched_args=nonbatched_args, | |
| low_memory=not is_training) | |
| embedding = jnp.reshape(embedding,[num_res, num_res, query_num_channels]) | |
| # No gradients if no templates. | |
| embedding *= (jnp.sum(template_mask) > 0.).astype(embedding.dtype) | |
| return embedding | |
| #################################################################### | |