File size: 8,879 Bytes
82d55c6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Created by: [email protected]
# des : evaluate RPcontact
import glob
import os
import pickle
import random
from argparse import ArgumentParser
import matplotlib.pyplot as plt
import pandas as pd
import torch
from Bio import SeqIO
from sklearn.preprocessing import OneHotEncoder
import numpy as np
from predict import check_path, one_hot_encode, get_bin_pred, doSavePredict
def get_bin_label(df_label,distance_cutoff):
bin_label = df_label < distance_cutoff
bin_label = bin_label.astype(int)
return bin_label
def view_evaluate_contact_prob(df_label, bin_pred,ax=None,markersize=5):
confusing_matrix = np.zeros_like(df_label)
r, p = confusing_matrix.shape
if ax is None:
ax = plt
ax.xlim([-2, p + 2])
ax.ylim([-2, r + 2])
# plt.xticks(rotation=90)
else:
ax.set_xlim([-2, p + 2])
ax.set_ylim([-2, r + 2])
# plt.setp(ax.get_xticklabels(), rotation=90)
ax.set_title('performance')
colors = [
'#f5e0c4', # lightblue for FP
# '#aaa6ce','#66609c','k',# light purple, dark purple,black, for Groud truth
'#b0d9db','#61b3b6','k',# light purple, dark purple,black, for Groud truth
'#ecbbd8','#9d4e7d','r' # for TP
]
tps = []
bin_label = df_label<8
temp = bin_pred - bin_label
fn = ax.plot(*np.where(temp.T == 1), ".", c=colors[0], markersize=markersize,label='False Positive')[0]
# 绘制NaN值的数据点为灰色
oc = ax.plot(*np.where(df_label.T.isna()), ".", c='gray', markersize=markersize, label='Missing in PDB')[0]
confusing_matrix[bin_label == 1] = 1 #ground truth
oc = ax.plot(*np.where(bin_label.T == 1), ".", c=colors[1],markersize=markersize, label='Ground truth (8Å)')[0]
temp = bin_label + bin_pred
tps.append(len(confusing_matrix[np.where(temp == 2)]))
confusing_matrix[np.where(temp == 2)] = 2 # TP : blue
tp = ax.plot(*np.where(temp.T == 2), "o", c=colors[4],markersize=markersize, label='True Positive (8Å)')[0]
tp.set_markerfacecolor(colors[1])
tp.set_markeredgecolor(colors[4])
bin_label = df_label<5
temp = bin_label + bin_pred
tps.append(len(confusing_matrix[np.where(temp == 2)]))
oc = ax.plot(*np.where(bin_label.T == 1), ".", c=colors[2],markersize=markersize, label='Ground truth (5Å)')[0]
confusing_matrix[np.where(temp == 2)] = 2 # TP : blue
tp = ax.plot(*np.where(temp.T == 2), "o", c=colors[5],markersize=markersize, label='True Positive (5Å)')[0]
tp.set_markerfacecolor(colors[2])
tp.set_markeredgecolor(colors[5])
bin_label = df_label<3.5
oc = ax.plot(*np.where(bin_label.T == 1), ".", c=colors[3],markersize=markersize, label='Ground truth (3.5Å)')[0]
temp = bin_label + bin_pred
tps.append(len(confusing_matrix[np.where(temp == 2)]))
confusing_matrix[np.where(temp == 2)] = 2 # TP : blue
tp = ax.plot(*np.where(temp.T == 2), "o", c=colors[6],markersize=markersize, label='True Positive (3.5Å)')[0]
tp.set_markerfacecolor(colors[3])
tp.set_markeredgecolor(colors[6])
# ax.legend()
# plt.show()
# tp = len(confusing_matrix[np.where(temp == 2)])
print(len(confusing_matrix[np.where(temp == 2)]))
return '/'.join([str(e) for e in tps[::-1]]),confusing_matrix
def seed_everything(seed=2022):
print('seed_everything to ',seed)
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed) # 程序每次运行结果一致,但是程序中多次生成随机数每次不一致 # https://blog.csdn.net/qq_42951560/article/details/112174334
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False # minbatch的长度一直在变化,这个优化比较浪费时间
def getParam():
parser = ArgumentParser()
# data
parser.add_argument('--rootdir', default='',
type=str)
parser.add_argument('--fasta', default='./example/inputs/8DMB_W.8DMB_P.fasta',
type=str)
parser.add_argument('--out', default='./example/outputs/',
type=str)
parser.add_argument('--ffeat', default='./example/inputs/{pdbid}.pickle',
type=str)
parser.add_argument('--fmodel', default='./weight/model_roc_0_38=0.845.pt',
type=str)
parser.add_argument('--device', default='cpu',
type=str)
parser.add_argument('--flabel', default='./example/inputs/{pdbid}.pickle',
type=str)
parser.add_argument('--draw', default=True,
type=bool)
args = parser.parse_args()
return args
if __name__ == '__main__':
args = getParam()
rootdir = args.rootdir
fasta = args.fasta
ffeat = args.ffeat
fmodel = args.fmodel
device = args.device
flabel = args.flabel
draw = args.draw
out = args.out
check_path(out)
# pdbid = fasta.rsplit('/',1)[0].split('.')[0]
seed_everything(seed=2022)
models = [(model_path,torch.load(model_path, map_location=torch.device(device))) for model_path in glob.glob(fmodel)]
print('loading existed model', fmodel)
with torch.no_grad():
for pdbid,seq in [(record.id,record.seq) for record in SeqIO.parse(fasta,'fasta')]:
rnaid,proid= pdbid.split('.')
rnaseq,proseq= seq.split('.')
with open(ffeat.format_map({'pdbid':rnaid}),'rb') as f:
rna_emb = pickle.load(f)
with open(ffeat.format_map({'pdbid':proid}),'rb') as f:
pro_emb = pickle.load(f)
rna_oh = one_hot_encode(rnaseq, alpha='ACGU')
pro_oh = one_hot_encode(proseq, alpha='GAVLIFWYDNEKQMSTCPHR')
# mask = np.ones((emb.shape[0],1)) # mask missing nt when evaluate the model
x_train = np.concatenate([rna_oh,rna_emb],axis=1)
x_train = np.expand_dims(x_train,0)
x_train = torch.from_numpy(x_train).transpose(-1,-2)
x_train = x_train.to(device, dtype=torch.float)
x_rna = x_train
x_train = np.concatenate([pro_oh, pro_emb], axis=1)
x_train = np.expand_dims(x_train, 0)
x_train = torch.from_numpy(x_train).transpose(-1, -2)
x_train = x_train.to(device, dtype=torch.float)
x_pro = x_train
print('input data shape for rna and protein:',x_rna.shape,x_pro.shape)
x_rna = x_rna.to(device, dtype=torch.float32)
x_pro = x_pro.to(device, dtype=torch.float32)
plt.figure(figsize=(20, 15))
for i,(model_path,model) in enumerate(models):
model.eval()
outputs = model(x_pro, x_rna) # [1, 299, 74, 1]
# print('outputs,',outputs.device)
outputs = torch.squeeze(outputs, -1)
outputs = outputs.permute(0, 2, 1)
df_pred = outputs[0].cpu().detach().numpy()
# seq = data._seq[pdbid] if pdbid in data._seq else None
des = f'predict by {__file__}\n#{model_path}'
doSavePredict(pdbid, {'rna':rnaseq,'protein':proseq}, df_pred,
out,
des
)
top = sum(df_pred.shape)
df_pred = pd.DataFrame(df_pred)
threshold = df_pred.stack().nlargest(top).iloc[-1]
if draw:
with open(flabel.format_map({'pdbid': pdbid}), 'rb') as f:
df_label = pickle.load(f)
df_label = df_label.squeeze()
bin_pred = get_bin_pred(df_pred, threshold=threshold)
view_evaluate_contact_prob(df_label, bin_pred, ax=None)
plt.title(f'Predicted contact map of {pdbid}\nPredidcted by RPcontact, top L=r+p')
plt.xlabel(proid)
plt.ylabel(rnaid)
handles, labels = plt.gca().get_legend_handles_labels()
plt.legend(handles, labels, bbox_to_anchor=(1.05, 1), loc='upper left', ncol=1, borderaxespad=1,
frameon=False)
# 设置坐标轴的相同缩放
ax = plt.gca()
ax.set_aspect('equal')
plt.tight_layout()
plt.savefig(f'{out}/{pdbid}_{i}_evaluate.png',dpi=900)
plt.show()
print(f'predict {pdbid} with {len(seq)} nts')
|