Spaces:
Runtime error
Runtime error
| import torch | |
| import numpy as np | |
| from rdkit import Chem, Geometry | |
| from src import const | |
| def create_conformer(coords): | |
| conformer = Chem.Conformer() | |
| for i, (x, y, z) in enumerate(coords): | |
| conformer.SetAtomPosition(i, Geometry.Point3D(x, y, z)) | |
| return conformer | |
| def build_molecules(one_hot, x, node_mask, is_geom, margins=const.MARGINS_EDM): | |
| molecules = [] | |
| for i in range(len(one_hot)): | |
| mask = node_mask[i].squeeze() == 1 | |
| atom_types = one_hot[i][mask].argmax(dim=1).detach().cpu() | |
| positions = x[i][mask].detach().cpu() | |
| mol = build_molecule(positions, atom_types, is_geom, margins=margins) | |
| molecules.append(mol) | |
| return molecules | |
| def build_molecule(positions, atom_types, is_geom, margins=const.MARGINS_EDM): | |
| idx2atom = const.GEOM_IDX2ATOM if is_geom else const.IDX2ATOM | |
| X, A, E = build_xae_molecule(positions, atom_types, is_geom=is_geom, margins=margins) | |
| mol = Chem.RWMol() | |
| for atom in X: | |
| a = Chem.Atom(idx2atom[atom.item()]) | |
| mol.AddAtom(a) | |
| all_bonds = torch.nonzero(A) | |
| for bond in all_bonds: | |
| mol.AddBond(bond[0].item(), bond[1].item(), const.BOND_DICT[E[bond[0], bond[1]].item()]) | |
| mol.AddConformer(create_conformer(positions.detach().cpu().numpy().astype(np.float64))) | |
| return mol | |
| def build_xae_molecule(positions, atom_types, is_geom, margins=const.MARGINS_EDM): | |
| """ Returns a triplet (X, A, E): atom_types, adjacency matrix, edge_types | |
| args: | |
| positions: N x 3 (already masked to keep final number nodes) | |
| atom_types: N | |
| returns: | |
| X: N (int) | |
| A: N x N (bool) (binary adjacency matrix) | |
| E: N x N (int) (bond type, 0 if no bond) such that A = E.bool() | |
| """ | |
| n = positions.shape[0] | |
| X = atom_types | |
| A = torch.zeros((n, n), dtype=torch.bool) | |
| E = torch.zeros((n, n), dtype=torch.int) | |
| idx2atom = const.GEOM_IDX2ATOM if is_geom else const.IDX2ATOM | |
| pos = positions.unsqueeze(0) | |
| dists = torch.cdist(pos, pos, p=2).squeeze(0) | |
| for i in range(n): | |
| for j in range(i): | |
| pair = sorted([atom_types[i], atom_types[j]]) | |
| order = get_bond_order(idx2atom[pair[0].item()], idx2atom[pair[1].item()], dists[i, j], margins=margins) | |
| # TODO: a batched version of get_bond_order to avoid the for loop | |
| if order > 0: | |
| # Warning: the graph should be DIRECTED | |
| A[i, j] = 1 | |
| E[i, j] = order | |
| return X, A, E | |
| def get_bond_order(atom1, atom2, distance, check_exists=True, margins=const.MARGINS_EDM): | |
| distance = 100 * distance # We change the metric | |
| # Check exists for large molecules where some atom pairs do not have a | |
| # typical bond length. | |
| if check_exists: | |
| if atom1 not in const.BONDS_1: | |
| return 0 | |
| if atom2 not in const.BONDS_1[atom1]: | |
| return 0 | |
| # margin1, margin2 and margin3 have been tuned to maximize the stability of the QM9 true samples | |
| if distance < const.BONDS_1[atom1][atom2] + margins[0]: | |
| # Check if atoms in bonds2 dictionary. | |
| if atom1 in const.BONDS_2 and atom2 in const.BONDS_2[atom1]: | |
| thr_bond2 = const.BONDS_2[atom1][atom2] + margins[1] | |
| if distance < thr_bond2: | |
| if atom1 in const.BONDS_3 and atom2 in const.BONDS_3[atom1]: | |
| thr_bond3 = const.BONDS_3[atom1][atom2] + margins[2] | |
| if distance < thr_bond3: | |
| return 3 # Triple | |
| return 2 # Double | |
| return 1 # Single | |
| return 0 # No bond | |