Spaces:
Running
Running
Update trainer.py
Browse files- trainer.py +12 -12
trainer.py
CHANGED
|
@@ -422,7 +422,7 @@ class Trainer(object):
|
|
| 422 |
|
| 423 |
''' Loading the atom and bond decoders'''
|
| 424 |
|
| 425 |
-
with open("
|
| 426 |
|
| 427 |
return pickle.load(f)
|
| 428 |
|
|
@@ -430,7 +430,7 @@ class Trainer(object):
|
|
| 430 |
|
| 431 |
''' Loading the atom and bond decoders'''
|
| 432 |
|
| 433 |
-
with open("
|
| 434 |
|
| 435 |
return pickle.load(f)
|
| 436 |
|
|
@@ -531,15 +531,15 @@ class Trainer(object):
|
|
| 531 |
|
| 532 |
|
| 533 |
# protein data
|
| 534 |
-
full_smiles = [line for line in open("
|
| 535 |
-
drug_smiles = [line for line in open("
|
| 536 |
|
| 537 |
drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
|
| 538 |
drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols]
|
| 539 |
fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf]
|
| 540 |
|
| 541 |
-
akt1_human_adj = torch.load("
|
| 542 |
-
akt1_human_annot = torch.load("
|
| 543 |
|
| 544 |
if self.resume:
|
| 545 |
self.restore_model(self.resume_epoch, self.resume_iter, self.resume_directory)
|
|
@@ -733,14 +733,14 @@ class Trainer(object):
|
|
| 733 |
self.G2.load_state_dict(torch.load(G2_path, map_location=lambda storage, loc: storage))
|
| 734 |
|
| 735 |
|
| 736 |
-
drug_smiles = [line for line in open("
|
| 737 |
|
| 738 |
drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
|
| 739 |
drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols]
|
| 740 |
fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf]
|
| 741 |
|
| 742 |
-
akt1_human_adj = torch.load("
|
| 743 |
-
akt1_human_annot = torch.load("
|
| 744 |
|
| 745 |
self.G.eval()
|
| 746 |
#self.D.eval()
|
|
@@ -782,8 +782,8 @@ class Trainer(object):
|
|
| 782 |
#metric_calc_mol = []
|
| 783 |
metric_calc_dr = []
|
| 784 |
date = time.time()
|
| 785 |
-
if not os.path.exists("
|
| 786 |
-
os.makedirs("
|
| 787 |
with torch.inference_mode():
|
| 788 |
|
| 789 |
dataloader_iterator = iter(self.inf_drugs_loader)
|
|
@@ -893,7 +893,7 @@ class Trainer(object):
|
|
| 893 |
inference_drugs = [Chem.MolToSmiles(line) for line in fake_mol_g if line is not None]
|
| 894 |
inference_drugs = [None if x is None else max(x.split('.'), key=len) for x in inference_drugs]
|
| 895 |
|
| 896 |
-
with open("
|
| 897 |
for molecules in inference_drugs:
|
| 898 |
|
| 899 |
f.write(molecules)
|
|
|
|
| 422 |
|
| 423 |
''' Loading the atom and bond decoders'''
|
| 424 |
|
| 425 |
+
with open("data/decoders/" + dictionary_name + "_" + self.dataset_name + '.pkl', 'rb') as f:
|
| 426 |
|
| 427 |
return pickle.load(f)
|
| 428 |
|
|
|
|
| 430 |
|
| 431 |
''' Loading the atom and bond decoders'''
|
| 432 |
|
| 433 |
+
with open("data/decoders/" + dictionary_name +"_" + self.drugs_name +'.pkl', 'rb') as f:
|
| 434 |
|
| 435 |
return pickle.load(f)
|
| 436 |
|
|
|
|
| 531 |
|
| 532 |
|
| 533 |
# protein data
|
| 534 |
+
full_smiles = [line for line in open("data/chembl_train.smi", 'r').read().splitlines()]
|
| 535 |
+
drug_smiles = [line for line in open("data/akt_train.smi", 'r').read().splitlines()]
|
| 536 |
|
| 537 |
drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
|
| 538 |
drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols]
|
| 539 |
fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf]
|
| 540 |
|
| 541 |
+
akt1_human_adj = torch.load("data/akt/AKT1_human_adj.pt").reshape(1,-1).to(self.device).float()
|
| 542 |
+
akt1_human_annot = torch.load("data/akt/AKT1_human_annot.pt").reshape(1,-1).to(self.device).float()
|
| 543 |
|
| 544 |
if self.resume:
|
| 545 |
self.restore_model(self.resume_epoch, self.resume_iter, self.resume_directory)
|
|
|
|
| 733 |
self.G2.load_state_dict(torch.load(G2_path, map_location=lambda storage, loc: storage))
|
| 734 |
|
| 735 |
|
| 736 |
+
drug_smiles = [line for line in open("data/akt_test.smi", 'r').read().splitlines()]
|
| 737 |
|
| 738 |
drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
|
| 739 |
drug_scaf = [MurckoScaffold.GetScaffoldForMol(x) for x in drug_mols]
|
| 740 |
fps_r = [Chem.RDKFingerprint(x) for x in drug_scaf]
|
| 741 |
|
| 742 |
+
akt1_human_adj = torch.load("data/akt/AKT1_human_adj.pt").reshape(1,-1).to(self.device).float()
|
| 743 |
+
akt1_human_annot = torch.load("data/akt/AKT1_human_annot.pt").reshape(1,-1).to(self.device).float()
|
| 744 |
|
| 745 |
self.G.eval()
|
| 746 |
#self.D.eval()
|
|
|
|
| 782 |
#metric_calc_mol = []
|
| 783 |
metric_calc_dr = []
|
| 784 |
date = time.time()
|
| 785 |
+
if not os.path.exists("experiments/inference/{}".format(self.submodel)):
|
| 786 |
+
os.makedirs("experiments/inference/{}".format(self.submodel))
|
| 787 |
with torch.inference_mode():
|
| 788 |
|
| 789 |
dataloader_iterator = iter(self.inf_drugs_loader)
|
|
|
|
| 893 |
inference_drugs = [Chem.MolToSmiles(line) for line in fake_mol_g if line is not None]
|
| 894 |
inference_drugs = [None if x is None else max(x.split('.'), key=len) for x in inference_drugs]
|
| 895 |
|
| 896 |
+
with open("experiments/inference/{}/inference_drugs.txt".format(self.submodel), "a") as f:
|
| 897 |
for molecules in inference_drugs:
|
| 898 |
|
| 899 |
f.write(molecules)
|