Spaces:
Runtime error
Runtime error
| from rdkit import Chem | |
| import os | |
| import argparse | |
| from tqdm import tqdm | |
| import multiprocessing | |
| import pandas as pd | |
| from rdkit import RDLogger | |
| import re | |
| from utils import * | |
| lg = RDLogger.logger() | |
| lg.setLevel(RDLogger.CRITICAL) | |
| def extract_smiles(s): | |
| start_token = "[START_SMILES]" | |
| end_token = "[END_SMILES]" | |
| start_index = s.find(start_token) + len(start_token) | |
| end_index = s.find(end_token) | |
| if start_index > -1 and end_index > -1: | |
| return s[start_index:end_index].strip() | |
| return s | |
| def canonicalize_smiles_clear_map(smiles,return_max_frag=True): | |
| mol = Chem.MolFromSmiles(smiles,sanitize=not opt.synthon) | |
| if mol is not None: | |
| [atom.ClearProp('molAtomMapNumber') for atom in mol.GetAtoms() if atom.HasProp('molAtomMapNumber')] | |
| try: | |
| smi = Chem.MolToSmiles(mol, isomericSmiles=False) | |
| except: | |
| if return_max_frag: | |
| return '','' | |
| else: | |
| return '' | |
| if return_max_frag: | |
| sub_smi = smi.split(".") | |
| sub_mol = [Chem.MolFromSmiles(smiles,sanitize=not opt.synthon) for smiles in sub_smi] | |
| sub_mol_size = [(sub_smi[i], len(m.GetAtoms())) for i, m in enumerate(sub_mol) if m is not None] | |
| if len(sub_mol_size) > 0: | |
| return smi, canonicalize_smiles_clear_map(sorted(sub_mol_size,key=lambda x:x[1],reverse=True)[0][0],return_max_frag=False) | |
| else: | |
| return smi, '' | |
| else: | |
| return smi | |
| else: | |
| if return_max_frag: | |
| return '','' | |
| else: | |
| return '' | |
| def compute_rank(input_smiles, prediction,raw=False,alpha=1.0): | |
| valid_score = [[k for k in range(len(prediction[j]))] for j in range(len(prediction))] | |
| invalid_rates = [0 for k in range(len(prediction[0]))] | |
| rank = {} | |
| max_frag_rank = {} | |
| highest = {} | |
| if raw: | |
| # no test augmentation | |
| assert len(prediction) == 1 | |
| for j in range(len(prediction)): | |
| for k in range(len(prediction[j])): | |
| if prediction[j][k][0] == "": | |
| invalid_rates[k] += 1 | |
| # error detection | |
| de_error = [i[0] for i in sorted(list(zip(prediction[j], valid_score[j])), key=lambda x: x[1]) if i[0][0] != ""] | |
| prediction[j] = list(set(de_error)) | |
| prediction[j].sort(key=de_error.index) | |
| for k, data in enumerate(prediction[j]): | |
| rank[data] = 1 / (alpha * k + 1) | |
| else: | |
| for j in range(len(prediction)): # aug_num, beam_size, 2 | |
| for k in range(len(prediction[j])): | |
| # predictions[i][j][k] = canonicalize_smiles_clear_map(predictions[i][j][k]) | |
| if prediction[j][k][0] == "": | |
| valid_score[j][k] = opt.beam_size + 1 | |
| invalid_rates[k] += 1 | |
| # error detection and deduplication | |
| de_error = [i[0] for i in sorted(list(zip(prediction[j], valid_score[j])), key=lambda x: x[1]) if i[0][0] != ""] | |
| prediction[j] = list(set(de_error)) | |
| prediction[j].sort(key=de_error.index) | |
| for k, data in enumerate(prediction[j]): | |
| if data in rank: | |
| rank[data] += 1 / (alpha * k + 1) | |
| else: | |
| rank[data] = 1 / (alpha * k + 1) | |
| if data in highest: | |
| highest[data] = min(k,highest[data]) | |
| else: | |
| highest[data] = k | |
| for key in rank.keys(): | |
| rank[key] += highest[key] * -1 | |
| rank[key] += abs(len(key[0])-len(input_smiles)) * -0.2 | |
| rank[key] += len(key[0]) * -0.2 | |
| return rank,invalid_rates | |
| def read_dataset(opt): | |
| print(f'Reading {opt.path}...') | |
| with open(opt.path, 'r', encoding='utf-8') as f: | |
| test_tgt = [json.loads(line) for line in f.readlines()] | |
| if opt.raw: | |
| test_tgt = test_tgt[::opt.augmentation] | |
| filtered_tgt = {} | |
| idx_key = 'ds_idx' if 'ds_idx' in test_tgt[0] else 'index' | |
| for dic in test_tgt: | |
| if dic[idx_key] not in filtered_tgt: | |
| filtered_tgt[dic[idx_key]] = dic | |
| test_tgt = list(filtered_tgt.values()) | |
| test_tgt.sort(key=lambda x: x[idx_key]) | |
| print(f'{len(test_tgt)} samples read.') | |
| input_list = [extract_smiles(i['input']) for i in test_tgt] | |
| gt_list = [i['targets'].replace('[START_SMILES]', '').replace('[END_SMILES]', '').replace('SPL1T-TH1S-Pl3A5E','').strip().replace(' ','.') for i in test_tgt] | |
| pred_list = [[smi.strip().replace(' ','.') for smi in i['predictions']] for i in test_tgt] | |
| return input_list, gt_list, pred_list | |
| def main(opt): | |
| input_list, gt_list, pred_list = read_dataset(opt) | |
| if opt.raw: | |
| opt.augmentation=1 | |
| print('Reading predictions from file ...') | |
| # inputs | |
| print("Input Length", len(gt_list)) | |
| ras_src_smiles = input_list[::opt.augmentation] | |
| with multiprocessing.Pool(processes=opt.process_number) as pool: | |
| ras_src_smiles = pool.map(func=canonicalize_smiles_clear_map,iterable=ras_src_smiles) | |
| ras_src_smiles = [i[0] for i in ras_src_smiles] | |
| # predictions | |
| print("Prediction Length", len(pred_list)) | |
| pred_lines = [i.split('>')[0] for d in pred_list for i in d] | |
| data_size = len(pred_lines) // (opt.augmentation * opt.beam_size) if opt.length == -1 else opt.length | |
| pred_lines = pred_lines[:data_size * (opt.augmentation * opt.beam_size)] | |
| print("Canonicalizing predictions using Process Number ",opt.process_number) | |
| with multiprocessing.Pool(processes=opt.process_number) as pool: | |
| raw_predictions = pool.map(func=canonicalize_smiles_clear_map,iterable=pred_lines) | |
| predictions = [[[] for j in range(opt.augmentation)] for i in range(data_size)] # data_len x augmentation x beam_size | |
| for i, line in enumerate(raw_predictions): | |
| predictions[i // (opt.beam_size * opt.augmentation)][i % (opt.beam_size * opt.augmentation) // opt.beam_size].append(line) | |
| # ground truth | |
| print("Origin Length", len(gt_list)) | |
| targets = [''.join(gt_list[i].strip().split(' ')) for i in tqdm(range(0,data_size * opt.augmentation,opt.augmentation))] | |
| with multiprocessing.Pool(processes=opt.process_number) as pool: | |
| targets = pool.map(func=canonicalize_smiles_clear_map, iterable=targets) | |
| print("predictions Length", len(predictions), len(predictions[0]), len(predictions[0][0])) | |
| print("Target Length", len(targets)) | |
| ground_truth = targets | |
| print("Origin Target Lentgh, ", len(ground_truth)) | |
| print("Cutted Length, ",data_size) | |
| print('\n') | |
| accuracy = [0 for j in range(opt.n_best)] | |
| topn_accuracy_chirality = [0 for _ in range(opt.n_best)] | |
| topn_accuracy_wochirality = [0 for _ in range(opt.n_best)] | |
| topn_accuracy_ringopening = [0 for _ in range(opt.n_best)] | |
| topn_accuracy_ringformation = [0 for _ in range(opt.n_best)] | |
| topn_accuracy_woring = [0 for _ in range(opt.n_best)] | |
| total_chirality = 0 | |
| total_ringopening = 0 | |
| total_ringformation = 0 | |
| atomsize_topk = [] | |
| accurate_indices = [[] for j in range(opt.n_best)] | |
| max_frag_accuracy = [0 for j in range(opt.n_best)] | |
| invalid_rates = [0 for j in range(opt.beam_size)] | |
| sorted_invalid_rates = [0 for j in range(opt.beam_size)] | |
| unique_rates = 0 | |
| ranked_results = [] | |
| for i in tqdm(range(len(predictions))): | |
| accurate_flag = False | |
| if opt.detailed: | |
| chirality_flag = False | |
| ringopening_flag = False | |
| ringformation_flag = False | |
| pro_mol = Chem.MolFromSmiles(ras_src_smiles[i]) | |
| rea_mol = Chem.MolFromSmiles(ground_truth[i][0]) | |
| try: | |
| pro_ringinfo = pro_mol.GetRingInfo() | |
| rea_ringinfo = rea_mol.GetRingInfo() | |
| pro_ringnum = pro_ringinfo.NumRings() | |
| rea_ringnum = rea_ringinfo.NumRings() | |
| size = len(rea_mol.GetAtoms()) - len(pro_mol.GetAtoms()) | |
| # if (int(ras_src_smiles[i].count("@") > 0) + int(ground_truth[i][0].count("@") > 0)) == 1: | |
| if ras_src_smiles[i].count("@") > 0 or ground_truth[i][0].count("@") > 0: | |
| total_chirality += 1 | |
| chirality_flag = True | |
| if pro_ringnum < rea_ringnum: | |
| total_ringopening += 1 | |
| ringopening_flag = True | |
| if pro_ringnum > rea_ringnum: | |
| total_ringformation += 1 | |
| ringformation_flag = True | |
| except: | |
| pass | |
| # continue | |
| inputs = input_list[i*opt.augmentation:(i+1)*opt.augmentation] | |
| gt = gt_list[i*opt.augmentation:(i+1)*opt.augmentation] | |
| rank, invalid_rate = compute_rank(ras_src_smiles[i], predictions[i], raw=opt.raw,alpha=opt.score_alpha) | |
| rank_ = {k[0]: v for k, v in sorted(rank.items(), key=lambda item: item[1], reverse=True)} | |
| if opt.detailed: | |
| print('Index', i) | |
| print('inputs', json.dumps(inputs, indent=4)) | |
| print('targets', json.dumps(gt, indent=4)) | |
| print('input', ras_src_smiles[i]) | |
| print('target', targets[i][0]) | |
| print('rank', json.dumps(rank_,indent=4)) | |
| print('invalid_rate', json.dumps(invalid_rate,indent=4)) | |
| print('\n') | |
| for j in range(opt.beam_size): | |
| invalid_rates[j] += invalid_rate[j] | |
| rank = list(zip(rank.keys(),rank.values())) | |
| rank.sort(key=lambda x:x[1],reverse=True) | |
| rank = rank[:opt.n_best] | |
| ranked_results.append([item[0][0] for item in rank]) | |
| for j, item in enumerate(rank): | |
| if item[0][0] == ground_truth[i][0]: | |
| if not accurate_flag: | |
| accurate_flag = True | |
| accurate_indices[j].append(i) | |
| for k in range(j, opt.n_best): | |
| accuracy[k] += 1 | |
| if opt.detailed: | |
| atomsize_topk.append((size,j)) | |
| if chirality_flag: | |
| for k in range(j,opt.n_best): | |
| topn_accuracy_chirality[k] += 1 | |
| else: | |
| for k in range(j,opt.n_best): | |
| topn_accuracy_wochirality[k] += 1 | |
| if ringopening_flag: | |
| for k in range(j,opt.n_best): | |
| topn_accuracy_ringopening[k] += 1 | |
| if ringformation_flag: | |
| for k in range(j,opt.n_best): | |
| topn_accuracy_ringformation[k] += 1 | |
| if not ringopening_flag and not ringformation_flag: | |
| for k in range(j,opt.n_best): | |
| topn_accuracy_woring[k] += 1 | |
| if opt.detailed and not accurate_flag: | |
| atomsize_topk.append((size,opt.n_best)) | |
| for j, item in enumerate(rank): | |
| if item[0][1] == ground_truth[i][1]: | |
| for k in range(j,opt.n_best): | |
| max_frag_accuracy[k] += 1 | |
| break | |
| for j in range(len(rank),opt.beam_size): | |
| sorted_invalid_rates[j] += 1 | |
| unique_rates += len(rank) | |
| for i in range(opt.n_best): | |
| if i in [0,1,2,3,4,5,6,7,8,9,19,49]: | |
| # if i in range(10): | |
| print("Top-{} Acc:{:.3f}%, MaxFrag {:.3f}%,".format(i+1,accuracy[i] / data_size * 100,max_frag_accuracy[i] / data_size * 100), | |
| " Invalid SMILES:{:.3f}% Sorted Invalid SMILES:{:.3f}%".format(invalid_rates[i] / data_size / opt.augmentation * 100,sorted_invalid_rates[i] / data_size / opt.augmentation * 100)) | |
| print(' '.join([f'{accuracy[i] / data_size * 100:.3f}' for i in [0,2,4,9]])) | |
| print("Unique Rates:{:.3f}%".format(unique_rates / data_size / opt.beam_size * 100)) | |
| if opt.detailed: | |
| print_topk = [1,3,5,10] | |
| save_dict = {} | |
| atomsize_topk.sort(key=lambda x:x[0]) | |
| differ_now = atomsize_topk[0][0] | |
| topn_accuracy_bydiffer = [0 for _ in range(opt.n_best)] | |
| total_bydiffer = 0 | |
| for i,item in enumerate(atomsize_topk): | |
| if differ_now < 11 and differ_now != item[0]: | |
| for j in range(opt.n_best): | |
| if (j+1) in print_topk: | |
| save_dict[f'top-{j+1}_size_{differ_now}'] = topn_accuracy_bydiffer[j] / total_bydiffer * 100 | |
| print("Top-{} Atom differ size {} Acc:{:.3f}%, Number:{:.3f}%".format(j+1, | |
| differ_now, | |
| topn_accuracy_bydiffer[j] / total_bydiffer * 100, | |
| total_bydiffer/data_size * 100)) | |
| total_bydiffer = 0 | |
| topn_accuracy_bydiffer = [0 for _ in range(opt.n_best)] | |
| differ_now = item[0] | |
| for k in range(item[1],opt.n_best): | |
| topn_accuracy_bydiffer[k] += 1 | |
| total_bydiffer += 1 | |
| for j in range(opt.n_best): | |
| if (j + 1) in print_topk: | |
| print("Top-{} Atom differ size {} Acc:{:.3f}%, Number:{:.3f}%".format(j + 1, | |
| differ_now, | |
| topn_accuracy_bydiffer[j] / total_bydiffer * 100, | |
| total_bydiffer / data_size * 100)) | |
| save_dict[f'top-{j+1}_size_{differ_now}'] = topn_accuracy_bydiffer[j] / total_bydiffer * 100 | |
| for i in range(opt.n_best): | |
| if (i+1) in print_topk: | |
| if total_chirality > 0: | |
| print("Top-{} Accuracy with chirality:{:.3f}%".format(i + 1, topn_accuracy_chirality[i] / total_chirality * 100)) | |
| save_dict[f'top-{i+1}_chilarity'] = topn_accuracy_chirality[i] / total_chirality * 100 | |
| print("Top-{} Accuracy without chirality:{:.3f}%".format(i + 1, topn_accuracy_wochirality[i] / (data_size - total_chirality) * 100)) | |
| save_dict[f'top-{i+1}_wochilarity'] = topn_accuracy_wochirality[i] / (data_size - total_chirality) * 100 | |
| if total_ringopening > 0: | |
| print("Top-{} Accuracy ring Opening:{:.3f}%".format(i + 1, topn_accuracy_ringopening[i] / total_ringopening * 100)) | |
| save_dict[f'top-{i+1}_ringopening'] = topn_accuracy_ringopening[i] / total_ringopening * 100 | |
| if total_ringformation > 0: | |
| print("Top-{} Accuracy ring Formation:{:.3f}%".format(i + 1, topn_accuracy_ringformation[i] / total_ringformation * 100)) | |
| save_dict[f'top-{i+1}_ringformation'] = topn_accuracy_ringformation[i] / total_ringformation * 100 | |
| print("Top-{} Accuracy without ring:{:.3f}%".format(i + 1, topn_accuracy_woring[i] / (data_size - total_ringopening - total_ringformation) * 100)) | |
| save_dict[f'top-{i+1}_wocring'] = topn_accuracy_woring[i] / (data_size - total_ringopening - total_ringformation)* 100 | |
| print(total_chirality) | |
| print(total_ringformation) | |
| print(total_ringopening) | |
| # df = pd.DataFrame(list(save_dict.items())) | |
| df = pd.DataFrame(save_dict,index=[0]) | |
| df.to_csv("detailed_results.csv") | |
| if opt.save_accurate_indices != "": | |
| with open(opt.save_accurate_indices, "w") as f: | |
| total_accurate_indices = [] | |
| for indices in accurate_indices: | |
| total_accurate_indices.extend(indices) | |
| total_accurate_indices.sort() | |
| # for index in total_accurate_indices: | |
| for index in accurate_indices[0]: | |
| f.write(str(index)) | |
| f.write("\n") | |
| if opt.save_file != "": | |
| with open(opt.save_file,"w") as f: | |
| for res in ranked_results: | |
| for smi in res: | |
| f.write(smi) | |
| f.write("\n") | |
| for i in range(len(res),opt.n_best): | |
| f.write("") | |
| f.write("\n") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser( | |
| description='score.py', | |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |
| parser.add_argument('--beam_size', type=int, default=10,help='Beam size') | |
| parser.add_argument('--n_best', type=int, default=10,help='n best') | |
| parser.add_argument('--path', type=str, required=True, help="Path to file containing the predictions and ground truth.") | |
| parser.add_argument('--augmentation', type=int, default=20) | |
| parser.add_argument('--score_alpha', type=float, default=1.0) | |
| parser.add_argument('--length', type=int, default=-1) | |
| parser.add_argument('--process_number', type=int, default=multiprocessing.cpu_count()) | |
| parser.add_argument('--synthon', action="store_true", default=False) | |
| parser.add_argument('--detailed', action="store_true", default=False) | |
| parser.add_argument('--raw', action="store_true", default=False) | |
| parser.add_argument('--save_file', type=str,default="") | |
| parser.add_argument('--save_accurate_indices', type=str,default="") | |
| opt = parser.parse_args() | |
| print(opt) | |
| main(opt) |