import biotite import joblib import math import numpy as np import os import scipy.spatial as spa import torch import torch.nn.functional as F from Bio import PDB from Bio.SeqUtils import seq1 from pathlib import Path from torch_geometric.data import Batch, Data from torch_scatter import scatter_mean, scatter_sum, scatter_max from tqdm import tqdm from typing import List from biotite.sequence import ProteinSequence from biotite.structure import filter_backbone, get_chains from biotite.structure.io import pdb, pdbx from biotite.structure.residues import get_residues from .encoder import AutoGraphEncoder def _normalize(tensor, dim=-1): """ Normalizes a `torch.Tensor` along dimension `dim` without `nan`s. """ return torch.nan_to_num( torch.div(tensor, torch.norm(tensor, dim=dim, keepdim=True)) ) def _rbf(D, D_min=0.0, D_max=20.0, D_count=16, device="cpu"): """ From https://github.com/jingraham/neurips19-graph-protein-design Returns an RBF embedding of `torch.Tensor` `D` along a new axis=-1. That is, if `D` has shape [...dims], then the returned tensor will have shape [...dims, D_count]. """ D_mu = torch.linspace(D_min, D_max, D_count, device=device) D_mu = D_mu.view([1, -1]) D_sigma = (D_max - D_min) / D_count D_expand = torch.unsqueeze(D, -1) RBF = torch.exp(-(((D_expand - D_mu) / D_sigma) ** 2)) return RBF def _orientations(X_ca): forward = _normalize(X_ca[1:] - X_ca[:-1]) backward = _normalize(X_ca[:-1] - X_ca[1:]) forward = F.pad(forward, [0, 0, 0, 1]) backward = F.pad(backward, [0, 0, 1, 0]) return torch.cat([forward.unsqueeze(-2), backward.unsqueeze(-2)], -2) def _sidechains(X): n, origin, c = X[:, 0], X[:, 1], X[:, 2] c, n = _normalize(c - origin), _normalize(n - origin) bisector = _normalize(c + n) perp = _normalize(torch.cross(c, n)) vec = -bisector * math.sqrt(1 / 3) - perp * math.sqrt(2 / 3) return vec def _positional_embeddings(edge_index, num_embeddings=16, period_range=[2, 1000]): # From https://github.com/jingraham/neurips19-graph-protein-design d = edge_index[0] - edge_index[1] frequency = torch.exp( torch.arange(0, num_embeddings, 2, dtype=torch.float32) * -(np.log(10000.0) / num_embeddings) ) angles = d.unsqueeze(-1) * frequency E = torch.cat((torch.cos(angles), torch.sin(angles)), -1) return E def generate_graph(pdb_file, max_distance=10): """ generate graph data from pdb file params: pdb_file: pdb file path node_level: residue or secondary_structure node_s_type: ss3, ss8, foldseek max_distance: cut off foldseek_fasta_file: foldseek fasta file path foldseek_fasta_multi_chain: pdb multi chain for foldseek fasta return: graph data """ pdb_parser = PDB.PDBParser(QUIET=True) structure = pdb_parser.get_structure("protein", pdb_file) model = structure[0] # extract amino acid sequence seq = [] # extract amino acid coordinates aa_coords = {"N": [], "CA": [], "C": [], "O": []} for model in structure: for chain in model: for residue in chain: if residue.get_id()[0] == " ": seq.append(residue.get_resname()) for atom_name in aa_coords.keys(): atom = residue[atom_name] aa_coords[atom_name].append(atom.get_coord().tolist()) one_letter_seq = "".join([seq1(aa) for aa in seq]) # aa means amino acid coords = list(zip(aa_coords["N"], aa_coords["CA"], aa_coords["C"], aa_coords["O"])) coords = torch.tensor(coords) # mask out the missing coordinates mask = torch.isfinite(coords.sum(dim=(1, 2))) coords[~mask] = np.inf ca_coords = coords[:, 1] node_s = torch.zeros(len(ca_coords), 20) # build graph and max_distance distances = spa.distance_matrix(ca_coords, ca_coords) edge_index = torch.tensor(np.array(np.where(distances < max_distance))) # remove loop mask = edge_index[0] != edge_index[1] edge_index = edge_index[:, mask] # node features orientations = _orientations(ca_coords) sidechains = _sidechains(coords) node_v = torch.cat([orientations, sidechains.unsqueeze(-2)], dim=-2) # edge features pos_embeddings = _positional_embeddings(edge_index) E_vectors = ca_coords[edge_index[0]] - ca_coords[edge_index[1]] rbf = _rbf(E_vectors.norm(dim=-1), D_count=16) edge_s = torch.cat([rbf, pos_embeddings], dim=-1) edge_v = _normalize(E_vectors).unsqueeze(-2) # node_v: [node_num, 3, 3] # edge_index: [2, edge_num] # edge_s: [edge_num, 16+16] # edge_v: [edge_num, 1, 3] node_s, node_v, edge_s, edge_v = map( torch.nan_to_num, (node_s, node_v, edge_s, edge_v) ) data = Data( node_s=node_s, node_v=node_v, edge_index=edge_index, edge_s=edge_s, edge_v=edge_v, distances=distances, aa_seq=one_letter_seq, ) return data def get_atom_coords_residuewise(atoms: List[str], struct: biotite.structure.AtomArray): """ Example for atoms argument: ["N", "CA", "C"] """ def filterfn(s, axis=None): filters = np.stack([s.atom_name == name for name in atoms], axis=1) sum = filters.sum(0) if not np.all(sum <= np.ones(filters.shape[1])): raise RuntimeError("structure has multiple atoms with same name") index = filters.argmax(0) coords = s[index].coord coords[sum == 0] = float("nan") return coords return biotite.structure.apply_residue_wise(struct, struct, filterfn) def extract_coords_from_structure(structure: biotite.structure.AtomArray): """ Args: structure: An instance of biotite AtomArray Returns: Tuple (coords, seq) - coords is an L x 3 x 3 array for N, CA, C coordinates - seq is the extracted sequence """ coords = get_atom_coords_residuewise(["N", "CA", "C"], structure) residue_identities = get_residues(structure)[1] seq = "".join([ProteinSequence.convert_letter_3to1(r) for r in residue_identities]) return coords def extract_seq_from_pdb(pdb_file, chain=None): """ Args: structure: An instance of biotite AtomArray Returns: - seq is the extracted sequence """ structure = load_structure(pdb_file, chain) residue_identities = get_residues(structure)[1] seq = "".join([ProteinSequence.convert_letter_3to1(r) for r in residue_identities]) return seq def generate_pos_subgraph( graph_data, subgraph_depth=None, subgraph_interval=1, max_distance=10, anchor_nodes=None, pure_subgraph=False, device="cuda" if torch.cuda.is_available() else "cpu" ): # move graph_data to GPU graph_data = Data( node_s=graph_data.node_s.to(device) if torch.is_tensor(graph_data.node_s) else torch.tensor(graph_data.node_s, device=device), node_v=graph_data.node_v.to(device) if torch.is_tensor(graph_data.node_v) else torch.tensor(graph_data.node_v, device=device), edge_index=graph_data.edge_index.to(device) if torch.is_tensor(graph_data.edge_index) else torch.tensor(graph_data.edge_index, device=device), edge_s=graph_data.edge_s.to(device) if torch.is_tensor(graph_data.edge_s) else torch.tensor(graph_data.edge_s, device=device), edge_v=graph_data.edge_v.to(device) if torch.is_tensor(graph_data.edge_v) else torch.tensor(graph_data.edge_v, device=device), distances=graph_data.distances.to(device) if torch.is_tensor(graph_data.distances) else torch.tensor(graph_data.distances, device=device), aa_seq=graph_data.aa_seq ) distances = graph_data.distances if subgraph_depth is None: subgraph_depth = 50 # Calculate anchor nodes if not provided if anchor_nodes is None: anchor_nodes = list(range(0, len(graph_data.aa_seq), subgraph_interval)) anchor_nodes_tensor = torch.tensor(anchor_nodes, device=device) # Move anchor nodes to device # Get the k nearest neighbors for ALL anchor nodes (batched) k = 50 nearest_indices = torch.argsort(distances, dim=1)[:, :k] # (num_nodes, k) distance_mask = torch.gather(distances, 1, nearest_indices) < max_distance # (num_nodes, k) nearest_indices = torch.where(distance_mask, nearest_indices, torch.tensor(-1, device=device)) # (num_nodes, k) subgraph_dict = {} for anchor_node in anchor_nodes: #Reverted back to for loop to ensure everything works with batches try: #Get neighbors for each anchornode k_neighbors = nearest_indices[anchor_node] k_neighbors = k_neighbors[k_neighbors != -1] if len(k_neighbors) == 0: # Skip if no neighbors found continue if len(k_neighbors) > 30: k_neighbors = k_neighbors[:40] k_neighbors, _ = torch.sort(k_neighbors) sub_matrix = distances.index_select(0, k_neighbors).index_select(1, k_neighbors) # Create edge indices efficiently sub_edges = torch.nonzero(sub_matrix < max_distance, as_tuple=False) mask = sub_edges[:, 0] != sub_edges[:, 1] sub_edge_index = sub_edges[mask] if len(sub_edge_index) == 0: # Skip if no edges found continue # Move edge_index to GPU only when needed edge_index_device = graph_data.edge_index.to(device) original_edge_index = k_neighbors[sub_edge_index] # More memory efficient edge matching matches = [] for edge in original_edge_index: match = (edge_index_device[0] == edge[0]) & (edge_index_device[1] == edge[1]) matches.append(match) matches = torch.stack(matches) edge_to_feature_idx = torch.nonzero(matches, as_tuple=True)[0].to(device) if len(edge_to_feature_idx) == 0: # Skip if no matching edges continue #Create data new_node_s = graph_data.node_s[k_neighbors].to(device) new_node_v = graph_data.node_v[k_neighbors].to(device) new_edge_s = graph_data.edge_s[edge_to_feature_idx].to(device) new_edge_v = graph_data.edge_v[edge_to_feature_idx].to(device) result = Data( edge_index=sub_edge_index.T.to(device), edge_s=new_edge_s.to(device), edge_v=new_edge_v.to(device), node_s=new_node_s.to(device), node_v=new_node_v.to(device), ) if not pure_subgraph: result.index_map = { int(old_id.to(device).item()): new_id for new_id, old_id in enumerate(k_neighbors) } subgraph_dict[anchor_node] = result except Exception as e: print(f"Error processing anchor node {anchor_node}: {str(e)}") continue return subgraph_dict def load_structure(fpath, chain=None): """ Args: fpath: filepath to either pdb or cif file chain: the chain id or list of chain ids to load Returns: biotite.structure.AtomArray """ if fpath.endswith("cif"): with open(fpath) as fin: pdbxf = pdbx.PDBxFile.read(fin) structure = pdbx.get_structure(pdbxf, model=1) elif fpath.endswith("pdb"): with open(fpath) as fin: pdbf = pdb.PDBFile.read(fin) structure = pdb.get_structure(pdbf, model=1) bbmask = filter_backbone(structure) structure = structure[bbmask] all_chains = get_chains(structure) if len(all_chains) == 0: raise ValueError("No chains found in the input file.") if chain is None: chain_ids = all_chains elif isinstance(chain, list): chain_ids = chain else: chain_ids = [chain] for chain in chain_ids: if chain not in all_chains: raise ValueError(f"Chain {chain} not found in input file") chain_filter = [a.chain_id in chain_ids for a in structure] structure = structure[chain_filter] return structure def get_atom_coords_residuewise(atoms: List[str], struct: biotite.structure.AtomArray): """ Example for atoms argument: ["N", "CA", "C"] """ def filterfn(s, axis=None): filters = np.stack([s.atom_name == name for name in atoms], axis=1) sum = filters.sum(0) if not np.all(sum <= np.ones(filters.shape[1])): raise RuntimeError("structure has multiple atoms with same name") index = filters.argmax(0) coords = s[index].coord coords[sum == 0] = float("nan") return coords return biotite.structure.apply_residue_wise(struct, struct, filterfn) def extract_coords_from_structure(structure: biotite.structure.AtomArray): """ Args: structure: An instance of biotite AtomArray Returns: Tuple (coords, seq) - coords is an L x 3 x 3 array for N, CA, C coordinates - seq is the extracted sequence """ coords = get_atom_coords_residuewise(["N", "CA", "C"], structure) residue_identities = get_residues(structure)[1] seq = "".join([ProteinSequence.convert_letter_3to1(r) for r in residue_identities]) return coords def extract_seq_from_pdb(pdb_file, chain=None): """ Args: structure: An instance of biotite AtomArray Returns: - seq is the extracted sequence """ structure = load_structure(pdb_file, chain) residue_identities = get_residues(structure)[1] seq = "".join([ProteinSequence.convert_letter_3to1(r) for r in residue_identities]) return seq def convert_graph(graph): graph = Data( node_s=graph.node_s.to(torch.float32), node_v=graph.node_v.to(torch.float32), edge_index=graph.edge_index.to(torch.int64), edge_s=graph.edge_s.to(torch.float32), edge_v=graph.edge_v.to(torch.float32), ) return graph def predict_structure(model, cluster_models, dataloader, datalabels, device): epoch_iterator = dataloader struc_label_dict = {} cluster_model_dict = {} for cluster_model_path in cluster_models: cluster_model_name = cluster_model_path.split("/")[-1].split(".")[0] struc_label_dict[cluster_model_name] = {} cluster_model_dict[cluster_model_name] = joblib.load(cluster_model_path) with torch.no_grad(): for batch, label_dict in zip(epoch_iterator, datalabels): batch.to(device) h_V = (batch.node_s, batch.node_v) h_E = (batch.edge_s, batch.edge_v) node_emebddings = model.get_embedding(h_V, batch.edge_index, h_E) graph_emebddings = scatter_mean(node_emebddings, batch.batch, dim=0).to(device) norm_graph_emebddings = F.normalize(graph_emebddings, p=2, dim=1) struc_label_dict[cluster_model_name][label_dict['name']]={} for name, cluster_model in cluster_model_dict.items(): batch_structure_labels = cluster_model.predict( norm_graph_emebddings.cpu() ).tolist() struc_label_dict[name][label_dict['name']]['seq']=label_dict['aa_seq'] struc_label_dict[name][label_dict['name']]['struct']=batch_structure_labels return struc_label_dict def get_embeds(model, dataloader, device, pooling="mean"): epoch_iterator = tqdm(dataloader) embeds = [] with torch.no_grad(): for batch in epoch_iterator: batch.to(device) h_V = (batch.node_s, batch.node_v) h_E = (batch.edge_s, batch.edge_v) node_embeds = model.get_embedding(h_V, batch.edge_index, h_E).cpu() if pooling == "mean": graph_embeds = scatter_mean(node_embeds, batch.batch.cpu(), dim=0) elif pooling == "sum": graph_embeds = scatter_sum(node_embeds, batch.batch.cpu(), dim=0) elif pooling == "max": graph_embeds, _ = scatter_max(node_embeds, batch.batch.cpu(), dim=0) else: raise ValueError("pooling should be mean, sum or max") embeds.append(graph_embeds) embeds = torch.cat(embeds, dim=0) norm_embeds = F.normalize(embeds, p=2, dim=1) return norm_embeds def process_pdb_file( pdb_file, subgraph_depth, subgraph_interval, max_distance, device="cuda" if torch.cuda.is_available() else "cpu" ): result_dict, subgraph_dict = {}, {} result_dict["name"] = Path(pdb_file).name try: graph = generate_graph(pdb_file, max_distance) except Exception as e: print(f"Error in processing {pdb_file}") result_dict["error"] = str(e) return None, result_dict, 0 result_dict["aa_seq"] = graph.aa_seq anchor_nodes = list(range(0, len(graph.node_s), subgraph_interval)) #Define anchor nodes try: #Run subgraph generation subgraph_dict = generate_pos_subgraph( graph, subgraph_depth, subgraph_interval, max_distance, anchor_nodes=anchor_nodes, pure_subgraph=True, device=device ) #Move all subgraphs to GPU for key in subgraph_dict.keys(): subgraph_dict[key] = convert_graph(subgraph_dict[key]) except Exception as e: print(f"Error processing subgraph {e}") return None, result_dict, 0 subgraph_dict = dict(sorted(subgraph_dict.items(), key=lambda x: x[0])) subgraphs = list(subgraph_dict.values()) return subgraphs, result_dict, len(anchor_nodes) def pdb_converter( pdb_files, subgraph_depth, subgraph_interval, max_distance, device="cuda" if torch.cuda.is_available() else "cpu", batch_size=32 ): error_proteins, error_messages = [], [] dataset, results, node_counts = [], [], [] for i in tqdm(range(0, len(pdb_files), batch_size), desc="Processing PDB files"): batch = pdb_files[i:i + batch_size] for pdb_file in batch: pdb_subgraphs, result_dict, node_count = process_pdb_file( pdb_file, subgraph_depth, subgraph_interval, max_distance, device=device ) if pdb_subgraphs is None: error_proteins.append(result_dict["name"]) error_messages.append(result_dict["error"]) continue dataset.append(pdb_subgraphs) results.append(result_dict) node_counts.append(node_count) if error_proteins: print(f"Found {len(error_proteins)} errors:") for name, msg in zip(error_proteins, error_messages): print(f"{name}: {msg}") def collate_fn(batch): batch_graphs = [] for d in batch: batch_graphs.extend(d) batch_graphs = Batch.from_data_list(batch_graphs) batch_graphs.node_s = torch.zeros_like(batch_graphs.node_s) return batch_graphs def data_loader(): for item in dataset: yield collate_fn([item]) return data_loader(), results class PdbQuantizer: def __init__( self, structure_vocab_size=2048, max_distance=10, subgraph_depth=None, subgraph_interval=1, anchor_nodes=None, model_path=None, cluster_dir=None, cluster_model=None, device=None, batch_size=16, ) -> None: assert structure_vocab_size in [20, 64, 128, 512, 1024, 2048, 4096] self.batch_size = batch_size self.max_distance = max_distance self.subgraph_depth = subgraph_depth self.subgraph_interval = subgraph_interval self.anchor_nodes = anchor_nodes if model_path is None: self.model_path = str(Path(__file__).parent / "static" / "AE.pt") else: self.model_path = model_path self.structure_vocab_size = structure_vocab_size if cluster_dir is None: self.cluster_dir = str(Path(__file__).parent / "static") self.cluster_model = [ Path(self.cluster_dir) / f"{structure_vocab_size}.joblib", ] else: self.cluster_dir = cluster_dir self.cluster_model = cluster_model if device is None: self.device = "cuda" if torch.cuda.is_available() else "cpu" else: self.device = device # Load model node_dim = (256, 32) edge_dim = (64, 2) model = AutoGraphEncoder( node_in_dim=(20, 3), node_h_dim=node_dim, edge_in_dim=(32, 1), edge_h_dim=edge_dim, num_layers=6, ) model.load_state_dict(torch.load(self.model_path)) model = model.to(self.device) model = model.eval() self.model = model self.cluster_models = [ os.path.join(self.cluster_dir, m) for m in self.cluster_model ] def __call__(self, pdb_files, return_residue_seq=False): if isinstance(pdb_files, str): pdb_files = [pdb_files] elif isinstance(pdb_files, list): pass else: raise ValueError("pdb_files should be either a string or a list of strings") data_loader, results = pdb_converter( pdb_files, self.subgraph_depth, self.subgraph_interval, self.max_distance, device=self.device, batch_size=self.batch_size ) structures = predict_structure( self.model, self.cluster_models, data_loader, results, self.device ) if not return_residue_seq: for clusterModelLabels in structures.keys(): for structureDict in structures[clusterModelLabels].keys(): structures[clusterModelLabels][structureDict].pop('seq', None) return structures