Spaces:
Paused
Paused
| import math | |
| import os | |
| import csv | |
| import random | |
| import torch | |
| from torch.utils import data | |
| import numpy as np | |
| from dateutil import parser | |
| import contigs | |
| from util import * | |
| from kinematics import * | |
| import pandas as pd | |
| import sys | |
| import torch.nn as nn | |
| from icecream import ic | |
| def write_pdb(filename, seq, atoms, Bfacts=None, prefix=None, chains=None): | |
| L = len(seq) | |
| ctr = 1 | |
| seq = seq.long() | |
| with open(filename, 'w+') as f: | |
| for i,s in enumerate(seq): | |
| if chains is None: | |
| chain='A' | |
| else: | |
| chain=chains[i] | |
| if (len(atoms.shape)==2): | |
| f.write ("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n"%( | |
| "ATOM", ctr, " CA ", util.num2aa[s], | |
| chain, i+1, atoms[i,0], atoms[i,1], atoms[i,2], | |
| 1.0, Bfacts[i] ) ) | |
| ctr += 1 | |
| elif atoms.shape[1]==3: | |
| for j,atm_j in enumerate((" N "," CA "," C ")): | |
| f.write ("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n"%( | |
| "ATOM", ctr, atm_j, num2aa[s], | |
| chain, i+1, atoms[i,j,0], atoms[i,j,1], atoms[i,j,2], | |
| 1.0, Bfacts[i] ) ) | |
| ctr += 1 | |
| else: | |
| atms = aa2long[s] | |
| for j,atm_j in enumerate(atms): | |
| if (atm_j is not None): | |
| f.write ("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n"%( | |
| "ATOM", ctr, atm_j, num2aa[s], | |
| chain, i+1, atoms[i,j,0], atoms[i,j,1], atoms[i,j,2], | |
| 1.0, Bfacts[i] ) ) | |
| ctr += 1 | |
| def preprocess(xyz_t, t1d, DEVICE, masks_1d, ti_dev=None, ti_flip=None, ang_ref=None): | |
| B, _, L, _, _ = xyz_t.shape | |
| seq_tmp = t1d[...,:-1].argmax(dim=-1).reshape(-1,L).to(DEVICE, non_blocking=True) | |
| alpha, _, alpha_mask,_ = get_torsions(xyz_t.reshape(-1,L,27,3), seq_tmp, ti_dev, ti_flip, ang_ref) | |
| alpha_mask = torch.logical_and(alpha_mask, ~torch.isnan(alpha[...,0])) | |
| alpha[torch.isnan(alpha)] = 0.0 | |
| alpha = alpha.reshape(B,-1,L,10,2) | |
| alpha_mask = alpha_mask.reshape(B,-1,L,10,1) | |
| alpha_t = torch.cat((alpha, alpha_mask), dim=-1).reshape(B,-1,L,30) | |
| #t1d = torch.cat((t1d, chis.reshape(B,-1,L,30)), dim=-1) | |
| xyz_t = get_init_xyz(xyz_t) | |
| xyz_prev = xyz_t[:,0] | |
| state = t1d[:,0] | |
| alpha = alpha[:,0] | |
| t2d=xyz_to_t2d(xyz_t) | |
| return (t2d, alpha, alpha_mask, alpha_t, t1d, xyz_t, xyz_prev, state) | |
| def TemplFeaturizeFixbb(seq, conf_1d=None): | |
| """ | |
| Template 1D featurizer for fixed BB examples : | |
| Parameters: | |
| seq (torch.tensor, required): Integer sequence | |
| conf_1d (torch.tensor, optional): Precalcualted confidence tensor | |
| """ | |
| L = seq.shape[-1] | |
| t1d = torch.nn.functional.one_hot(seq, num_classes=21) # one hot sequence | |
| if conf_1d is None: | |
| conf = torch.ones_like(seq)[...,None] | |
| else: | |
| conf = conf_1d[:,None] | |
| t1d = torch.cat((t1d, conf), dim=-1) | |
| return t1d | |
| def MSAFeaturize_fixbb(msa, params): | |
| ''' | |
| Input: full msa information | |
| Output: Single sequence, with some percentage of amino acids mutated (but no resides 'masked') | |
| This is modified from autofold2, to remove mutations of the single sequence | |
| ''' | |
| N, L = msa.shape | |
| # raw MSA profile | |
| raw_profile = torch.nn.functional.one_hot(msa, num_classes=22) | |
| raw_profile = raw_profile.float().mean(dim=0) | |
| b_seq = list() | |
| b_msa_clust = list() | |
| b_msa_seed = list() | |
| b_msa_extra = list() | |
| b_mask_pos = list() | |
| for i_cycle in range(params['MAXCYCLE']): | |
| assert torch.max(msa) < 22 | |
| msa_onehot = torch.nn.functional.one_hot(msa[:1],num_classes=22) | |
| msa_fakeprofile_onehot = torch.nn.functional.one_hot(msa[:1],num_classes=26) #add the extra two indel planes, which will be set to zero | |
| msa_full_onehot = torch.cat((msa_onehot, msa_fakeprofile_onehot), dim=-1) | |
| #make fake msa_extra | |
| msa_extra_onehot = torch.nn.functional.one_hot(msa[:1],num_classes=25) | |
| #make fake msa_clust and mask_pos | |
| msa_clust = msa[:1] | |
| mask_pos = torch.full_like(msa_clust, 1).bool() | |
| b_seq.append(msa[0].clone()) | |
| b_msa_seed.append(msa_full_onehot[:1].clone()) #masked single sequence onehot (nb no mask so just single sequence onehot) | |
| b_msa_extra.append(msa_extra_onehot[:1].clone()) #masked single sequence onehot (nb no mask so just single sequence onehot) | |
| b_msa_clust.append(msa_clust[:1].clone()) #unmasked original single sequence | |
| b_mask_pos.append(mask_pos[:1].clone()) #mask positions in single sequence (all zeros) | |
| b_seq = torch.stack(b_seq) | |
| b_msa_clust = torch.stack(b_msa_clust) | |
| b_msa_seed = torch.stack(b_msa_seed) | |
| b_msa_extra = torch.stack(b_msa_extra) | |
| b_mask_pos = torch.stack(b_mask_pos) | |
| return b_seq, b_msa_clust, b_msa_seed, b_msa_extra, b_mask_pos | |
| def MSAFeaturize(msa, params): | |
| ''' | |
| Input: full msa information | |
| Output: Single sequence, with some percentage of amino acids mutated (but no resides 'masked') | |
| This is modified from autofold2, to remove mutations of the single sequence | |
| ''' | |
| N, L = msa.shape | |
| # raw MSA profile | |
| raw_profile = torch.nn.functional.one_hot(msa, num_classes=22) | |
| raw_profile = raw_profile.float().mean(dim=0) | |
| b_seq = list() | |
| b_msa_clust = list() | |
| b_msa_seed = list() | |
| b_msa_extra = list() | |
| b_mask_pos = list() | |
| for i_cycle in range(params['MAXCYCLE']): | |
| assert torch.max(msa) < 22 | |
| msa_onehot = torch.nn.functional.one_hot(msa,num_classes=22) | |
| msa_fakeprofile_onehot = torch.nn.functional.one_hot(msa,num_classes=26) #add the extra two indel planes, which will be set to zero | |
| msa_full_onehot = torch.cat((msa_onehot, msa_fakeprofile_onehot), dim=-1) | |
| #make fake msa_extra | |
| msa_extra_onehot = torch.nn.functional.one_hot(msa,num_classes=25) | |
| #make fake msa_clust and mask_pos | |
| msa_clust = msa | |
| mask_pos = torch.full_like(msa_clust, 1).bool() | |
| b_seq.append(msa[0].clone()) | |
| b_msa_seed.append(msa_full_onehot.clone()) #masked single sequence onehot (nb no mask so just single sequence onehot) | |
| b_msa_extra.append(msa_extra_onehot.clone()) #masked single sequence onehot (nb no mask so just single sequence onehot) | |
| b_msa_clust.append(msa_clust.clone()) #unmasked original single sequence | |
| b_mask_pos.append(mask_pos.clone()) #mask positions in single sequence (all zeros) | |
| b_seq = torch.stack(b_seq) | |
| b_msa_clust = torch.stack(b_msa_clust) | |
| b_msa_seed = torch.stack(b_msa_seed) | |
| b_msa_extra = torch.stack(b_msa_extra) | |
| b_mask_pos = torch.stack(b_mask_pos) | |
| return b_seq, b_msa_clust, b_msa_seed, b_msa_extra, b_mask_pos | |
| def mask_inputs(seq, msa_masked, msa_full, xyz_t, t1d, input_seq_mask=None, input_str_mask=None, input_t1dconf_mask=None, loss_seq_mask=None, loss_str_mask=None): | |
| """ | |
| Parameters: | |
| seq (torch.tensor, required): (B,I,L) integer sequence | |
| msa_masked (torch.tensor, required): (B,I,N_short,L,46) | |
| msa_full (torch,.tensor, required): (B,I,N_long,L,23) | |
| xyz_t (torch,tensor): (B,T,L,14,3) template crds BEFORE they go into get_init_xyz | |
| t1d (torch.tensor, required): (B,I,L,22) this is the t1d before tacking on the chi angles | |
| str_mask_1D (torch.tensor, required): Shape (L) rank 1 tensor where structure is masked at False positions | |
| seq_mask_1D (torch.tensor, required): Shape (L) rank 1 tensor where seq is masked at False positions | |
| """ | |
| ########### | |
| B,_,_ = seq.shape | |
| assert B == 1, 'batch sizes > 1 not supported' | |
| seq_mask = input_seq_mask[0] | |
| seq[:,:,~seq_mask] = 21 # mask token categorical value | |
| ### msa_masked ### | |
| ################## | |
| msa_masked[:,:,:,~seq_mask,:20] = 0 | |
| msa_masked[:,:,:,~seq_mask,20] = 0 | |
| msa_masked[:,:,:,~seq_mask,21] = 1 # set to the unkown char | |
| # index 44/45 is insertion/deletion | |
| # index 43 is the unknown token | |
| # index 42 is the masked token | |
| msa_masked[:,:,:,~seq_mask,22:42] = 0 | |
| msa_masked[:,:,:,~seq_mask,43] = 1 | |
| msa_masked[:,:,:,~seq_mask,42] = 0 | |
| # insertion/deletion stuff | |
| msa_masked[:,:,:,~seq_mask,44:] = 0 | |
| ### msa_full ### | |
| ################ | |
| msa_full[:,:,:,~seq_mask,:20] = 0 | |
| msa_full[:,:,:,~seq_mask,21] = 1 | |
| msa_full[:,:,:,~seq_mask,20] = 0 | |
| msa_full[:,:,:,~seq_mask,-1] = 0 #NOTE: double check this is insertions/deletions and 0 makes sense | |
| ### t1d ### | |
| ########### | |
| # NOTE: Not adjusting t1d last dim (confidence) from sequence mask | |
| t1d[:,:,~seq_mask,:20] = 0 | |
| t1d[:,:,~seq_mask,20] = 1 # unknown | |
| t1d[:,:,:,21] *= input_t1dconf_mask | |
| #JG added in here to make sure everything fits | |
| print('expanding t1d to 24 dims') | |
| t1d = torch.cat((t1d, torch.zeros((t1d.shape[0],t1d.shape[1],t1d.shape[2],2)).float()), -1).to(seq.device) | |
| xyz_t[:,:,~seq_mask,3:,:] = float('nan') | |
| # Structure masking | |
| str_mask = input_str_mask[0] | |
| xyz_t[:,:,~str_mask,:,:] = float('nan') | |
| return seq, msa_masked, msa_full, xyz_t, t1d | |
| ########################################################### | |
| #Functions for randomly translating/rotation input residues | |
| ########################################################### | |
| def get_translated_coords(args): | |
| ''' | |
| Parses args.res_translate | |
| ''' | |
| #get positions to translate | |
| res_translate = [] | |
| for res in args.res_translate.split(":"): | |
| temp_str = [] | |
| for i in res.split(','): | |
| temp_str.append(i) | |
| if temp_str[-1][0].isalpha() is True: | |
| temp_str.append(2.0) #set default distance | |
| for i in temp_str[:-1]: | |
| if '-' in i: | |
| start = int(i.split('-')[0][1:]) | |
| while start <= int(i.split('-')[1]): | |
| res_translate.append((i.split('-')[0][0] + str(start),float(temp_str[-1]))) | |
| start += 1 | |
| else: | |
| res_translate.append((i, float(temp_str[-1]))) | |
| start = 0 | |
| output = [] | |
| for i in res_translate: | |
| temp = (i[0], i[1], start) | |
| output.append(temp) | |
| start += 1 | |
| return output | |
| def get_tied_translated_coords(args, untied_translate=None): | |
| ''' | |
| Parses args.tie_translate | |
| ''' | |
| #pdb_idx = list(parsed_pdb['idx']) | |
| #xyz = parsed_pdb['xyz'] | |
| #get positions to translate | |
| res_translate = [] | |
| block = 0 | |
| for res in args.tie_translate.split(":"): | |
| temp_str = [] | |
| for i in res.split(','): | |
| temp_str.append(i) | |
| if temp_str[-1][0].isalpha() is True: | |
| temp_str.append(2.0) #set default distance | |
| for i in temp_str[:-1]: | |
| if '-' in i: | |
| start = int(i.split('-')[0][1:]) | |
| while start <= int(i.split('-')[1]): | |
| res_translate.append((i.split('-')[0][0] + str(start),float(temp_str[-1]), block)) | |
| start += 1 | |
| else: | |
| res_translate.append((i, float(temp_str[-1]), block)) | |
| block += 1 | |
| #sanity check | |
| if untied_translate != None: | |
| checker = [i[0] for i in res_translate] | |
| untied_check = [i[0] for i in untied_translate] | |
| for i in checker: | |
| if i in untied_check: | |
| print(f'WARNING: residue {i} is specified both in --res_translate and --tie_translate. Residue {i} will be ignored in --res_translate, and instead only moved in a tied block (--tie_translate)') | |
| final_output = res_translate | |
| for i in untied_translate: | |
| if i[0] not in checker: | |
| final_output.append((i[0],i[1],i[2] + block + 1)) | |
| else: | |
| final_output = res_translate | |
| return final_output | |
| def translate_coords(parsed_pdb, res_translate): | |
| ''' | |
| Takes parsed list in format [(chain_residue,distance,tieing_block)] and randomly translates residues accordingly. | |
| ''' | |
| pdb_idx = parsed_pdb['pdb_idx'] | |
| xyz = np.copy(parsed_pdb['xyz']) | |
| translated_coord_dict = {} | |
| #get number of blocks | |
| temp = [int(i[2]) for i in res_translate] | |
| blocks = np.max(temp) | |
| for block in range(blocks + 1): | |
| init_dist = 1.01 | |
| while init_dist > 1: #gives equal probability to any direction (as keeps going until init_dist is within unit circle) | |
| x = random.uniform(-1,1) | |
| y = random.uniform(-1,1) | |
| z = random.uniform(-1,1) | |
| init_dist = np.sqrt(x**2 + y**2 + z**2) | |
| x=x/init_dist | |
| y=y/init_dist | |
| z=z/init_dist | |
| translate_dist = random.uniform(0,1) #now choose distance (as proportion of maximum) that coordinates will be translated | |
| for res in res_translate: | |
| if res[2] == block: | |
| res_idx = pdb_idx.index((res[0][0],int(res[0][1:]))) | |
| original_coords = np.copy(xyz[res_idx,:,:]) | |
| for i in range(14): | |
| if parsed_pdb['mask'][res_idx, i]: | |
| xyz[res_idx,i,0] += np.float32(x * translate_dist * float(res[1])) | |
| xyz[res_idx,i,1] += np.float32(y * translate_dist * float(res[1])) | |
| xyz[res_idx,i,2] += np.float32(z * translate_dist * float(res[1])) | |
| translated_coords = xyz[res_idx,:,:] | |
| translated_coord_dict[res[0]] = (original_coords.tolist(), translated_coords.tolist()) | |
| return xyz[:,:,:], translated_coord_dict | |
| def parse_block_rotate(args): | |
| block_translate = [] | |
| block = 0 | |
| for res in args.block_rotate.split(":"): | |
| temp_str = [] | |
| for i in res.split(','): | |
| temp_str.append(i) | |
| if temp_str[-1][0].isalpha() is True: | |
| temp_str.append(10) #set default angle to 10 degrees | |
| for i in temp_str[:-1]: | |
| if '-' in i: | |
| start = int(i.split('-')[0][1:]) | |
| while start <= int(i.split('-')[1]): | |
| block_translate.append((i.split('-')[0][0] + str(start),float(temp_str[-1]), block)) | |
| start += 1 | |
| else: | |
| block_translate.append((i, float(temp_str[-1]), block)) | |
| block += 1 | |
| return block_translate | |
| def rotate_block(xyz, block_rotate,pdb_index): | |
| rotated_coord_dict = {} | |
| #get number of blocks | |
| temp = [int(i[2]) for i in block_rotate] | |
| blocks = np.max(temp) | |
| for block in range(blocks + 1): | |
| idxs = [pdb_index.index((i[0][0],int(i[0][1:]))) for i in block_rotate if i[2] == block] | |
| angle = [i[1] for i in block_rotate if i[2] == block][0] | |
| block_xyz = xyz[idxs,:,:] | |
| com = [float(torch.mean(block_xyz[:,:,i])) for i in range(3)] | |
| origin_xyz = np.copy(block_xyz) | |
| for i in range(np.shape(origin_xyz)[0]): | |
| for j in range(14): | |
| origin_xyz[i,j] = origin_xyz[i,j] - com | |
| rotated_xyz = rigid_rotate(origin_xyz,angle,angle,angle) | |
| recovered_xyz = np.copy(rotated_xyz) | |
| for i in range(np.shape(origin_xyz)[0]): | |
| for j in range(14): | |
| recovered_xyz[i,j] = rotated_xyz[i,j] + com | |
| recovered_xyz=torch.tensor(recovered_xyz) | |
| rotated_coord_dict[f'rotated_block_{block}_original'] = block_xyz | |
| rotated_coord_dict[f'rotated_block_{block}_rotated'] = recovered_xyz | |
| xyz_out = torch.clone(xyz) | |
| for i in range(len(idxs)): | |
| xyz_out[idxs[i]] = recovered_xyz[i] | |
| return xyz_out,rotated_coord_dict | |
| def rigid_rotate(xyz,a=180,b=180,c=180): | |
| #TODO fix this to make it truly uniform | |
| a=(a/180)*math.pi | |
| b=(b/180)*math.pi | |
| c=(c/180)*math.pi | |
| alpha = random.uniform(-a, a) | |
| beta = random.uniform(-b, b) | |
| gamma = random.uniform(-c, c) | |
| rotated = [] | |
| for i in range(np.shape(xyz)[0]): | |
| for j in range(14): | |
| try: | |
| x = xyz[i,j,0] | |
| y = xyz[i,j,1] | |
| z = xyz[i,j,2] | |
| x2 = x*math.cos(alpha) - y*math.sin(alpha) | |
| y2 = x*math.sin(alpha) + y*math.cos(alpha) | |
| x3 = x2*math.cos(beta) - z*math.sin(beta) | |
| z2 = x2*math.sin(beta) + z*math.cos(beta) | |
| y3 = y2*math.cos(gamma) - z2*math.sin(gamma) | |
| z3 = y2*math.sin(gamma) + z2*math.cos(gamma) | |
| rotated.append([x3,y3,z3]) | |
| except: | |
| rotated.append([float('nan'),float('nan'),float('nan')]) | |
| rotated=np.array(rotated) | |
| rotated=np.reshape(rotated, [np.shape(xyz)[0],14,3]) | |
| return rotated | |
| ######## from old pred_util.py | |
| def find_contigs(mask): | |
| """ | |
| Find contiguous regions in a mask that are True with no False in between | |
| Parameters: | |
| mask (torch.tensor or np.array, required): 1D boolean array | |
| Returns: | |
| contigs (list): List of tuples, each tuple containing the beginning and the | |
| """ | |
| assert len(mask.shape) == 1 # 1D tensor of bools | |
| contigs = [] | |
| found_contig = False | |
| for i,b in enumerate(mask): | |
| if b and not found_contig: # found the beginning of a contig | |
| contig = [i] | |
| found_contig = True | |
| elif b and found_contig: # currently have contig, continuing it | |
| pass | |
| elif not b and found_contig: # found the end, record previous index as end, reset indicator | |
| contig.append(i) | |
| found_contig = False | |
| contigs.append(tuple(contig)) | |
| else: # currently don't have a contig, and didn't find one | |
| pass | |
| # fence post bug - check if the very last entry was True and we didn't get to finish | |
| if b: | |
| contig.append(i+1) | |
| found_contig = False | |
| contigs.append(tuple(contig)) | |
| return contigs | |
| def reindex_chains(pdb_idx): | |
| """ | |
| Given a list of (chain, index) tuples, and the indices where chains break, create a reordered indexing | |
| Parameters: | |
| pdb_idx (list, required): List of tuples (chainID, index) | |
| breaks (list, required): List of indices where chains begin | |
| """ | |
| new_breaks, new_idx = [],[] | |
| current_chain = None | |
| chain_and_idx_to_torch = {} | |
| for i,T in enumerate(pdb_idx): | |
| chain, idx = T | |
| if chain != current_chain: | |
| new_breaks.append(i) | |
| current_chain = chain | |
| # create new space for chain id listings | |
| chain_and_idx_to_torch[chain] = {} | |
| # map original pdb (chain, idx) pair to index in tensor | |
| chain_and_idx_to_torch[chain][idx] = i | |
| # append tensor index to list | |
| new_idx.append(i) | |
| new_idx = np.array(new_idx) | |
| # now we have ordered list and know where the chainbreaks are in the new order | |
| num_additions = 0 | |
| for i in new_breaks[1:]: # skip the first trivial one | |
| new_idx[np.where(new_idx==(i+ num_additions*500))[0][0]:] += 500 | |
| num_additions += 1 | |
| return new_idx, chain_and_idx_to_torch,new_breaks[1:] | |
| class ObjectView(object): | |
| ''' | |
| Easy wrapper to access dictionary values with "dot" notiation instead | |
| ''' | |
| def __init__(self, d): | |
| self.__dict__ = d | |
| def split_templates(xyz_t, t1d, multi_templates,mappings,multi_tmpl_conf=None): | |
| templates = multi_templates.split(":") | |
| if multi_tmpl_conf is not None: | |
| multi_tmpl_conf = [float(i) for i in multi_tmpl_conf.split(",")] | |
| assert len(templates) == len(multi_tmpl_conf), "Number of templates must equal number of confidences specified in --multi_tmpl_conf flag" | |
| for idx, template in enumerate(templates): | |
| parts = template.split(",") | |
| template_mask = torch.zeros(xyz_t.shape[2]).bool() | |
| for part in parts: | |
| start = int(part.split("-")[0][1:]) | |
| end = int(part.split("-")[1]) + 1 | |
| chain = part[0] | |
| for i in range(start, end): | |
| try: | |
| ref_pos = mappings['complex_con_ref_pdb_idx'].index((chain, i)) | |
| hal_pos_0 = mappings['complex_con_hal_idx0'][ref_pos] | |
| except: | |
| ref_pos = mappings['con_ref_pdb_idx'].index((chain, i)) | |
| hal_pos_0 = mappings['con_hal_idx0'][ref_pos] | |
| template_mask[hal_pos_0] = True | |
| xyz_t_temp = torch.clone(xyz_t) | |
| xyz_t_temp[:,:,~template_mask,:,:] = float('nan') | |
| t1d_temp = torch.clone(t1d) | |
| t1d_temp[:,:,~template_mask,:20] =0 | |
| t1d_temp[:,:,~template_mask,20] = 1 | |
| if multi_tmpl_conf is not None: | |
| t1d_temp[:,:,template_mask,21] = multi_tmpl_conf[idx] | |
| if idx != 0: | |
| xyz_t_out = torch.cat((xyz_t_out, xyz_t_temp),dim=1) | |
| t1d_out = torch.cat((t1d_out, t1d_temp),dim=1) | |
| else: | |
| xyz_t_out = xyz_t_temp | |
| t1d_out = t1d_temp | |
| return xyz_t_out, t1d_out | |
| class ContigMap(): | |
| ''' | |
| New class for doing mapping. | |
| Supports multichain or multiple crops from a single receptor chain. | |
| Also supports indexing jump (+200) or not, based on contig input. | |
| Default chain outputs are inpainted chains as A (and B, C etc if multiple chains), and all fragments of receptor chain on the next one (generally B) | |
| Output chains can be specified. Sequence must be the same number of elements as in contig string | |
| ''' | |
| def __init__(self, parsed_pdb, contigs=None, inpaint_seq=None, inpaint_str=None, length=None, ref_idx=None, hal_idx=None, idx_rf=None, inpaint_seq_tensor=None, inpaint_str_tensor=None, topo=False): | |
| #sanity checks | |
| if contigs is None and ref_idx is None: | |
| sys.exit("Must either specify a contig string or precise mapping") | |
| if idx_rf is not None or hal_idx is not None or ref_idx is not None: | |
| if idx_rf is None or hal_idx is None or ref_idx is None: | |
| sys.exit("If you're specifying specific contig mappings, the reference and output positions must be specified, AND the indexing for RoseTTAFold (idx_rf)") | |
| self.chain_order='ABCDEFGHIJKLMNOPQRSTUVWXYZ' | |
| if length is not None: | |
| if '-' not in length: | |
| self.length = [int(length),int(length)+1] | |
| else: | |
| self.length = [int(length.split("-")[0]),int(length.split("-")[1])+1] | |
| else: | |
| self.length = None | |
| self.ref_idx = ref_idx | |
| self.hal_idx=hal_idx | |
| self.idx_rf=idx_rf | |
| self.inpaint_seq = ','.join(inpaint_seq).split(",") if inpaint_seq is not None else None | |
| self.inpaint_str = ','.join(inpaint_str).split(",") if inpaint_str is not None else None | |
| self.inpaint_seq_tensor=inpaint_seq_tensor | |
| self.inpaint_str_tensor=inpaint_str_tensor | |
| self.parsed_pdb = parsed_pdb | |
| self.topo=topo | |
| if ref_idx is None: | |
| #using default contig generation, which outputs in rosetta-like format | |
| self.contigs=contigs | |
| self.sampled_mask,self.contig_length,self.n_inpaint_chains = self.get_sampled_mask() | |
| self.receptor_chain = self.chain_order[self.n_inpaint_chains] | |
| self.receptor, self.receptor_hal, self.receptor_rf, self.inpaint, self.inpaint_hal, self.inpaint_rf= self.expand_sampled_mask() | |
| self.ref = self.inpaint + self.receptor | |
| self.hal = self.inpaint_hal + self.receptor_hal | |
| self.rf = self.inpaint_rf + self.receptor_rf | |
| else: | |
| #specifying precise mappings | |
| self.ref=ref_idx | |
| self.hal=hal_idx | |
| self.rf = rf_idx | |
| self.mask_1d = [False if i == ('_','_') else True for i in self.ref] | |
| #take care of sequence and structure masking | |
| if self.inpaint_seq_tensor is None: | |
| if self.inpaint_seq is not None: | |
| self.inpaint_seq = self.get_inpaint_seq_str(self.inpaint_seq) | |
| else: | |
| self.inpaint_seq = np.array([True if i != ('_','_') else False for i in self.ref]) | |
| else: | |
| self.inpaint_seq = self.inpaint_seq_tensor | |
| if self.inpaint_str_tensor is None: | |
| if self.inpaint_str is not None: | |
| self.inpaint_str = self.get_inpaint_seq_str(self.inpaint_str) | |
| else: | |
| self.inpaint_str = np.array([True if i != ('_','_') else False for i in self.ref]) | |
| else: | |
| self.inpaint_str = self.inpaint_str_tensor | |
| #get 0-indexed input/output (for trb file) | |
| self.ref_idx0,self.hal_idx0, self.ref_idx0_inpaint, self.hal_idx0_inpaint, self.ref_idx0_receptor, self.hal_idx0_receptor=self.get_idx0() | |
| def get_sampled_mask(self): | |
| ''' | |
| Function to get a sampled mask from a contig. | |
| ''' | |
| length_compatible=False | |
| count = 0 | |
| while length_compatible is False: | |
| inpaint_chains=0 | |
| contig_list = self.contigs | |
| sampled_mask = [] | |
| sampled_mask_length = 0 | |
| #allow receptor chain to be last in contig string | |
| if all([i[0].isalpha() for i in contig_list[-1].split(",")]): | |
| contig_list[-1] = f'{contig_list[-1]},0' | |
| for con in contig_list: | |
| if ((all([i[0].isalpha() for i in con.split(",")[:-1]]) and con.split(",")[-1] == '0')) or self.topo is True: | |
| #receptor chain | |
| sampled_mask.append(con) | |
| else: | |
| inpaint_chains += 1 | |
| #chain to be inpainted. These are the only chains that count towards the length of the contig | |
| subcons = con.split(",") | |
| subcon_out = [] | |
| for subcon in subcons: | |
| if subcon[0].isalpha(): | |
| subcon_out.append(subcon) | |
| if '-' in subcon: | |
| sampled_mask_length += (int(subcon.split("-")[1])-int(subcon.split("-")[0][1:])+1) | |
| else: | |
| sampled_mask_length += 1 | |
| else: | |
| if '-' in subcon: | |
| length_inpaint=random.randint(int(subcon.split("-")[0]),int(subcon.split("-")[1])) | |
| subcon_out.append(f'{length_inpaint}-{length_inpaint}') | |
| sampled_mask_length += length_inpaint | |
| elif subcon == '0': | |
| subcon_out.append('0') | |
| else: | |
| length_inpaint=int(subcon) | |
| subcon_out.append(f'{length_inpaint}-{length_inpaint}') | |
| sampled_mask_length += int(subcon) | |
| sampled_mask.append(','.join(subcon_out)) | |
| #check length is compatible | |
| if self.length is not None: | |
| if sampled_mask_length >= self.length[0] and sampled_mask_length < self.length[1]: | |
| length_compatible = True | |
| else: | |
| length_compatible = True | |
| count+=1 | |
| if count == 100000: #contig string incompatible with this length | |
| sys.exit("Contig string incompatible with --length range") | |
| return sampled_mask, sampled_mask_length, inpaint_chains | |
| def expand_sampled_mask(self): | |
| chain_order='ABCDEFGHIJKLMNOPQRSTUVWXYZ' | |
| receptor = [] | |
| inpaint = [] | |
| receptor_hal = [] | |
| inpaint_hal = [] | |
| receptor_idx = 1 | |
| inpaint_idx = 1 | |
| inpaint_chain_idx=-1 | |
| receptor_chain_break=[] | |
| inpaint_chain_break = [] | |
| for con in self.sampled_mask: | |
| if (all([i[0].isalpha() for i in con.split(",")[:-1]]) and con.split(",")[-1] == '0') or self.topo is True: | |
| #receptor chain | |
| subcons = con.split(",")[:-1] | |
| assert all([i[0] == subcons[0][0] for i in subcons]), "If specifying fragmented receptor in a single block of the contig string, they MUST derive from the same chain" | |
| assert all(int(subcons[i].split("-")[0][1:]) < int(subcons[i+1].split("-")[0][1:]) for i in range(len(subcons)-1)), "If specifying multiple fragments from the same chain, pdb indices must be in ascending order!" | |
| for idx, subcon in enumerate(subcons): | |
| ref_to_add = [(subcon[0], i) for i in np.arange(int(subcon.split("-")[0][1:]),int(subcon.split("-")[1])+1)] | |
| receptor.extend(ref_to_add) | |
| receptor_hal.extend([(self.receptor_chain,i) for i in np.arange(receptor_idx, receptor_idx+len(ref_to_add))]) | |
| receptor_idx += len(ref_to_add) | |
| if idx != len(subcons)-1: | |
| idx_jump = int(subcons[idx+1].split("-")[0][1:]) - int(subcon.split("-")[1]) -1 | |
| receptor_chain_break.append((receptor_idx-1,idx_jump)) #actual chain break in pdb chain | |
| else: | |
| receptor_chain_break.append((receptor_idx-1,200)) #200 aa chain break | |
| else: | |
| inpaint_chain_idx += 1 | |
| for subcon in con.split(","): | |
| if subcon[0].isalpha(): | |
| ref_to_add=[(subcon[0], i) for i in np.arange(int(subcon.split("-")[0][1:]),int(subcon.split("-")[1])+1)] | |
| inpaint.extend(ref_to_add) | |
| inpaint_hal.extend([(chain_order[inpaint_chain_idx], i) for i in np.arange(inpaint_idx,inpaint_idx+len(ref_to_add))]) | |
| inpaint_idx += len(ref_to_add) | |
| else: | |
| inpaint.extend([('_','_')] * int(subcon.split("-")[0])) | |
| inpaint_hal.extend([(chain_order[inpaint_chain_idx], i) for i in np.arange(inpaint_idx,inpaint_idx+int(subcon.split("-")[0]))]) | |
| inpaint_idx += int(subcon.split("-")[0]) | |
| inpaint_chain_break.append((inpaint_idx-1,200)) | |
| if self.topo is True or inpaint_hal == []: | |
| receptor_hal = [(i[0], i[1]) for i in receptor_hal] | |
| else: | |
| receptor_hal = [(i[0], i[1] + inpaint_hal[-1][1]) for i in receptor_hal] #rosetta-like numbering | |
| #get rf indexes, with chain breaks | |
| inpaint_rf = np.arange(0,len(inpaint)) | |
| receptor_rf = np.arange(len(inpaint)+200,len(inpaint)+len(receptor)+200) | |
| for ch_break in inpaint_chain_break[:-1]: | |
| receptor_rf[:] += 200 | |
| inpaint_rf[ch_break[0]:] += ch_break[1] | |
| for ch_break in receptor_chain_break[:-1]: | |
| receptor_rf[ch_break[0]:] += ch_break[1] | |
| return receptor, receptor_hal, receptor_rf.tolist(), inpaint, inpaint_hal, inpaint_rf.tolist() | |
| def get_inpaint_seq_str(self, inpaint_s): | |
| ''' | |
| function to generate inpaint_str or inpaint_seq masks specific to this contig | |
| ''' | |
| s_mask = np.copy(self.mask_1d) | |
| inpaint_s_list = [] | |
| for i in inpaint_s: | |
| if '-' in i: | |
| inpaint_s_list.extend([(i[0],p) for p in range(int(i.split("-")[0][1:]), int(i.split("-")[1])+1)]) | |
| else: | |
| inpaint_s_list.append((i[0],int(i[1:]))) | |
| for res in inpaint_s_list: | |
| if res in self.ref: | |
| s_mask[self.ref.index(res)] = False #mask this residue | |
| return np.array(s_mask) | |
| def get_idx0(self): | |
| ref_idx0=[] | |
| hal_idx0=[] | |
| ref_idx0_inpaint=[] | |
| hal_idx0_inpaint=[] | |
| ref_idx0_receptor=[] | |
| hal_idx0_receptor=[] | |
| for idx, val in enumerate(self.ref): | |
| if val != ('_','_'): | |
| assert val in self.parsed_pdb['pdb_idx'],f"{val} is not in pdb file!" | |
| hal_idx0.append(idx) | |
| ref_idx0.append(self.parsed_pdb['pdb_idx'].index(val)) | |
| for idx, val in enumerate(self.inpaint): | |
| if val != ('_','_'): | |
| hal_idx0_inpaint.append(idx) | |
| ref_idx0_inpaint.append(self.parsed_pdb['pdb_idx'].index(val)) | |
| for idx, val in enumerate(self.receptor): | |
| if val != ('_','_'): | |
| hal_idx0_receptor.append(idx) | |
| ref_idx0_receptor.append(self.parsed_pdb['pdb_idx'].index(val)) | |
| return ref_idx0, hal_idx0, ref_idx0_inpaint, hal_idx0_inpaint, ref_idx0_receptor, hal_idx0_receptor | |
| def get_mappings(rm): | |
| mappings = {} | |
| mappings['con_ref_pdb_idx'] = [i for i in rm.inpaint if i != ('_','_')] | |
| mappings['con_hal_pdb_idx'] = [rm.inpaint_hal[i] for i in range(len(rm.inpaint_hal)) if rm.inpaint[i] != ("_","_")] | |
| mappings['con_ref_idx0'] = rm.ref_idx0_inpaint | |
| mappings['con_hal_idx0'] = rm.hal_idx0_inpaint | |
| if rm.inpaint != rm.ref: | |
| mappings['complex_con_ref_pdb_idx'] = [i for i in rm.ref if i != ("_","_")] | |
| mappings['complex_con_hal_pdb_idx'] = [rm.hal[i] for i in range(len(rm.hal)) if rm.ref[i] != ("_","_")] | |
| mappings['receptor_con_ref_pdb_idx'] = [i for i in rm.receptor if i != ("_","_")] | |
| mappings['receptor_con_hal_pdb_idx'] = [rm.receptor_hal[i] for i in range(len(rm.receptor_hal)) if rm.receptor[i] != ("_","_")] | |
| mappings['complex_con_ref_idx0'] = rm.ref_idx0 | |
| mappings['complex_con_hal_idx0'] = rm.hal_idx0 | |
| mappings['receptor_con_ref_idx0'] = rm.ref_idx0_receptor | |
| mappings['receptor_con_hal_idx0'] = rm.hal_idx0_receptor | |
| mappings['inpaint_str'] = rm.inpaint_str | |
| mappings['inpaint_seq'] = rm.inpaint_seq | |
| mappings['sampled_mask'] = rm.sampled_mask | |
| mappings['mask_1d'] = rm.mask_1d | |
| return mappings | |
| def lddt_unbin(pred_lddt): | |
| nbin = pred_lddt.shape[1] | |
| bin_step = 1.0 / nbin | |
| lddt_bins = torch.linspace(bin_step, 1.0, nbin, dtype=pred_lddt.dtype, device=pred_lddt.device) | |
| pred_lddt = nn.Softmax(dim=1)(pred_lddt) | |
| return torch.sum(lddt_bins[None,:,None]*pred_lddt, dim=1) | |