Spaces:
Running
on
T4
Running
on
T4
| # Copyright 2021 DeepMind Technologies Limited | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Modules and utilities for the structure module.""" | |
| import functools | |
| from typing import Dict | |
| from alphafold.common import residue_constants | |
| from alphafold.model import all_atom | |
| from alphafold.model import common_modules | |
| from alphafold.model import prng | |
| from alphafold.model import quat_affine | |
| from alphafold.model import r3 | |
| from alphafold.model import utils | |
| import haiku as hk | |
| import jax | |
| import jax.numpy as jnp | |
| import ml_collections | |
| import numpy as np | |
| def squared_difference(x, y): | |
| return jnp.square(x - y) | |
| class InvariantPointAttention(hk.Module): | |
| """Invariant Point attention module. | |
| The high-level idea is that this attention module works over a set of points | |
| and associated orientations in 3D space (e.g. protein residues). | |
| Each residue outputs a set of queries and keys as points in their local | |
| reference frame. The attention is then defined as the euclidean distance | |
| between the queries and keys in the global frame. | |
| Jumper et al. (2021) Suppl. Alg. 22 "InvariantPointAttention" | |
| """ | |
| def __init__(self, | |
| config, | |
| global_config, | |
| dist_epsilon=1e-8, | |
| name='invariant_point_attention'): | |
| """Initialize. | |
| Args: | |
| config: Structure Module Config | |
| global_config: Global Config of Model. | |
| dist_epsilon: Small value to avoid NaN in distance calculation. | |
| name: Haiku Module name. | |
| """ | |
| super().__init__(name=name) | |
| self._dist_epsilon = dist_epsilon | |
| self._zero_initialize_last = global_config.zero_init | |
| self.config = config | |
| self.global_config = global_config | |
| def __call__(self, inputs_1d, inputs_2d, mask, affine): | |
| """Compute geometry-aware attention. | |
| Given a set of query residues (defined by affines and associated scalar | |
| features), this function computes geometry-aware attention between the | |
| query residues and target residues. | |
| The residues produce points in their local reference frame, which | |
| are converted into the global frame in order to compute attention via | |
| euclidean distance. | |
| Equivalently, the target residues produce points in their local frame to be | |
| used as attention values, which are converted into the query residues' | |
| local frames. | |
| Args: | |
| inputs_1d: (N, C) 1D input embedding that is the basis for the | |
| scalar queries. | |
| inputs_2d: (N, M, C') 2D input embedding, used for biases and values. | |
| mask: (N, 1) mask to indicate which elements of inputs_1d participate | |
| in the attention. | |
| affine: QuatAffine object describing the position and orientation of | |
| every element in inputs_1d. | |
| Returns: | |
| Transformation of the input embedding. | |
| """ | |
| num_residues, _ = inputs_1d.shape | |
| # Improve readability by removing a large number of 'self's. | |
| num_head = self.config.num_head | |
| num_scalar_qk = self.config.num_scalar_qk | |
| num_point_qk = self.config.num_point_qk | |
| num_scalar_v = self.config.num_scalar_v | |
| num_point_v = self.config.num_point_v | |
| num_output = self.config.num_channel | |
| assert num_scalar_qk > 0 | |
| assert num_point_qk > 0 | |
| assert num_point_v > 0 | |
| # Construct scalar queries of shape: | |
| # [num_query_residues, num_head, num_points] | |
| q_scalar = common_modules.Linear( | |
| num_head * num_scalar_qk, name='q_scalar')( | |
| inputs_1d) | |
| q_scalar = jnp.reshape( | |
| q_scalar, [num_residues, num_head, num_scalar_qk]) | |
| # Construct scalar keys/values of shape: | |
| # [num_target_residues, num_head, num_points] | |
| kv_scalar = common_modules.Linear( | |
| num_head * (num_scalar_v + num_scalar_qk), name='kv_scalar')( | |
| inputs_1d) | |
| kv_scalar = jnp.reshape(kv_scalar, | |
| [num_residues, num_head, | |
| num_scalar_v + num_scalar_qk]) | |
| k_scalar, v_scalar = jnp.split(kv_scalar, [num_scalar_qk], axis=-1) | |
| # Construct query points of shape: | |
| # [num_residues, num_head, num_point_qk] | |
| # First construct query points in local frame. | |
| q_point_local = common_modules.Linear( | |
| num_head * 3 * num_point_qk, name='q_point_local')( | |
| inputs_1d) | |
| q_point_local = jnp.split(q_point_local, 3, axis=-1) | |
| # Project query points into global frame. | |
| q_point_global = affine.apply_to_point(q_point_local, extra_dims=1) | |
| # Reshape query point for later use. | |
| q_point = [ | |
| jnp.reshape(x, [num_residues, num_head, num_point_qk]) | |
| for x in q_point_global] | |
| # Construct key and value points. | |
| # Key points have shape [num_residues, num_head, num_point_qk] | |
| # Value points have shape [num_residues, num_head, num_point_v] | |
| # Construct key and value points in local frame. | |
| kv_point_local = common_modules.Linear( | |
| num_head * 3 * (num_point_qk + num_point_v), name='kv_point_local')( | |
| inputs_1d) | |
| kv_point_local = jnp.split(kv_point_local, 3, axis=-1) | |
| # Project key and value points into global frame. | |
| kv_point_global = affine.apply_to_point(kv_point_local, extra_dims=1) | |
| kv_point_global = [ | |
| jnp.reshape(x, [num_residues, | |
| num_head, (num_point_qk + num_point_v)]) | |
| for x in kv_point_global] | |
| # Split key and value points. | |
| k_point, v_point = list( | |
| zip(*[ | |
| jnp.split(x, [num_point_qk,], axis=-1) | |
| for x in kv_point_global | |
| ])) | |
| # We assume that all queries and keys come iid from N(0, 1) distribution | |
| # and compute the variances of the attention logits. | |
| # Each scalar pair (q, k) contributes Var q*k = 1 | |
| scalar_variance = max(num_scalar_qk, 1) * 1. | |
| # Each point pair (q, k) contributes Var [0.5 ||q||^2 - <q, k>] = 9 / 2 | |
| point_variance = max(num_point_qk, 1) * 9. / 2 | |
| # Allocate equal variance to scalar, point and attention 2d parts so that | |
| # the sum is 1. | |
| num_logit_terms = 3 | |
| scalar_weights = np.sqrt(1.0 / (num_logit_terms * scalar_variance)) | |
| point_weights = np.sqrt(1.0 / (num_logit_terms * point_variance)) | |
| attention_2d_weights = np.sqrt(1.0 / (num_logit_terms)) | |
| # Trainable per-head weights for points. | |
| trainable_point_weights = jax.nn.softplus(hk.get_parameter( | |
| 'trainable_point_weights', shape=[num_head], | |
| # softplus^{-1} (1) | |
| init=hk.initializers.Constant(np.log(np.exp(1.) - 1.)))) | |
| point_weights *= jnp.expand_dims(trainable_point_weights, axis=1) | |
| v_point = [jnp.swapaxes(x, -2, -3) for x in v_point] | |
| q_point = [jnp.swapaxes(x, -2, -3) for x in q_point] | |
| k_point = [jnp.swapaxes(x, -2, -3) for x in k_point] | |
| dist2 = [ | |
| squared_difference(qx[:, :, None, :], kx[:, None, :, :]) | |
| for qx, kx in zip(q_point, k_point) | |
| ] | |
| dist2 = sum(dist2) | |
| attn_qk_point = -0.5 * jnp.sum( | |
| point_weights[:, None, None, :] * dist2, axis=-1) | |
| v = jnp.swapaxes(v_scalar, -2, -3) | |
| q = jnp.swapaxes(scalar_weights * q_scalar, -2, -3) | |
| k = jnp.swapaxes(k_scalar, -2, -3) | |
| attn_qk_scalar = jnp.matmul(q, jnp.swapaxes(k, -2, -1)) | |
| attn_logits = attn_qk_scalar + attn_qk_point | |
| attention_2d = common_modules.Linear( | |
| num_head, name='attention_2d')( | |
| inputs_2d) | |
| attention_2d = jnp.transpose(attention_2d, [2, 0, 1]) | |
| attention_2d = attention_2d_weights * attention_2d | |
| attn_logits += attention_2d | |
| mask_2d = mask * jnp.swapaxes(mask, -1, -2) | |
| attn_logits -= 1e5 * (1. - mask_2d) | |
| # [num_head, num_query_residues, num_target_residues] | |
| attn = jax.nn.softmax(attn_logits) | |
| # [num_head, num_query_residues, num_head * num_scalar_v] | |
| result_scalar = jnp.matmul(attn, v) | |
| # For point result, implement matmul manually so that it will be a float32 | |
| # on TPU. This is equivalent to | |
| # result_point_global = [jnp.einsum('bhqk,bhkc->bhqc', attn, vx) | |
| # for vx in v_point] | |
| # but on the TPU, doing the multiply and reduce_sum ensures the | |
| # computation happens in float32 instead of bfloat16. | |
| result_point_global = [jnp.sum( | |
| attn[:, :, :, None] * vx[:, None, :, :], | |
| axis=-2) for vx in v_point] | |
| # [num_query_residues, num_head, num_head * num_(scalar|point)_v] | |
| result_scalar = jnp.swapaxes(result_scalar, -2, -3) | |
| result_point_global = [ | |
| jnp.swapaxes(x, -2, -3) | |
| for x in result_point_global] | |
| # Features used in the linear output projection. Should have the size | |
| # [num_query_residues, ?] | |
| output_features = [] | |
| result_scalar = jnp.reshape( | |
| result_scalar, [num_residues, num_head * num_scalar_v]) | |
| output_features.append(result_scalar) | |
| result_point_global = [ | |
| jnp.reshape(r, [num_residues, num_head * num_point_v]) | |
| for r in result_point_global] | |
| result_point_local = affine.invert_point(result_point_global, extra_dims=1) | |
| output_features.extend(result_point_local) | |
| output_features.append(jnp.sqrt(self._dist_epsilon + | |
| jnp.square(result_point_local[0]) + | |
| jnp.square(result_point_local[1]) + | |
| jnp.square(result_point_local[2]))) | |
| # Dimensions: h = heads, i and j = residues, | |
| # c = inputs_2d channels | |
| # Contraction happens over the second residue dimension, similarly to how | |
| # the usual attention is performed. | |
| result_attention_over_2d = jnp.einsum('hij, ijc->ihc', attn, inputs_2d) | |
| num_out = num_head * result_attention_over_2d.shape[-1] | |
| output_features.append( | |
| jnp.reshape(result_attention_over_2d, | |
| [num_residues, num_out])) | |
| final_init = 'zeros' if self._zero_initialize_last else 'linear' | |
| final_act = jnp.concatenate(output_features, axis=-1) | |
| return common_modules.Linear( | |
| num_output, | |
| initializer=final_init, | |
| name='output_projection')(final_act) | |
| class FoldIteration(hk.Module): | |
| """A single iteration of the main structure module loop. | |
| Jumper et al. (2021) Suppl. Alg. 20 "StructureModule" lines 6-21 | |
| First, each residue attends to all residues using InvariantPointAttention. | |
| Then, we apply transition layers to update the hidden representations. | |
| Finally, we use the hidden representations to produce an update to the | |
| affine of each residue. | |
| """ | |
| def __init__(self, config, global_config, | |
| name='fold_iteration'): | |
| super().__init__(name=name) | |
| self.config = config | |
| self.global_config = global_config | |
| def __call__(self, | |
| activations, | |
| sequence_mask, | |
| update_affine, | |
| is_training, | |
| initial_act, | |
| safe_key=None, | |
| static_feat_2d=None, | |
| aatype=None, | |
| scale_rate=1.0): | |
| c = self.config | |
| if safe_key is None: | |
| safe_key = prng.SafeKey(hk.next_rng_key()) | |
| def safe_dropout_fn(tensor, safe_key): | |
| return prng.safe_dropout( | |
| tensor=tensor, | |
| safe_key=safe_key, | |
| rate=c.dropout * scale_rate, | |
| is_deterministic=self.global_config.deterministic, | |
| is_training=is_training) | |
| affine = quat_affine.QuatAffine.from_tensor(activations['affine']) | |
| act = activations['act'] | |
| attention_module = InvariantPointAttention(self.config, self.global_config) | |
| # Attention | |
| attn = attention_module( | |
| inputs_1d=act, | |
| inputs_2d=static_feat_2d, | |
| mask=sequence_mask, | |
| affine=affine) | |
| act += attn | |
| safe_key, *sub_keys = safe_key.split(3) | |
| sub_keys = iter(sub_keys) | |
| act = safe_dropout_fn(act, next(sub_keys)) | |
| act = hk.LayerNorm( | |
| axis=[-1], | |
| create_scale=True, | |
| create_offset=True, | |
| name='attention_layer_norm')( | |
| act) | |
| final_init = 'zeros' if self.global_config.zero_init else 'linear' | |
| # Transition | |
| input_act = act | |
| for i in range(c.num_layer_in_transition): | |
| init = 'relu' if i < c.num_layer_in_transition - 1 else final_init | |
| act = common_modules.Linear( | |
| c.num_channel, | |
| initializer=init, | |
| name='transition')( | |
| act) | |
| if i < c.num_layer_in_transition - 1: | |
| act = jax.nn.relu(act) | |
| act += input_act | |
| act = safe_dropout_fn(act, next(sub_keys)) | |
| act = hk.LayerNorm( | |
| axis=[-1], | |
| create_scale=True, | |
| create_offset=True, | |
| name='transition_layer_norm')(act) | |
| if update_affine: | |
| # This block corresponds to | |
| # Jumper et al. (2021) Alg. 23 "Backbone update" | |
| affine_update_size = 6 | |
| # Affine update | |
| affine_update = common_modules.Linear( | |
| affine_update_size, | |
| initializer=final_init, | |
| name='affine_update')( | |
| act) | |
| affine = affine.pre_compose(affine_update) | |
| sc = MultiRigidSidechain(c.sidechain, self.global_config)( | |
| affine.scale_translation(c.position_scale), [act, initial_act], aatype) | |
| outputs = {'affine': affine.to_tensor(), 'sc': sc} | |
| # affine = affine.apply_rotation_tensor_fn(jax.lax.stop_gradient) | |
| new_activations = { | |
| 'act': act, | |
| 'affine': affine.to_tensor() | |
| } | |
| return new_activations, outputs | |
| def generate_affines(representations, batch, config, global_config, | |
| is_training, safe_key): | |
| """Generate predicted affines for a single chain. | |
| Jumper et al. (2021) Suppl. Alg. 20 "StructureModule" | |
| This is the main part of the structure module - it iteratively applies | |
| folding to produce a set of predicted residue positions. | |
| Args: | |
| representations: Representations dictionary. | |
| batch: Batch dictionary. | |
| config: Config for the structure module. | |
| global_config: Global config. | |
| is_training: Whether the model is being trained. | |
| safe_key: A prng.SafeKey object that wraps a PRNG key. | |
| Returns: | |
| A dictionary containing residue affines and sidechain positions. | |
| """ | |
| c = config | |
| sequence_mask = batch['seq_mask'][:, None] | |
| act = hk.LayerNorm( | |
| axis=[-1], | |
| create_scale=True, | |
| create_offset=True, | |
| name='single_layer_norm')( | |
| representations['single']) | |
| initial_act = act | |
| act = common_modules.Linear( | |
| c.num_channel, name='initial_projection')( | |
| act) | |
| affine = generate_new_affine(sequence_mask) | |
| fold_iteration = FoldIteration( | |
| c, global_config, name='fold_iteration') | |
| assert len(batch['seq_mask'].shape) == 1 | |
| activations = {'act': act, | |
| 'affine': affine.to_tensor(), | |
| } | |
| act_2d = hk.LayerNorm( | |
| axis=[-1], | |
| create_scale=True, | |
| create_offset=True, | |
| name='pair_layer_norm')( | |
| representations['pair']) | |
| def fold_iter(act, key): | |
| act, out = fold_iteration( | |
| act, | |
| initial_act=initial_act, | |
| static_feat_2d=act_2d, | |
| safe_key=prng.SafeKey(key), | |
| sequence_mask=sequence_mask, | |
| update_affine=True, | |
| is_training=is_training, | |
| aatype=batch['aatype'], | |
| scale_rate=batch["scale_rate"]) | |
| return act, out | |
| keys = jax.random.split(safe_key.get(), c.num_layer) | |
| activations, output = hk.scan(fold_iter, activations, keys) | |
| # Include the activations in the output dict for use by the LDDT-Head. | |
| output['act'] = activations['act'] | |
| return output | |
| class dummy(hk.Module): | |
| def __init__(self, config, global_config, compute_loss=True): | |
| super().__init__(name="dummy") | |
| def __call__(self, representations, batch, is_training, safe_key=None): | |
| if safe_key is None: | |
| safe_key = prng.SafeKey(hk.next_rng_key()) | |
| return {} | |
| class StructureModule(hk.Module): | |
| """StructureModule as a network head. | |
| Jumper et al. (2021) Suppl. Alg. 20 "StructureModule" | |
| """ | |
| def __init__(self, config, global_config, compute_loss=True, | |
| name='structure_module'): | |
| super().__init__(name=name) | |
| self.config = config | |
| self.global_config = global_config | |
| self.compute_loss = compute_loss | |
| def __call__(self, representations, batch, is_training, | |
| safe_key=None): | |
| c = self.config | |
| ret = {} | |
| if safe_key is None: | |
| safe_key = prng.SafeKey(hk.next_rng_key()) | |
| output = generate_affines( | |
| representations=representations, | |
| batch=batch, | |
| config=self.config, | |
| global_config=self.global_config, | |
| is_training=is_training, | |
| safe_key=safe_key) | |
| ret['representations'] = {'structure_module': output['act']} | |
| ret['traj'] = output['affine'] * jnp.array([1.] * 4 + [c.position_scale] * 3) | |
| ret['sidechains'] = output['sc'] | |
| atom14_pred_positions = r3.vecs_to_tensor(output['sc']['atom_pos'])[-1] | |
| ret['final_atom14_positions'] = atom14_pred_positions # (N, 14, 3) | |
| ret['final_atom14_mask'] = batch['atom14_atom_exists'] # (N, 14) | |
| atom37_pred_positions = all_atom.atom14_to_atom37(atom14_pred_positions, batch) | |
| atom37_pred_positions *= batch['atom37_atom_exists'][:, :, None] | |
| ret['final_atom_positions'] = atom37_pred_positions # (N, 37, 3) | |
| ret['final_atom_mask'] = batch['atom37_atom_exists'] # (N, 37) | |
| ret['final_affines'] = ret['traj'][-1] | |
| return ret | |
| def loss(self, value, batch): | |
| ret = {'loss': 0.} | |
| ret['metrics'] = {} | |
| # If requested, compute in-graph metrics. | |
| if self.config.compute_in_graph_metrics: | |
| atom14_pred_positions = value['final_atom14_positions'] | |
| # Compute renaming and violations. | |
| value.update(compute_renamed_ground_truth(batch, atom14_pred_positions)) | |
| value['violations'] = find_structural_violations( | |
| batch, atom14_pred_positions, self.config) | |
| # Several violation metrics: | |
| violation_metrics = compute_violation_metrics( | |
| batch=batch, | |
| atom14_pred_positions=atom14_pred_positions, | |
| violations=value['violations']) | |
| ret['metrics'].update(violation_metrics) | |
| backbone_loss(ret, batch, value, self.config) | |
| if 'renamed_atom14_gt_positions' not in value: | |
| value.update(compute_renamed_ground_truth( | |
| batch, value['final_atom14_positions'])) | |
| sc_loss = sidechain_loss(batch, value, self.config) | |
| ret['loss'] = ((1 - self.config.sidechain.weight_frac) * ret['loss'] + | |
| self.config.sidechain.weight_frac * sc_loss['loss']) | |
| ret['sidechain_fape'] = sc_loss['fape'] | |
| supervised_chi_loss(ret, batch, value, self.config) | |
| if self.config.structural_violation_loss_weight: | |
| if 'violations' not in value: | |
| value['violations'] = find_structural_violations( | |
| batch, value['final_atom14_positions'], self.config) | |
| structural_violation_loss(ret, batch, value, self.config) | |
| return ret | |
| def compute_renamed_ground_truth( | |
| batch: Dict[str, jnp.ndarray], | |
| atom14_pred_positions: jnp.ndarray, | |
| ) -> Dict[str, jnp.ndarray]: | |
| """Find optimal renaming of ground truth based on the predicted positions. | |
| Jumper et al. (2021) Suppl. Alg. 26 "renameSymmetricGroundTruthAtoms" | |
| This renamed ground truth is then used for all losses, | |
| such that each loss moves the atoms in the same direction. | |
| Shape (N). | |
| Args: | |
| batch: Dictionary containing: | |
| * atom14_gt_positions: Ground truth positions. | |
| * atom14_alt_gt_positions: Ground truth positions with renaming swaps. | |
| * atom14_atom_is_ambiguous: 1.0 for atoms that are affected by | |
| renaming swaps. | |
| * atom14_gt_exists: Mask for which atoms exist in ground truth. | |
| * atom14_alt_gt_exists: Mask for which atoms exist in ground truth | |
| after renaming. | |
| * atom14_atom_exists: Mask for whether each atom is part of the given | |
| amino acid type. | |
| atom14_pred_positions: Array of atom positions in global frame with shape | |
| (N, 14, 3). | |
| Returns: | |
| Dictionary containing: | |
| alt_naming_is_better: Array with 1.0 where alternative swap is better. | |
| renamed_atom14_gt_positions: Array of optimal ground truth positions | |
| after renaming swaps are performed. | |
| renamed_atom14_gt_exists: Mask after renaming swap is performed. | |
| """ | |
| alt_naming_is_better = all_atom.find_optimal_renaming( | |
| atom14_gt_positions=batch['atom14_gt_positions'], | |
| atom14_alt_gt_positions=batch['atom14_alt_gt_positions'], | |
| atom14_atom_is_ambiguous=batch['atom14_atom_is_ambiguous'], | |
| atom14_gt_exists=batch['atom14_gt_exists'], | |
| atom14_pred_positions=atom14_pred_positions, | |
| atom14_atom_exists=batch['atom14_atom_exists']) | |
| renamed_atom14_gt_positions = ( | |
| (1. - alt_naming_is_better[:, None, None]) | |
| * batch['atom14_gt_positions'] | |
| + alt_naming_is_better[:, None, None] | |
| * batch['atom14_alt_gt_positions']) | |
| renamed_atom14_gt_mask = ( | |
| (1. - alt_naming_is_better[:, None]) * batch['atom14_gt_exists'] | |
| + alt_naming_is_better[:, None] * batch['atom14_alt_gt_exists']) | |
| return { | |
| 'alt_naming_is_better': alt_naming_is_better, # (N) | |
| 'renamed_atom14_gt_positions': renamed_atom14_gt_positions, # (N, 14, 3) | |
| 'renamed_atom14_gt_exists': renamed_atom14_gt_mask, # (N, 14) | |
| } | |
| def backbone_loss(ret, batch, value, config): | |
| """Backbone FAPE Loss. | |
| Jumper et al. (2021) Suppl. Alg. 20 "StructureModule" line 17 | |
| Args: | |
| ret: Dictionary to write outputs into, needs to contain 'loss'. | |
| batch: Batch, needs to contain 'backbone_affine_tensor', | |
| 'backbone_affine_mask'. | |
| value: Dictionary containing structure module output, needs to contain | |
| 'traj', a trajectory of rigids. | |
| config: Configuration of loss, should contain 'fape.clamp_distance' and | |
| 'fape.loss_unit_distance'. | |
| """ | |
| affine_trajectory = quat_affine.QuatAffine.from_tensor(value['traj']) | |
| rigid_trajectory = r3.rigids_from_quataffine(affine_trajectory) | |
| if 'backbone_affine_tensor' in batch: | |
| gt_affine = quat_affine.QuatAffine.from_tensor(batch['backbone_affine_tensor']) | |
| backbone_mask = batch['backbone_affine_mask'] | |
| else: | |
| n_xyz = batch['all_atom_positions'][...,0,:] | |
| ca_xyz = batch['all_atom_positions'][...,1,:] | |
| c_xyz = batch['all_atom_positions'][...,2,:] | |
| rot, trans = quat_affine.make_transform_from_reference(n_xyz, ca_xyz, c_xyz) | |
| gt_affine = quat_affine.QuatAffine(quaternion=None, | |
| translation=trans, | |
| rotation=rot, | |
| unstack_inputs=True) | |
| backbone_mask = batch['all_atom_mask'][...,0] | |
| gt_rigid = r3.rigids_from_quataffine(gt_affine) | |
| fape_loss_fn = functools.partial( | |
| all_atom.frame_aligned_point_error, | |
| l1_clamp_distance=config.fape.clamp_distance, | |
| length_scale=config.fape.loss_unit_distance) | |
| fape_loss_fn = jax.vmap(fape_loss_fn, (0, None, None, 0, None, None)) | |
| fape_loss = fape_loss_fn(rigid_trajectory, gt_rigid, backbone_mask, | |
| rigid_trajectory.trans, gt_rigid.trans, | |
| backbone_mask) | |
| if 'use_clamped_fape' in batch: | |
| # Jumper et al. (2021) Suppl. Sec. 1.11.5 "Loss clamping details" | |
| use_clamped_fape = jnp.asarray(batch['use_clamped_fape'], jnp.float32) | |
| unclamped_fape_loss_fn = functools.partial( | |
| all_atom.frame_aligned_point_error, | |
| l1_clamp_distance=None, | |
| length_scale=config.fape.loss_unit_distance) | |
| unclamped_fape_loss_fn = jax.vmap(unclamped_fape_loss_fn, | |
| (0, None, None, 0, None, None)) | |
| fape_loss_unclamped = unclamped_fape_loss_fn(rigid_trajectory, gt_rigid, | |
| backbone_mask, | |
| rigid_trajectory.trans, | |
| gt_rigid.trans, | |
| backbone_mask) | |
| fape_loss = (fape_loss * use_clamped_fape + fape_loss_unclamped * (1 - use_clamped_fape)) | |
| ret['fape'] = fape_loss[-1] | |
| ret['loss'] += jnp.mean(fape_loss) | |
| def sidechain_loss(batch, value, config): | |
| """All Atom FAPE Loss using renamed rigids.""" | |
| # Rename Frames | |
| # Jumper et al. (2021) Suppl. Alg. 26 "renameSymmetricGroundTruthAtoms" line 7 | |
| alt_naming_is_better = value['alt_naming_is_better'] | |
| renamed_gt_frames = ( | |
| (1. - alt_naming_is_better[:, None, None]) | |
| * batch['rigidgroups_gt_frames'] | |
| + alt_naming_is_better[:, None, None] | |
| * batch['rigidgroups_alt_gt_frames']) | |
| flat_gt_frames = r3.rigids_from_tensor_flat12(jnp.reshape(renamed_gt_frames, [-1, 12])) | |
| flat_frames_mask = jnp.reshape(batch['rigidgroups_gt_exists'], [-1]) | |
| flat_gt_positions = r3.vecs_from_tensor(jnp.reshape(value['renamed_atom14_gt_positions'], [-1, 3])) | |
| flat_positions_mask = jnp.reshape(value['renamed_atom14_gt_exists'], [-1]) | |
| # Compute frame_aligned_point_error score for the final layer. | |
| pred_frames = value['sidechains']['frames'] | |
| pred_positions = value['sidechains']['atom_pos'] | |
| def _slice_last_layer_and_flatten(x): | |
| return jnp.reshape(x[-1], [-1]) | |
| flat_pred_frames = jax.tree_map(_slice_last_layer_and_flatten, pred_frames) | |
| flat_pred_positions = jax.tree_map(_slice_last_layer_and_flatten, pred_positions) | |
| # FAPE Loss on sidechains | |
| fape = all_atom.frame_aligned_point_error( | |
| pred_frames=flat_pred_frames, | |
| target_frames=flat_gt_frames, | |
| frames_mask=flat_frames_mask, | |
| pred_positions=flat_pred_positions, | |
| target_positions=flat_gt_positions, | |
| positions_mask=flat_positions_mask, | |
| l1_clamp_distance=config.sidechain.atom_clamp_distance, | |
| length_scale=config.sidechain.length_scale) | |
| return { | |
| 'fape': fape, | |
| 'loss': fape} | |
| def structural_violation_loss(ret, batch, value, config): | |
| """Computes loss for structural violations.""" | |
| assert config.sidechain.weight_frac | |
| # Put all violation losses together to one large loss. | |
| violations = value['violations'] | |
| num_atoms = jnp.sum(batch['atom14_atom_exists']).astype(jnp.float32) | |
| ret['loss'] += (config.structural_violation_loss_weight * ( | |
| violations['between_residues']['bonds_c_n_loss_mean'] + | |
| violations['between_residues']['angles_ca_c_n_loss_mean'] + | |
| violations['between_residues']['angles_c_n_ca_loss_mean'] + | |
| jnp.sum( | |
| violations['between_residues']['clashes_per_atom_loss_sum'] + | |
| violations['within_residues']['per_atom_loss_sum']) / | |
| (1e-6 + num_atoms))) | |
| def find_structural_violations( | |
| batch: Dict[str, jnp.ndarray], | |
| atom14_pred_positions: jnp.ndarray, # (N, 14, 3) | |
| config: ml_collections.ConfigDict | |
| ): | |
| """Computes several checks for structural violations.""" | |
| # Compute between residue backbone violations of bonds and angles. | |
| connection_violations = all_atom.between_residue_bond_loss( | |
| pred_atom_positions=atom14_pred_positions, | |
| pred_atom_mask=batch['atom14_atom_exists'].astype(jnp.float32), | |
| residue_index=batch['residue_index'].astype(jnp.float32), | |
| aatype=batch['aatype'], | |
| tolerance_factor_soft=config.violation_tolerance_factor, | |
| tolerance_factor_hard=config.violation_tolerance_factor) | |
| # Compute the Van der Waals radius for every atom | |
| # (the first letter of the atom name is the element type). | |
| # Shape: (N, 14). | |
| atomtype_radius = [ | |
| residue_constants.van_der_waals_radius[name[0]] | |
| for name in residue_constants.atom_types | |
| ] | |
| atom14_atom_radius = batch['atom14_atom_exists'] * utils.batched_gather( | |
| atomtype_radius, batch['residx_atom14_to_atom37']) | |
| # Compute the between residue clash loss. | |
| between_residue_clashes = all_atom.between_residue_clash_loss( | |
| atom14_pred_positions=atom14_pred_positions, | |
| atom14_atom_exists=batch['atom14_atom_exists'], | |
| atom14_atom_radius=atom14_atom_radius, | |
| residue_index=batch['residue_index'], | |
| overlap_tolerance_soft=config.clash_overlap_tolerance, | |
| overlap_tolerance_hard=config.clash_overlap_tolerance) | |
| # Compute all within-residue violations (clashes, | |
| # bond length and angle violations). | |
| restype_atom14_bounds = residue_constants.make_atom14_dists_bounds( | |
| overlap_tolerance=config.clash_overlap_tolerance, | |
| bond_length_tolerance_factor=config.violation_tolerance_factor) | |
| atom14_dists_lower_bound = utils.batched_gather( | |
| restype_atom14_bounds['lower_bound'], batch['aatype']) | |
| atom14_dists_upper_bound = utils.batched_gather( | |
| restype_atom14_bounds['upper_bound'], batch['aatype']) | |
| within_residue_violations = all_atom.within_residue_violations( | |
| atom14_pred_positions=atom14_pred_positions, | |
| atom14_atom_exists=batch['atom14_atom_exists'], | |
| atom14_dists_lower_bound=atom14_dists_lower_bound, | |
| atom14_dists_upper_bound=atom14_dists_upper_bound, | |
| tighten_bounds_for_loss=0.0) | |
| # Combine them to a single per-residue violation mask (used later for LDDT). | |
| per_residue_violations_mask = jnp.max(jnp.stack([ | |
| connection_violations['per_residue_violation_mask'], | |
| jnp.max(between_residue_clashes['per_atom_clash_mask'], axis=-1), | |
| jnp.max(within_residue_violations['per_atom_violations'], | |
| axis=-1)]), axis=0) | |
| return { | |
| 'between_residues': { | |
| 'bonds_c_n_loss_mean': | |
| connection_violations['c_n_loss_mean'], # () | |
| 'angles_ca_c_n_loss_mean': | |
| connection_violations['ca_c_n_loss_mean'], # () | |
| 'angles_c_n_ca_loss_mean': | |
| connection_violations['c_n_ca_loss_mean'], # () | |
| 'connections_per_residue_loss_sum': | |
| connection_violations['per_residue_loss_sum'], # (N) | |
| 'connections_per_residue_violation_mask': | |
| connection_violations['per_residue_violation_mask'], # (N) | |
| 'clashes_mean_loss': | |
| between_residue_clashes['mean_loss'], # () | |
| 'clashes_per_atom_loss_sum': | |
| between_residue_clashes['per_atom_loss_sum'], # (N, 14) | |
| 'clashes_per_atom_clash_mask': | |
| between_residue_clashes['per_atom_clash_mask'], # (N, 14) | |
| }, | |
| 'within_residues': { | |
| 'per_atom_loss_sum': | |
| within_residue_violations['per_atom_loss_sum'], # (N, 14) | |
| 'per_atom_violations': | |
| within_residue_violations['per_atom_violations'], # (N, 14), | |
| }, | |
| 'total_per_residue_violations_mask': | |
| per_residue_violations_mask, # (N) | |
| } | |
| def compute_violation_metrics( | |
| batch: Dict[str, jnp.ndarray], | |
| atom14_pred_positions: jnp.ndarray, # (N, 14, 3) | |
| violations: Dict[str, jnp.ndarray], | |
| ) -> Dict[str, jnp.ndarray]: | |
| """Compute several metrics to assess the structural violations.""" | |
| ret = {} | |
| extreme_ca_ca_violations = all_atom.extreme_ca_ca_distance_violations( | |
| pred_atom_positions=atom14_pred_positions, | |
| pred_atom_mask=batch['atom14_atom_exists'].astype(jnp.float32), | |
| residue_index=batch['residue_index'].astype(jnp.float32)) | |
| ret['violations_extreme_ca_ca_distance'] = extreme_ca_ca_violations | |
| ret['violations_between_residue_bond'] = utils.mask_mean( | |
| mask=batch['seq_mask'], | |
| value=violations['between_residues'][ | |
| 'connections_per_residue_violation_mask']) | |
| ret['violations_between_residue_clash'] = utils.mask_mean( | |
| mask=batch['seq_mask'], | |
| value=jnp.max( | |
| violations['between_residues']['clashes_per_atom_clash_mask'], | |
| axis=-1)) | |
| ret['violations_within_residue'] = utils.mask_mean( | |
| mask=batch['seq_mask'], | |
| value=jnp.max( | |
| violations['within_residues']['per_atom_violations'], axis=-1)) | |
| ret['violations_per_residue'] = utils.mask_mean( | |
| mask=batch['seq_mask'], | |
| value=violations['total_per_residue_violations_mask']) | |
| return ret | |
| def supervised_chi_loss(ret, batch, value, config): | |
| """Computes loss for direct chi angle supervision. | |
| Jumper et al. (2021) Suppl. Alg. 27 "torsionAngleLoss" | |
| Args: | |
| ret: Dictionary to write outputs into, needs to contain 'loss'. | |
| batch: Batch, needs to contain 'seq_mask', 'chi_mask', 'chi_angles'. | |
| value: Dictionary containing structure module output, needs to contain | |
| value['sidechains']['angles_sin_cos'] for angles and | |
| value['sidechains']['unnormalized_angles_sin_cos'] for unnormalized | |
| angles. | |
| config: Configuration of loss, should contain 'chi_weight' and | |
| 'angle_norm_weight', 'angle_norm_weight' scales angle norm term, | |
| 'chi_weight' scales torsion term. | |
| """ | |
| eps = 1e-6 | |
| sequence_mask = batch['seq_mask'] | |
| num_res = sequence_mask.shape[0] | |
| chi_mask = batch['chi_mask'].astype(jnp.float32) | |
| pred_angles = jnp.reshape( | |
| value['sidechains']['angles_sin_cos'], [-1, num_res, 7, 2]) | |
| pred_angles = pred_angles[:, :, 3:] | |
| residue_type_one_hot = jax.nn.one_hot( | |
| batch['aatype'], residue_constants.restype_num + 1, | |
| dtype=jnp.float32)[None] | |
| chi_pi_periodic = jnp.einsum('ijk, kl->ijl', residue_type_one_hot, | |
| jnp.asarray(residue_constants.chi_pi_periodic)) | |
| true_chi = batch['chi_angles'][None] | |
| sin_true_chi = jnp.sin(true_chi) | |
| cos_true_chi = jnp.cos(true_chi) | |
| sin_cos_true_chi = jnp.stack([sin_true_chi, cos_true_chi], axis=-1) | |
| # This is -1 if chi is pi-periodic and +1 if it's 2pi-periodic | |
| shifted_mask = (1 - 2 * chi_pi_periodic)[..., None] | |
| sin_cos_true_chi_shifted = shifted_mask * sin_cos_true_chi | |
| sq_chi_error = jnp.sum( | |
| squared_difference(sin_cos_true_chi, pred_angles), -1) | |
| sq_chi_error_shifted = jnp.sum( | |
| squared_difference(sin_cos_true_chi_shifted, pred_angles), -1) | |
| sq_chi_error = jnp.minimum(sq_chi_error, sq_chi_error_shifted) | |
| sq_chi_loss = utils.mask_mean(mask=chi_mask[None], value=sq_chi_error) | |
| ret['chi_loss'] = sq_chi_loss | |
| ret['loss'] += config.chi_weight * sq_chi_loss | |
| unnormed_angles = jnp.reshape( | |
| value['sidechains']['unnormalized_angles_sin_cos'], [-1, num_res, 7, 2]) | |
| angle_norm = jnp.sqrt(jnp.sum(jnp.square(unnormed_angles), axis=-1) + eps) | |
| norm_error = jnp.abs(angle_norm - 1.) | |
| angle_norm_loss = utils.mask_mean(mask=sequence_mask[None, :, None], | |
| value=norm_error) | |
| ret['angle_norm_loss'] = angle_norm_loss | |
| ret['loss'] += config.angle_norm_weight * angle_norm_loss | |
| def generate_new_affine(sequence_mask): | |
| num_residues, _ = sequence_mask.shape | |
| quaternion = jnp.tile( | |
| jnp.reshape(jnp.asarray([1., 0., 0., 0.]), [1, 4]), | |
| [num_residues, 1]) | |
| translation = jnp.zeros([num_residues, 3]) | |
| return quat_affine.QuatAffine(quaternion, translation, unstack_inputs=True) | |
| def l2_normalize(x, axis=-1, epsilon=1e-12): | |
| return x / jnp.sqrt( | |
| jnp.maximum(jnp.sum(x**2, axis=axis, keepdims=True), epsilon)) | |
| class MultiRigidSidechain(hk.Module): | |
| """Class to make side chain atoms.""" | |
| def __init__(self, config, global_config, name='rigid_sidechain'): | |
| super().__init__(name=name) | |
| self.config = config | |
| self.global_config = global_config | |
| def __call__(self, affine, representations_list, aatype): | |
| """Predict side chains using multi-rigid representations. | |
| Args: | |
| affine: The affines for each residue (translations in angstroms). | |
| representations_list: A list of activations to predict side chains from. | |
| aatype: Amino acid types. | |
| Returns: | |
| Dict containing atom positions and frames (in angstroms). | |
| """ | |
| act = [ | |
| common_modules.Linear( # pylint: disable=g-complex-comprehension | |
| self.config.num_channel, | |
| name='input_projection')(jax.nn.relu(x)) | |
| for x in representations_list | |
| ] | |
| # Sum the activation list (equivalent to concat then Linear). | |
| act = sum(act) | |
| final_init = 'zeros' if self.global_config.zero_init else 'linear' | |
| # Mapping with some residual blocks. | |
| for _ in range(self.config.num_residual_block): | |
| old_act = act | |
| act = common_modules.Linear( | |
| self.config.num_channel, | |
| initializer='relu', | |
| name='resblock1')( | |
| jax.nn.relu(act)) | |
| act = common_modules.Linear( | |
| self.config.num_channel, | |
| initializer=final_init, | |
| name='resblock2')( | |
| jax.nn.relu(act)) | |
| act += old_act | |
| # Map activations to torsion angles. Shape: (num_res, 14). | |
| num_res = act.shape[0] | |
| unnormalized_angles = common_modules.Linear( | |
| 14, name='unnormalized_angles')( | |
| jax.nn.relu(act)) | |
| unnormalized_angles = jnp.reshape( | |
| unnormalized_angles, [num_res, 7, 2]) | |
| angles = l2_normalize(unnormalized_angles, axis=-1) | |
| outputs = { | |
| 'angles_sin_cos': angles, # jnp.ndarray (N, 7, 2) | |
| 'unnormalized_angles_sin_cos': | |
| unnormalized_angles, # jnp.ndarray (N, 7, 2) | |
| } | |
| # Map torsion angles to frames. | |
| backb_to_global = r3.rigids_from_quataffine(affine) | |
| # Jumper et al. (2021) Suppl. Alg. 24 "computeAllAtomCoordinates" | |
| # r3.Rigids with shape (N, 8). | |
| all_frames_to_global = all_atom.torsion_angles_to_frames( | |
| aatype, | |
| backb_to_global, | |
| angles) | |
| # Use frames and literature positions to create the final atom coordinates. | |
| # r3.Vecs with shape (N, 14). | |
| pred_positions = all_atom.frames_and_literature_positions_to_atom14_pos( | |
| aatype, all_frames_to_global) | |
| outputs.update({ | |
| 'atom_pos': pred_positions, # r3.Vecs (N, 14) | |
| 'frames': all_frames_to_global, # r3.Rigids (N, 8) | |
| }) | |
| return outputs | |