Spaces:
Running
on
T4
Running
on
T4
| # Copyright 2021 DeepMind Technologies Limited | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Library to run Jackhmmer from Python.""" | |
| from concurrent import futures | |
| import glob | |
| import os | |
| import subprocess | |
| from typing import Any, Callable, Mapping, Optional, Sequence | |
| from urllib import request | |
| from absl import logging | |
| from alphafold.data.tools import utils | |
| # Internal import (7716). | |
| class Jackhmmer: | |
| """Python wrapper of the Jackhmmer binary.""" | |
| def __init__(self, | |
| *, | |
| binary_path: str, | |
| database_path: str, | |
| n_cpu: int = 8, | |
| n_iter: int = 1, | |
| e_value: float = 0.0001, | |
| z_value: Optional[int] = None, | |
| get_tblout: bool = False, | |
| filter_f1: float = 0.0005, | |
| filter_f2: float = 0.00005, | |
| filter_f3: float = 0.0000005, | |
| incdom_e: Optional[float] = None, | |
| dom_e: Optional[float] = None, | |
| num_streamed_chunks: Optional[int] = None, | |
| streaming_callback: Optional[Callable[[int], None]] = None): | |
| """Initializes the Python Jackhmmer wrapper. | |
| Args: | |
| binary_path: The path to the jackhmmer executable. | |
| database_path: The path to the jackhmmer database (FASTA format). | |
| n_cpu: The number of CPUs to give Jackhmmer. | |
| n_iter: The number of Jackhmmer iterations. | |
| e_value: The E-value, see Jackhmmer docs for more details. | |
| z_value: The Z-value, see Jackhmmer docs for more details. | |
| get_tblout: Whether to save tblout string. | |
| filter_f1: MSV and biased composition pre-filter, set to >1.0 to turn off. | |
| filter_f2: Viterbi pre-filter, set to >1.0 to turn off. | |
| filter_f3: Forward pre-filter, set to >1.0 to turn off. | |
| incdom_e: Domain e-value criteria for inclusion of domains in MSA/next | |
| round. | |
| dom_e: Domain e-value criteria for inclusion in tblout. | |
| num_streamed_chunks: Number of database chunks to stream over. | |
| streaming_callback: Callback function run after each chunk iteration with | |
| the iteration number as argument. | |
| """ | |
| self.binary_path = binary_path | |
| self.database_path = database_path | |
| self.num_streamed_chunks = num_streamed_chunks | |
| if not os.path.exists(self.database_path) and num_streamed_chunks is None: | |
| logging.error('Could not find Jackhmmer database %s', database_path) | |
| raise ValueError(f'Could not find Jackhmmer database {database_path}') | |
| self.n_cpu = n_cpu | |
| self.n_iter = n_iter | |
| self.e_value = e_value | |
| self.z_value = z_value | |
| self.filter_f1 = filter_f1 | |
| self.filter_f2 = filter_f2 | |
| self.filter_f3 = filter_f3 | |
| self.incdom_e = incdom_e | |
| self.dom_e = dom_e | |
| self.get_tblout = get_tblout | |
| self.streaming_callback = streaming_callback | |
| def _query_chunk(self, input_fasta_path: str, database_path: str | |
| ) -> Mapping[str, Any]: | |
| """Queries the database chunk using Jackhmmer.""" | |
| with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir: | |
| sto_path = os.path.join(query_tmp_dir, 'output.sto') | |
| # The F1/F2/F3 are the expected proportion to pass each of the filtering | |
| # stages (which get progressively more expensive), reducing these | |
| # speeds up the pipeline at the expensive of sensitivity. They are | |
| # currently set very low to make querying Mgnify run in a reasonable | |
| # amount of time. | |
| cmd_flags = [ | |
| # Don't pollute stdout with Jackhmmer output. | |
| '-o', '/dev/null', | |
| '-A', sto_path, | |
| '--noali', | |
| '--F1', str(self.filter_f1), | |
| '--F2', str(self.filter_f2), | |
| '--F3', str(self.filter_f3), | |
| '--incE', str(self.e_value), | |
| # Report only sequences with E-values <= x in per-sequence output. | |
| '-E', str(self.e_value), | |
| '--cpu', str(self.n_cpu), | |
| '-N', str(self.n_iter) | |
| ] | |
| if self.get_tblout: | |
| tblout_path = os.path.join(query_tmp_dir, 'tblout.txt') | |
| cmd_flags.extend(['--tblout', tblout_path]) | |
| if self.z_value: | |
| cmd_flags.extend(['-Z', str(self.z_value)]) | |
| if self.dom_e is not None: | |
| cmd_flags.extend(['--domE', str(self.dom_e)]) | |
| if self.incdom_e is not None: | |
| cmd_flags.extend(['--incdomE', str(self.incdom_e)]) | |
| cmd = [self.binary_path] + cmd_flags + [input_fasta_path, | |
| database_path] | |
| logging.info('Launching subprocess "%s"', ' '.join(cmd)) | |
| process = subprocess.Popen( | |
| cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
| with utils.timing( | |
| f'Jackhmmer ({os.path.basename(database_path)}) query'): | |
| _, stderr = process.communicate() | |
| retcode = process.wait() | |
| if retcode: | |
| raise RuntimeError( | |
| 'Jackhmmer failed\nstderr:\n%s\n' % stderr.decode('utf-8')) | |
| # Get e-values for each target name | |
| tbl = '' | |
| if self.get_tblout: | |
| with open(tblout_path) as f: | |
| tbl = f.read() | |
| with open(sto_path) as f: | |
| sto = f.read() | |
| raw_output = dict( | |
| sto=sto, | |
| tbl=tbl, | |
| stderr=stderr, | |
| n_iter=self.n_iter, | |
| e_value=self.e_value) | |
| return raw_output | |
| def query(self, input_fasta_path: str) -> Sequence[Mapping[str, Any]]: | |
| """Queries the database using Jackhmmer.""" | |
| if self.num_streamed_chunks is None: | |
| return [self._query_chunk(input_fasta_path, self.database_path)] | |
| db_basename = os.path.basename(self.database_path) | |
| db_remote_chunk = lambda db_idx: f'{self.database_path}.{db_idx}' | |
| db_local_chunk = lambda db_idx: f'/tmp/ramdisk/{db_basename}.{db_idx}' | |
| # Remove existing files to prevent OOM | |
| for f in glob.glob(db_local_chunk('[0-9]*')): | |
| try: | |
| os.remove(f) | |
| except OSError: | |
| print(f'OSError while deleting {f}') | |
| # Download the (i+1)-th chunk while Jackhmmer is running on the i-th chunk | |
| with futures.ThreadPoolExecutor(max_workers=2) as executor: | |
| chunked_output = [] | |
| for i in range(1, self.num_streamed_chunks + 1): | |
| # Copy the chunk locally | |
| if i == 1: | |
| future = executor.submit( | |
| request.urlretrieve, db_remote_chunk(i), db_local_chunk(i)) | |
| if i < self.num_streamed_chunks: | |
| next_future = executor.submit( | |
| request.urlretrieve, db_remote_chunk(i+1), db_local_chunk(i+1)) | |
| # Run Jackhmmer with the chunk | |
| future.result() | |
| chunked_output.append( | |
| self._query_chunk(input_fasta_path, db_local_chunk(i))) | |
| # Remove the local copy of the chunk | |
| os.remove(db_local_chunk(i)) | |
| future = next_future | |
| if self.streaming_callback: | |
| self.streaming_callback(i) | |
| return chunked_output | |