Spaces:
Running
Running
Update utils.py
Browse files
utils.py
CHANGED
|
@@ -5,7 +5,7 @@ from rdkit.Chem import AllChem
|
|
| 5 |
from rdkit.Chem import Draw
|
| 6 |
import os
|
| 7 |
import numpy as np
|
| 8 |
-
import seaborn as sns
|
| 9 |
import matplotlib.pyplot as plt
|
| 10 |
from matplotlib.lines import Line2D
|
| 11 |
from rdkit import RDLogger
|
|
@@ -46,6 +46,7 @@ class Metrics(object):
|
|
| 46 |
|
| 47 |
return (np.array(list(map(Metrics.mol_length, data)), dtype=np.float32)/max_len).mean()
|
| 48 |
|
|
|
|
| 49 |
def sim_reward(mol_gen, fps_r):
|
| 50 |
|
| 51 |
gen_scaf = []
|
|
@@ -152,6 +153,7 @@ def sample_z_edge(batch_size, vertexes, edges):
|
|
| 152 |
|
| 153 |
return np.random.normal(0,1, size=(batch_size, vertexes, vertexes, edges)) # 128, 9, 9, 5
|
| 154 |
|
|
|
|
| 155 |
def sample_z( batch_size, z_dim):
|
| 156 |
|
| 157 |
''' Random noise. '''
|
|
@@ -176,10 +178,7 @@ def mol_sample(sample_directory, model_name, mol, edges, nodes, idx, i):
|
|
| 176 |
print("Valid matrices and smiles are saved")
|
| 177 |
|
| 178 |
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
def logging(log_path, start_time, mols, train_smiles, i,idx, loss,model_num, save_path):
|
| 183 |
|
| 184 |
gen_smiles = []
|
| 185 |
for line in mols:
|
|
@@ -222,20 +221,20 @@ def logging(log_path, start_time, mols, train_smiles, i,idx, loss,model_num, sav
|
|
| 222 |
#m1 =all_scores_chem(fake_mol, mols, vert, norm=True)
|
| 223 |
#m0.update(m1)
|
| 224 |
|
| 225 |
-
|
|
|
|
|
|
|
| 226 |
|
| 227 |
#m0 = {k: np.array(v).mean() for k, v in m0.items()}
|
| 228 |
#loss.update(m0)
|
| 229 |
loss.update({'Valid': valid})
|
| 230 |
-
loss.update({'Unique
|
| 231 |
loss.update({'Novel': novel})
|
| 232 |
#loss.update({'QED': statistics.mean(qed)})
|
| 233 |
#loss.update({'SA': statistics.mean(sa)})
|
| 234 |
#loss.update({'LogP': statistics.mean(logp)})
|
| 235 |
#loss.update({'IntDiv': IntDiv})
|
| 236 |
|
| 237 |
-
#wandb.log({"maxlen": maxlen})
|
| 238 |
-
|
| 239 |
for tag, value in loss.items():
|
| 240 |
|
| 241 |
log += ", {}: {:.4f}".format(tag, value)
|
|
@@ -246,24 +245,23 @@ def logging(log_path, start_time, mols, train_smiles, i,idx, loss,model_num, sav
|
|
| 246 |
print("\n")
|
| 247 |
|
| 248 |
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
attentions_pos =
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
axes[i].
|
| 264 |
-
|
| 265 |
-
pltsavedir
|
| 266 |
-
plt.savefig(os.path.join(pltsavedir, "attn" + model + "_" + dataset_name + "_" + str(iter) + "_" + str(epoch) + ".png"), dpi= 500,bbox_inches='tight')
|
| 267 |
|
| 268 |
|
| 269 |
def plot_grad_flow(named_parameters, model, iter, epoch):
|
|
@@ -298,36 +296,8 @@ def plot_grad_flow(named_parameters, model, iter, epoch):
|
|
| 298 |
Line2D([0], [0], color="k", lw=4)], ['max-gradient', 'mean-gradient', 'zero-gradient'])
|
| 299 |
pltsavedir = "/home/atabey/gradients/tryout"
|
| 300 |
plt.savefig(os.path.join(pltsavedir, "weights_" + model + "_" + str(iter) + "_" + str(epoch) + ".png"), dpi= 500,bbox_inches='tight')
|
| 301 |
-
|
| 302 |
-
"""
|
| 303 |
-
def _genDegree():
|
| 304 |
|
| 305 |
-
|
| 306 |
-
dataset is used.
|
| 307 |
-
Can be called without arguments and saves the tensor for later use. If tensor was created
|
| 308 |
-
before, it just loads the degree tensor.
|
| 309 |
-
'''
|
| 310 |
-
|
| 311 |
-
degree_path = os.path.join(self.degree_dir, self.dataset_name + '-degree.pt')
|
| 312 |
-
if not os.path.exists(degree_path):
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
max_degree = -1
|
| 316 |
-
for data in self.dataset:
|
| 317 |
-
d = geoutils.degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long)
|
| 318 |
-
max_degree = max(max_degree, int(d.max()))
|
| 319 |
-
|
| 320 |
-
# Compute the in-degree histogram tensor
|
| 321 |
-
deg = torch.zeros(max_degree + 1, dtype=torch.long)
|
| 322 |
-
for data in self.dataset:
|
| 323 |
-
d = geoutils.degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long)
|
| 324 |
-
deg += torch.bincount(d, minlength=deg.numel())
|
| 325 |
-
torch.save(deg, 'DrugGEN/data/' + self.dataset_name + '-degree.pt')
|
| 326 |
-
else:
|
| 327 |
-
deg = torch.load(degree_path, map_location=lambda storage, loc: storage)
|
| 328 |
-
|
| 329 |
-
return deg
|
| 330 |
-
"""
|
| 331 |
def get_mol(smiles_or_mol):
|
| 332 |
'''
|
| 333 |
Loads SMILES/molecule into RDKit's object
|
|
@@ -345,6 +315,7 @@ def get_mol(smiles_or_mol):
|
|
| 345 |
return mol
|
| 346 |
return smiles_or_mol
|
| 347 |
|
|
|
|
| 348 |
def mapper(n_jobs):
|
| 349 |
'''
|
| 350 |
Returns function for map call.
|
|
@@ -369,6 +340,8 @@ def mapper(n_jobs):
|
|
| 369 |
|
| 370 |
return _mapper
|
| 371 |
return n_jobs.map
|
|
|
|
|
|
|
| 372 |
def remove_invalid(gen, canonize=True, n_jobs=1):
|
| 373 |
"""
|
| 374 |
Removes invalid molecules from the dataset
|
|
@@ -378,6 +351,8 @@ def remove_invalid(gen, canonize=True, n_jobs=1):
|
|
| 378 |
return [gen_ for gen_, mol in zip(gen, mols) if mol is not None]
|
| 379 |
return [x for x in mapper(n_jobs)(canonic_smiles, gen) if
|
| 380 |
x is not None]
|
|
|
|
|
|
|
| 381 |
def fraction_valid(gen, n_jobs=1):
|
| 382 |
"""
|
| 383 |
Computes a number of valid molecules
|
|
@@ -387,11 +362,15 @@ def fraction_valid(gen, n_jobs=1):
|
|
| 387 |
"""
|
| 388 |
gen = mapper(n_jobs)(get_mol, gen)
|
| 389 |
return 1 - gen.count(None) / len(gen)
|
|
|
|
|
|
|
| 390 |
def canonic_smiles(smiles_or_mol):
|
| 391 |
mol = get_mol(smiles_or_mol)
|
| 392 |
if mol is None:
|
| 393 |
return None
|
| 394 |
return Chem.MolToSmiles(mol)
|
|
|
|
|
|
|
| 395 |
def fraction_unique(gen, k=None, n_jobs=1, check_validity=True):
|
| 396 |
"""
|
| 397 |
Computes a number of unique molecules
|
|
@@ -410,9 +389,11 @@ def fraction_unique(gen, k=None, n_jobs=1, check_validity=True):
|
|
| 410 |
gen = gen[:k]
|
| 411 |
canonic = set(mapper(n_jobs)(canonic_smiles, gen))
|
| 412 |
if None in canonic and check_validity:
|
| 413 |
-
|
|
|
|
| 414 |
return 0 if len(gen) == 0 else len(canonic) / len(gen)
|
| 415 |
|
|
|
|
| 416 |
def novelty(gen, train, n_jobs=1):
|
| 417 |
gen_smiles = mapper(n_jobs)(canonic_smiles, gen)
|
| 418 |
gen_smiles_set = set(gen_smiles) - {None}
|
|
@@ -420,7 +401,6 @@ def novelty(gen, train, n_jobs=1):
|
|
| 420 |
return 0 if len(gen_smiles_set) == 0 else len(gen_smiles_set - train_set) / len(gen_smiles_set)
|
| 421 |
|
| 422 |
|
| 423 |
-
|
| 424 |
def average_agg_tanimoto(stock_vecs, gen_vecs,
|
| 425 |
batch_size=5000, agg='max',
|
| 426 |
device='cpu', p=1):
|
|
|
|
| 5 |
from rdkit.Chem import Draw
|
| 6 |
import os
|
| 7 |
import numpy as np
|
| 8 |
+
#import seaborn as sns
|
| 9 |
import matplotlib.pyplot as plt
|
| 10 |
from matplotlib.lines import Line2D
|
| 11 |
from rdkit import RDLogger
|
|
|
|
| 46 |
|
| 47 |
return (np.array(list(map(Metrics.mol_length, data)), dtype=np.float32)/max_len).mean()
|
| 48 |
|
| 49 |
+
|
| 50 |
def sim_reward(mol_gen, fps_r):
|
| 51 |
|
| 52 |
gen_scaf = []
|
|
|
|
| 153 |
|
| 154 |
return np.random.normal(0,1, size=(batch_size, vertexes, vertexes, edges)) # 128, 9, 9, 5
|
| 155 |
|
| 156 |
+
|
| 157 |
def sample_z( batch_size, z_dim):
|
| 158 |
|
| 159 |
''' Random noise. '''
|
|
|
|
| 178 |
print("Valid matrices and smiles are saved")
|
| 179 |
|
| 180 |
|
| 181 |
+
def logging(log_path, start_time, mols, train_smiles, i,idx, loss,model_num, save_path, get_maxlen=False):
|
|
|
|
|
|
|
|
|
|
| 182 |
|
| 183 |
gen_smiles = []
|
| 184 |
for line in mols:
|
|
|
|
| 221 |
#m1 =all_scores_chem(fake_mol, mols, vert, norm=True)
|
| 222 |
#m0.update(m1)
|
| 223 |
|
| 224 |
+
if get_maxlen:
|
| 225 |
+
maxlen = Metrics.max_component(mols, 45)
|
| 226 |
+
loss.update({"MaxLen": maxlen})
|
| 227 |
|
| 228 |
#m0 = {k: np.array(v).mean() for k, v in m0.items()}
|
| 229 |
#loss.update(m0)
|
| 230 |
loss.update({'Valid': valid})
|
| 231 |
+
loss.update({'Unique': unique})
|
| 232 |
loss.update({'Novel': novel})
|
| 233 |
#loss.update({'QED': statistics.mean(qed)})
|
| 234 |
#loss.update({'SA': statistics.mean(sa)})
|
| 235 |
#loss.update({'LogP': statistics.mean(logp)})
|
| 236 |
#loss.update({'IntDiv': IntDiv})
|
| 237 |
|
|
|
|
|
|
|
| 238 |
for tag, value in loss.items():
|
| 239 |
|
| 240 |
log += ", {}: {:.4f}".format(tag, value)
|
|
|
|
| 245 |
print("\n")
|
| 246 |
|
| 247 |
|
| 248 |
+
#def plot_attn(dataset_name, heads,attn_w, model, iter, epoch):
|
| 249 |
+
#
|
| 250 |
+
# cols = 4
|
| 251 |
+
# rows = int(heads/cols)
|
| 252 |
+
#
|
| 253 |
+
# fig, axes = plt.subplots( rows,cols, figsize = (30, 14))
|
| 254 |
+
# axes = axes.flat
|
| 255 |
+
# attentions_pos = attn_w[0]
|
| 256 |
+
# attentions_pos = attentions_pos.cpu().detach().numpy()
|
| 257 |
+
# for i,att in enumerate(attentions_pos):
|
| 258 |
+
#
|
| 259 |
+
# #im = axes[i].imshow(att, cmap='gray')
|
| 260 |
+
# sns.heatmap(att,vmin = 0, vmax = 1,ax = axes[i])
|
| 261 |
+
# axes[i].set_title(f'head - {i} ')
|
| 262 |
+
# axes[i].set_ylabel('layers')
|
| 263 |
+
# pltsavedir = "/home/atabey/attn/second"
|
| 264 |
+
# plt.savefig(os.path.join(pltsavedir, "attn" + model + "_" + dataset_name + "_" + str(iter) + "_" + str(epoch) + ".png"), dpi= 500,bbox_inches='tight')
|
|
|
|
| 265 |
|
| 266 |
|
| 267 |
def plot_grad_flow(named_parameters, model, iter, epoch):
|
|
|
|
| 296 |
Line2D([0], [0], color="k", lw=4)], ['max-gradient', 'mean-gradient', 'zero-gradient'])
|
| 297 |
pltsavedir = "/home/atabey/gradients/tryout"
|
| 298 |
plt.savefig(os.path.join(pltsavedir, "weights_" + model + "_" + str(iter) + "_" + str(epoch) + ".png"), dpi= 500,bbox_inches='tight')
|
|
|
|
|
|
|
|
|
|
| 299 |
|
| 300 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 301 |
def get_mol(smiles_or_mol):
|
| 302 |
'''
|
| 303 |
Loads SMILES/molecule into RDKit's object
|
|
|
|
| 315 |
return mol
|
| 316 |
return smiles_or_mol
|
| 317 |
|
| 318 |
+
|
| 319 |
def mapper(n_jobs):
|
| 320 |
'''
|
| 321 |
Returns function for map call.
|
|
|
|
| 340 |
|
| 341 |
return _mapper
|
| 342 |
return n_jobs.map
|
| 343 |
+
|
| 344 |
+
|
| 345 |
def remove_invalid(gen, canonize=True, n_jobs=1):
|
| 346 |
"""
|
| 347 |
Removes invalid molecules from the dataset
|
|
|
|
| 351 |
return [gen_ for gen_, mol in zip(gen, mols) if mol is not None]
|
| 352 |
return [x for x in mapper(n_jobs)(canonic_smiles, gen) if
|
| 353 |
x is not None]
|
| 354 |
+
|
| 355 |
+
|
| 356 |
def fraction_valid(gen, n_jobs=1):
|
| 357 |
"""
|
| 358 |
Computes a number of valid molecules
|
|
|
|
| 362 |
"""
|
| 363 |
gen = mapper(n_jobs)(get_mol, gen)
|
| 364 |
return 1 - gen.count(None) / len(gen)
|
| 365 |
+
|
| 366 |
+
|
| 367 |
def canonic_smiles(smiles_or_mol):
|
| 368 |
mol = get_mol(smiles_or_mol)
|
| 369 |
if mol is None:
|
| 370 |
return None
|
| 371 |
return Chem.MolToSmiles(mol)
|
| 372 |
+
|
| 373 |
+
|
| 374 |
def fraction_unique(gen, k=None, n_jobs=1, check_validity=True):
|
| 375 |
"""
|
| 376 |
Computes a number of unique molecules
|
|
|
|
| 389 |
gen = gen[:k]
|
| 390 |
canonic = set(mapper(n_jobs)(canonic_smiles, gen))
|
| 391 |
if None in canonic and check_validity:
|
| 392 |
+
canonic = [i for i in canonic if i is not None]
|
| 393 |
+
#raise ValueError("Invalid molecule passed to unique@k")
|
| 394 |
return 0 if len(gen) == 0 else len(canonic) / len(gen)
|
| 395 |
|
| 396 |
+
|
| 397 |
def novelty(gen, train, n_jobs=1):
|
| 398 |
gen_smiles = mapper(n_jobs)(canonic_smiles, gen)
|
| 399 |
gen_smiles_set = set(gen_smiles) - {None}
|
|
|
|
| 401 |
return 0 if len(gen_smiles_set) == 0 else len(gen_smiles_set - train_set) / len(gen_smiles_set)
|
| 402 |
|
| 403 |
|
|
|
|
| 404 |
def average_agg_tanimoto(stock_vecs, gen_vecs,
|
| 405 |
batch_size=5000, agg='max',
|
| 406 |
device='cpu', p=1):
|