Spaces:
Running
on
T4
Running
on
T4
| # Copyright 2021 DeepMind Technologies Limited | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Code to generate processed features.""" | |
| import copy | |
| from typing import List, Mapping, Tuple | |
| from alphafold.model.tf import input_pipeline | |
| from alphafold.model.tf import proteins_dataset | |
| import ml_collections | |
| import numpy as np | |
| import tensorflow.compat.v1 as tf | |
| FeatureDict = Mapping[str, np.ndarray] | |
| def make_data_config( | |
| config: ml_collections.ConfigDict, | |
| num_res: int, | |
| ) -> Tuple[ml_collections.ConfigDict, List[str]]: | |
| """Makes a data config for the input pipeline.""" | |
| cfg = copy.deepcopy(config.data) | |
| feature_names = cfg.common.unsupervised_features | |
| if cfg.common.use_templates: | |
| feature_names += cfg.common.template_features | |
| with cfg.unlocked(): | |
| cfg.eval.crop_size = num_res | |
| return cfg, feature_names | |
| def tf_example_to_features(tf_example: tf.train.Example, | |
| config: ml_collections.ConfigDict, | |
| random_seed: int = 0) -> FeatureDict: | |
| """Converts tf_example to numpy feature dictionary.""" | |
| num_res = int(tf_example.features.feature['seq_length'].int64_list.value[0]) | |
| cfg, feature_names = make_data_config(config, num_res=num_res) | |
| if 'deletion_matrix_int' in set(tf_example.features.feature): | |
| deletion_matrix_int = ( | |
| tf_example.features.feature['deletion_matrix_int'].int64_list.value) | |
| feat = tf.train.Feature(float_list=tf.train.FloatList( | |
| value=map(float, deletion_matrix_int))) | |
| tf_example.features.feature['deletion_matrix'].CopyFrom(feat) | |
| del tf_example.features.feature['deletion_matrix_int'] | |
| tf_graph = tf.Graph() | |
| with tf_graph.as_default(), tf.device('/device:CPU:0'): | |
| tf.compat.v1.set_random_seed(random_seed) | |
| tensor_dict = proteins_dataset.create_tensor_dict( | |
| raw_data=tf_example.SerializeToString(), | |
| features=feature_names) | |
| processed_batch = input_pipeline.process_tensors_from_config( | |
| tensor_dict, cfg) | |
| tf_graph.finalize() | |
| with tf.Session(graph=tf_graph) as sess: | |
| features = sess.run(processed_batch) | |
| return {k: v for k, v in features.items() if v.dtype != 'O'} | |
| def np_example_to_features(np_example: FeatureDict, | |
| config: ml_collections.ConfigDict, | |
| random_seed: int = 0) -> FeatureDict: | |
| """Preprocesses NumPy feature dict using TF pipeline.""" | |
| np_example = dict(np_example) | |
| num_res = int(np_example['seq_length'][0]) | |
| cfg, feature_names = make_data_config(config, num_res=num_res) | |
| if 'deletion_matrix_int' in np_example: | |
| np_example['deletion_matrix'] = ( | |
| np_example.pop('deletion_matrix_int').astype(np.float32)) | |
| tf_graph = tf.Graph() | |
| with tf_graph.as_default(), tf.device('/device:CPU:0'): | |
| tf.compat.v1.set_random_seed(random_seed) | |
| tensor_dict = proteins_dataset.np_to_tensor_dict( | |
| np_example=np_example, features=feature_names) | |
| processed_batch = input_pipeline.process_tensors_from_config( | |
| tensor_dict, cfg) | |
| tf_graph.finalize() | |
| with tf.Session(graph=tf_graph) as sess: | |
| features = sess.run(processed_batch) | |
| return {k: v for k, v in features.items() if v.dtype != 'O'} | |