Spaces:
Runtime error
Runtime error
Upload graph_decoder/visualize_utils.py with huggingface_hub
Browse files
graph_decoder/visualize_utils.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from rdkit import Chem
|
| 3 |
+
from rdkit.Chem import Draw, AllChem
|
| 4 |
+
from rdkit.Geometry import Point3D
|
| 5 |
+
from rdkit import RDLogger
|
| 6 |
+
import numpy as np
|
| 7 |
+
import rdkit.Chem
|
| 8 |
+
|
| 9 |
+
class MolecularVisualization:
|
| 10 |
+
def __init__(self, atom_decoder):
|
| 11 |
+
self.atom_decoder = atom_decoder
|
| 12 |
+
|
| 13 |
+
def mol_from_graphs(self, node_list, adjacency_matrix):
|
| 14 |
+
"""
|
| 15 |
+
Convert graphs to rdkit molecules
|
| 16 |
+
node_list: the nodes of a batch of nodes (bs x n)
|
| 17 |
+
adjacency_matrix: the adjacency_matrix of the molecule (bs x n x n)
|
| 18 |
+
"""
|
| 19 |
+
# dictionary to map integer value to the char of atom
|
| 20 |
+
atom_decoder = self.atom_decoder
|
| 21 |
+
|
| 22 |
+
# create empty editable mol object
|
| 23 |
+
mol = Chem.RWMol()
|
| 24 |
+
|
| 25 |
+
# add atoms to mol and keep track of index
|
| 26 |
+
node_to_idx = {}
|
| 27 |
+
for i in range(len(node_list)):
|
| 28 |
+
if node_list[i] == -1:
|
| 29 |
+
continue
|
| 30 |
+
a = Chem.Atom(atom_decoder[int(node_list[i])])
|
| 31 |
+
molIdx = mol.AddAtom(a)
|
| 32 |
+
node_to_idx[i] = molIdx
|
| 33 |
+
|
| 34 |
+
for ix, row in enumerate(adjacency_matrix):
|
| 35 |
+
for iy, bond in enumerate(row):
|
| 36 |
+
# only traverse half the symmetric matrix
|
| 37 |
+
if iy <= ix:
|
| 38 |
+
continue
|
| 39 |
+
if bond == 1:
|
| 40 |
+
bond_type = Chem.rdchem.BondType.SINGLE
|
| 41 |
+
elif bond == 2:
|
| 42 |
+
bond_type = Chem.rdchem.BondType.DOUBLE
|
| 43 |
+
elif bond == 3:
|
| 44 |
+
bond_type = Chem.rdchem.BondType.TRIPLE
|
| 45 |
+
elif bond == 4:
|
| 46 |
+
bond_type = Chem.rdchem.BondType.AROMATIC
|
| 47 |
+
else:
|
| 48 |
+
continue
|
| 49 |
+
mol.AddBond(node_to_idx[ix], node_to_idx[iy], bond_type)
|
| 50 |
+
|
| 51 |
+
try:
|
| 52 |
+
mol = mol.GetMol()
|
| 53 |
+
except rdkit.Chem.KekulizeException:
|
| 54 |
+
print("Can't kekulize molecule")
|
| 55 |
+
mol = None
|
| 56 |
+
return mol
|
| 57 |
+
|
| 58 |
+
def visualize_chain(self, nodes_list, adjacency_matrix):
|
| 59 |
+
RDLogger.DisableLog('rdApp.*')
|
| 60 |
+
# convert graphs to the rdkit molecules
|
| 61 |
+
mols = [self.mol_from_graphs(nodes_list[i], adjacency_matrix[i]) for i in range(nodes_list.shape[0])]
|
| 62 |
+
|
| 63 |
+
# find the coordinates of atoms in the final molecule
|
| 64 |
+
final_molecule = mols[-1]
|
| 65 |
+
AllChem.Compute2DCoords(final_molecule)
|
| 66 |
+
|
| 67 |
+
coords = []
|
| 68 |
+
for i, atom in enumerate(final_molecule.GetAtoms()):
|
| 69 |
+
positions = final_molecule.GetConformer().GetAtomPosition(i)
|
| 70 |
+
coords.append((positions.x, positions.y, positions.z))
|
| 71 |
+
|
| 72 |
+
# align all the molecules
|
| 73 |
+
for i, mol in enumerate(mols):
|
| 74 |
+
AllChem.Compute2DCoords(mol)
|
| 75 |
+
conf = mol.GetConformer()
|
| 76 |
+
for j, atom in enumerate(mol.GetAtoms()):
|
| 77 |
+
x, y, z = coords[j]
|
| 78 |
+
conf.SetAtomPosition(j, Point3D(x, y, z))
|
| 79 |
+
|
| 80 |
+
# create list of molecule images
|
| 81 |
+
mol_images = []
|
| 82 |
+
for frame, mol in enumerate(mols):
|
| 83 |
+
img = Draw.MolToImage(mol, size=(300, 300), legend=f"Frame {frame}")
|
| 84 |
+
mol_images.append(img)
|
| 85 |
+
|
| 86 |
+
return mol_images
|